RNN - 预测正弦函数

  • 参考《TensorFlow实战Google深度学习框架》。不使用TFLearn,只使用TensorFlow
  • 完整代码看这里
  • 如果对RNN不理解,请看RNN递归神经网络的直观理解:基于TensorFlow的简单RNN例子
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

数据准备

# 训练数据个数
training_examples = 10000
# 测试数据个数
testing_examples = 1000
# sin函数的采样间隔
sample_gap = 0.01
# 每个训练样本的长度
timesteps = 20
def generate_data(seq):'''生成数据,seq是一序列的连续的sin的值'''X = []y = []# 用前 timesteps 个sin值,估计第 timesteps+1 个# 因此, 输入 X 是一段序列,输出 y 是一个值 for i in range(len(seq) - timesteps -1):X.append(seq[i : i+timesteps])y.append(seq[i+timesteps])return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)
test_start = training_examples*sample_gap
test_end = test_start + testing_examples*sample_gaptrain_x, train_y = generate_data( np.sin( np.linspace(0, test_start, training_examples) ) )
test_x, test_y = generate_data( np.sin( np.linspace(test_start, test_end, testing_examples) ) )

建立RNN模型

设置模型参数

lstm_size = 30
lstm_layers = 2
batch_size = 64

定义输入输出

x = tf.placeholder(tf.float32, [None, timesteps, 1], name='input_x')
y_ = tf.placeholder(tf.float32, [None, 1], name='input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')

建立LSTM层

# 有lstm_size个单元
lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
# 添加dropout
drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
# 一层不够,就多来几层
def lstm_cell():return tf.contrib.rnn.BasicLSTMCell(lstm_size)
cell = tf.contrib.rnn.MultiRNNCell([ lstm_cell() for _ in range(lstm_layers)])# 进行forward,得到隐层的输出
outputs, final_state = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)
# 在本问题中只关注最后一个时刻的输出结果,该结果为下一个时刻的预测值
outputs = outputs[:,-1]# 定义输出层, 输出值[-1,1],因此激活函数用tanh
predictions = tf.contrib.layers.fully_connected(outputs, 1, activation_fn=tf.tanh)
# 定义损失函数
cost = tf.losses.mean_squared_error(y_, predictions)
# 定义优化步骤
optimizer = tf.train.AdamOptimizer().minimize(cost)

训练

# 获取一个batch_size大小的数据
def get_batches(X, y, batch_size=64):for i in range(0, len(X), batch_size):begin_i = iend_i = i + batch_size if (i+batch_size) < len(X) else len(X)yield X[begin_i:end_i], y[begin_i:end_i]
epochs = 20
session = tf.Session()
with session.as_default() as sess:# 初始化变量tf.global_variables_initializer().run()iteration = 1for e in range(epochs):for xs, ys in get_batches(train_x, train_y, batch_size):# xs[:,:,None] 增加一个维度,例如[64, 20] ==> [64, 20, 1],为了对应输入# 同理 ys[:,None]feed_dict = { x:xs[:,:,None], y_:ys[:,None], keep_prob:.5 }loss, _ = sess.run([cost, optimizer], feed_dict=feed_dict)if iteration % 100 == 0:print('Epochs:{}/{}'.format(e, epochs),'Iteration:{}'.format(iteration),'Train loss: {:.8f}'.format(loss))iteration += 1
Epochs:0/20 Iteration:100 Train loss: 0.01009926
Epochs:1/20 Iteration:200 Train loss: 0.02012673
Epochs:1/20 Iteration:300 Train loss: 0.00237983
Epochs:2/20 Iteration:400 Train loss: 0.00029798
Epochs:3/20 Iteration:500 Train loss: 0.00283409
Epochs:3/20 Iteration:600 Train loss: 0.00115144
Epochs:4/20 Iteration:700 Train loss: 0.00130756
Epochs:5/20 Iteration:800 Train loss: 0.00029282
Epochs:5/20 Iteration:900 Train loss: 0.00045034
Epochs:6/20 Iteration:1000 Train loss: 0.00007531
Epochs:7/20 Iteration:1100 Train loss: 0.00189699
Epochs:7/20 Iteration:1200 Train loss: 0.00022669
Epochs:8/20 Iteration:1300 Train loss: 0.00065262
Epochs:8/20 Iteration:1400 Train loss: 0.00001342
Epochs:9/20 Iteration:1500 Train loss: 0.00037799
Epochs:10/20 Iteration:1600 Train loss: 0.00009412
Epochs:10/20 Iteration:1700 Train loss: 0.00110568
Epochs:11/20 Iteration:1800 Train loss: 0.00024895
Epochs:12/20 Iteration:1900 Train loss: 0.00287319
Epochs:12/20 Iteration:2000 Train loss: 0.00012025
Epochs:13/20 Iteration:2100 Train loss: 0.00353661
Epochs:14/20 Iteration:2200 Train loss: 0.00045697
Epochs:14/20 Iteration:2300 Train loss: 0.00103393
Epochs:15/20 Iteration:2400 Train loss: 0.00045038
Epochs:16/20 Iteration:2500 Train loss: 0.00022164
Epochs:16/20 Iteration:2600 Train loss: 0.00026206
Epochs:17/20 Iteration:2700 Train loss: 0.00279484
Epochs:17/20 Iteration:2800 Train loss: 0.00024887
Epochs:18/20 Iteration:2900 Train loss: 0.00263336
Epochs:19/20 Iteration:3000 Train loss: 0.00071482
Epochs:19/20 Iteration:3100 Train loss: 0.00026286

测试

with session.as_default() as sess:## 测试结果feed_dict = {x:test_x[:,:,None], keep_prob:1.0}results = sess.run(predictions, feed_dict=feed_dict)plt.plot(results,'r', label='predicted')plt.plot(test_y, 'g--', label='real sin')plt.legend()plt.show()

TensorFlow-RNN循环神经网络 Example 1:预测Sin函数相关推荐

  1. tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】

    之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...

  2. PyTorch-09 循环神经网络RNNLSTM (时间序列表示、RNN循环神经网络、RNN Layer使用、时间序列预测案例、RNN训练难题、解决梯度离散LSTM、LSTM使用、情感分类问题实战)

    PyTorch-09 循环神经网络RNN&LSTM (时间序列表示.RNN循环神经网络.RNN Layer使用.时间序列预测案例(一层的预测点的案例).RNN训练难题(梯度爆炸和梯度离散)和解 ...

  3. RNN循环神经网络 、LSTM长短期记忆网络实现时间序列长期利率预测

    全文链接:http://tecdat.cn/?p=25133 2017 年年中,R 推出了 Keras 包 _,_这是一个在 Tensorflow 之上运行的综合库,具有 CPU 和 GPU 功能(点 ...

  4. 神经网络学习小记录2——利用tensorflow构建循环神经网络(RNN)

    神经网络学习小记录2--利用tensorflow构建循环神经网络(RNN) 学习前言 RNN简介 tensorflow中RNN的相关函数 tf.nn.rnn_cell.BasicLSTMCell tf ...

  5. 【机器学习】RNN循环神经网络

    循环神经网络归属: 领域:机器学习 方向:自然语言处理 贡献:自动文本生成 循环神经网络实际应用: 生活中因为原始数据都是序列化的,比如自然语言,语音处理,时间序列问题(股票价格)等问题, 这个时候需 ...

  6. [译] RNN 循环神经网络系列 2:文本分类

    原文地址:RECURRENT NEURAL NETWORKS (RNN) – PART 2: TEXT CLASSIFICATION 原文作者:GokuMohandas 译文出自:掘金翻译计划 本文永 ...

  7. RNN 循环神经网络系列 5: 自定义单元

    原文地址:RECURRENT NEURAL NETWORK (RNN) – PART 5: CUSTOM CELLS 原文作者:GokuMohandas 译文出自:掘金翻译计划 本文永久链接:gith ...

  8. Recurrent Neural Networks(RNN) 循环神经网络初探

    1. 针对机器学习/深度神经网络"记忆能力"的讨论 0x1:数据规律的本质是能代表此类数据的通用模式 - 数据挖掘的本质是在进行模式提取 数据的本质是存储信息的介质,而模式(pat ...

  9. rnn 循环神经网络

    rnn 循环神经网络 创建日期 星期四 10 一月 2019 rnn为 recurrent natural network, 递归神经网络 是一种基于序列的神经网络, 序列可以是时间,文本序列等,和普 ...

  10. RNN循环神经网络(recurrent neural network)

     自己开发了一个股票智能分析软件,功能很强大,需要的点击下面的链接获取: https://www.cnblogs.com/bclshuai/p/11380657.html 1.1  RNN循环神经网络 ...

最新文章

  1. 目标检测比赛中的trick
  2. sql语句查询Oracle|sql server|access 数据库里的所有表名,字段名
  3. JTA 深度历险 - 原理与实现
  4. CDI和EJB:在事务成功时发送异步邮件
  5. 第六节: 六类Calander处理六种不同的时间场景
  6. HttpClient简介
  7. SBuild 0.2.0 发布,基于 Scala 的构建系统
  8. 干货!!月薪过万行业,软件测试必懂的基本概念
  9. HiveQL学习笔记(五):Hive练习题
  10. ios ipad适配_安卓和iOS的APP在开发时有哪些区别?
  11. 集成产品开发-IPD简介
  12. 武汉大学计算机学院 优秀夏令营,武汉大学计算机学院2016年优秀大学生暑期夏令营活动实施方案...
  13. 分享几个翻译PDF的软件给你
  14. **传统线上支付 区块链**
  15. 博基计划(4)---近红外光谱过程分析中基线漂移的主要来源
  16. 各种图片格式综述(转载)
  17. YOLOv5之detect.py文件
  18. 专接本数学第一章知识梳理(1)
  19. phpcms mysql debug_phpcms教程之mysql配置优化
  20. 11G Upgrade

热门文章

  1. 计算机网络实训简介,计算机网络实验报告介绍.doc
  2. mysql并行加载机制_Mysql表引擎优化
  3. 【项目调研+论文阅读】基于医学文献的实体抽取(NER)方法研究 day5
  4. mysql存储ip地址_MySQL怎样存储IP地址
  5. php td复制剪贴板,选择一个带有Javascript的完整表格(复制到剪贴板)
  6. java oracle executeupdate 无效_Java语言的品味(三)
  7. (源码实例)通过层DIV实现,当鼠标放在链接上面,显示图片及文字
  8. Python案例:获取天气信息并绘制气温折线图
  9. Django讲课笔记04:Django项目的调试
  10. 大数据学习笔记28:MR案例——多输出源处理成绩