文章目录

  • 1 前言
  • 2 RNN 的弊端
  • 3 LSTM
  • 4 代码实现
  • 5 重要部份讲解
  • 6 输出:

1 前言

和前几篇文章一样,依旧是分为讲解和代码实现。

2 RNN 的弊端


之前我们说过, RNN 是在有顺序的数据上进行学习的. 为了记住这些数据, RNN 会像人一样产生对先前发生事件的记忆. 不过一般形式的 RNN 就像一个老爷爷, 有时候比较健忘. 为什么会这样呢?


再来看看 RNN是怎样学习的吧. 红烧排骨这个信息原的记忆要进过长途跋涉才能抵达最后一个时间点. 然后我们得到误差, 而且在 反向传递 得到的误差的时候, 他在每一步都会 乘以一个自己的参数 W. 如果这个 W 是一个小于1 的数, 比如0.9. 这个0.9 不断乘以误差, 误差传到初始时间点也会是一个接近于零的数, 所以对于初始时刻, 误差相当于就消失了. 我们把这个问题叫做梯度消失或者梯度弥散 Gradient vanishing. 反之如果 W 是一个大于1 的数, 比如1.1 不断累乘, 则到最后变成了无穷大的数, RNN被这无穷大的数撑死了, 这种情况我们叫做剃度爆炸, Gradient exploding. 这就是普通 RNN 没有办法回忆起久远记忆的原因.

3 LSTM


LSTM 就是为了解决这个问题而诞生的. LSTM 和普通 RNN 相比, 多出了三个控制器. (输入控制, 输出控制, 忘记控制). 现在, LSTM RNN 内部的情况是这样.

他多了一个 控制全局的记忆, 我们用粗线代替. 为了方便理解, 我们把粗线想象成电影或游戏当中的 主线剧情. 而原本的 RNN 体系就是 分线剧情. 三个控制器都是在原始的 RNN 体系上, 我们先看 输入方面 , 如果此时的分线剧情对于剧终结果十分重要, 输入控制就会将这个分线剧情按重要程度 写入主线剧情 进行分析. 再看 忘记方面, 如果此时的分线剧情更改了我们对之前剧情的想法, 那么忘记控制就会将之前的某些主线剧情忘记, 按比例替换成现在的新剧情. 所以 主线剧情的更新就取决于输入 和忘记 控制. 最后的输出方面, 输出控制会基于目前的主线剧情和分线剧情判断要输出的到底是什么.基于这些控制机制, LSTM 就像延缓记忆衰退的良药, 可以带来更好的结果.
Long Short Term Mermory network(LSTM)是一种特殊的RNNs,可以很好地解决长时依赖问题。那么它与常规神经网络有什么不同?

4 代码实现

我们的任务是,由sin曲线预测出cos曲线。

import numpy as np
np.random.seed(1337)
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense,LSTM,TimeDistributed
from keras.optimizers import AdamBATCH_START=0
TIME_STEPS=20  # 一个batch里面取20步 看蓝色的线怎么对应上红色线
BATCH_SIZE=50
INPUT_SIZE=1 # 蓝色线一个点
OUTPUT_SIZE=1 # 红色线一个点
CELL_SIZE=20
LR=0.006def get_batch():global BATCH_START, TIME_STEPS# xs shape (50batch, 20steps)xs = np.arange(BATCH_START, BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10*np.pi) #np.piseq = np.sin(xs)res = np.cos(xs)BATCH_START += TIME_STEPS# plt.plot(xs[0,:],res[0,:],'r',xs[0,:],seq[0,:],'b--')# plt.show()return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs]#get_batch()
#exit()model=Sequential()# build a LSTM RNN
model.add(LSTM(batch_input_shape=(BATCH_SIZE,TIME_STEPS,INPUT_SIZE),output_dim=CELL_SIZE, #unitreturn_sequences=True,# 每一个时间点都有一个output 分类问题时只有最后一个时刻输出stateful=True,#batch和batch是否有关联
))# add output layer
model.add(TimeDistributed(Dense(OUTPUT_SIZE))) #每一个output全连接adam=Adam(LR)model.compile(optimizer=adam,loss="mse",
)print("____________Training_____________)")
for step in range(501):# data shape = (batch_num, steps, inputs/outputs)X_batch, Y_batch, xs = get_batch()cost = model.train_on_batch(X_batch, Y_batch)pred = model.predict(X_batch, BATCH_SIZE)# 生成的蓝色线是我要预测的cos线 发现越来越贴合plt.plot(xs[0, :], Y_batch[0].flatten(), 'r', xs[0, :], pred.flatten()[:TIME_STEPS], 'b--')plt.ylim((-1.2, 1.2))plt.draw()plt.pause(0.1)if step % 10 == 0:print('train cost: ', cost)

5 重要部份讲解

这次我们使用RNN来求解回归(Regression)问题. 首先生成序列sin(x),对应输出数据为cos(x),设置序列步长为20,每次训练的BATCH_SIZE为50.

def get_batch():
global BATCH_START, TIME_STEPS
# xs shape (50batch, 20steps)
xs = np.arange(BATCH_START, BATCH_START+TIME_STEPSBATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10np.pi)
seq = np.sin(xs)
res = np.cos(xs)
BATCH_START += TIME_STEPS
return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs]

搭建模型
然后添加LSTM RNN层,输入为训练数据,输出数据大小由CELL_SIZE定义。因为每一个输入都对应一个输出,所以return_sequences=True。 每一个点的当前输出都受前面所有输出的影响,BATCH之间的参数也需要记忆,故stateful=True

model.add(LSTM(
batch_input_shape=(BATCH_SIZE, TIME_STEPS, INPUT_SIZE), # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,
output_dim=CELL_SIZE,
return_sequences=True, # True: output at all steps. False: output as last step.
stateful=True, # True: the final state of batch1 is feed into the initial state of batch2
))
最后添加输出层,LSTM层的每一步都有输出,使用TimeDistributed函数。

model.add(TimeDistributed(Dense(OUTPUT_SIZE)))

训练
设置优化方法,loss函数和metrics方法之后就可以开始训练了。 训练501次,调用matplotlib函数采用动画的方式输出结果。

for step in range(501):
# data shape = (batch_num, steps, inputs/outputs)
X_batch, Y_batch, xs = get_batch()
cost = model.train_on_batch(X_batch, Y_batch)
pred = model.predict(X_batch, BATCH_SIZE)
plt.plot(xs[0, :], Y_batch[0].flatten(), ‘r’, xs[0, :], pred.flatten()[:TIME_STEPS], ‘b–’)
plt.ylim((-1.2, 1.2))
plt.draw()
plt.pause(0.1)
if step % 10 == 0:
print('train cost: ', cost)

6 输出:

____________Training_____________)
2020-04-12 10:46:58.468867: I T:\src\github\tensorflow\tensorflow\core\platform\cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
train cost:  0.50940645
train cost:  0.367941

把学习效率调低一点,效果就好很多。

Keras【Deep Learning With Python】LSTM 循环神经网络解决Regressor回归问题相关推荐

  1. 通过keras例子理解LSTM 循环神经网络(RNN)

    博文的翻译和实践: Understanding Stateful LSTM Recurrent Neural Networks in Python with Keras 正文 一个强大而流行的循环神经 ...

  2. Keras Tutorial: Deep Learning in Python

    This Keras tutorial introduces you to deep learning in Python: learn to preprocess your data, model, ...

  3. 自然语言处理--Keras 实现LSTM循环神经网络分类 IMDB 电影评论数据集

    LSTM 对于循环网络的每一层都引入了状态(state)的概念,状态作为网络的记忆(memory).但什么是记忆呢?记忆将由一个向量来表示,这个向量与元胞中神经元的元素数量相同.记忆单元将是一个由 n ...

  4. 如何用Python和循环神经网络预测严重交通拥堵?

    作者 | 王树义 来源 | 玉树芝兰(nkwangshuyi) 本文为你介绍,如何从 Waze 交通事件开放数据中,利用序列模型找到规律,进行分类预测.以便相关部门可以未雨绸缪,提前有效干预可能发生的 ...

  5. Deep learning with Python 学习笔记(6)

    本节介绍循环神经网络及其优化 循环神经网络(RNN,recurrent neural network)处理序列的方式是,遍历所有序列元素,并保存一个状态(state),其中包含与已查看内容相关的信息. ...

  6. 如何用 Python 和循环神经网络预测严重交通拥堵?

    本文为你介绍,如何从 Waze 交通事件开放数据中,利用序列模型找到规律,进行分类预测.以便相关部门可以未雨绸缪,提前有效干预可能发生的严重拥堵. 寻找 之前在<文科生如何理解循环神经网络(RN ...

  7. Deep learning with Python 学习笔记(9)

    神经网络模型的优化 使用 Keras 回调函数 使用 model.fit()或 model.fit_generator() 在一个大型数据集上启动数十轮的训练,有点类似于扔一架纸飞机,一开始给它一点推 ...

  8. python思想读后感_《Deep Learning with Python》读后感精选

    <Deep Learning with Python>是一本由Francois Chollet著作,Manning Publications出版的Paperback图书,本书定价:USD ...

  9. Python TensorFlow循环神经网络RNN-LSTM神经网络预测股票市场价格时间序列和MSE评估准确性...

    全文链接:http://tecdat.cn/?p=26562 该项目包括: 自 2000 年 1 月以来的股票价格数据.我们使用的是 Microsoft 股票. 将时间序列数据转换为分类问题. 使用 ...

最新文章

  1. 基于JWT的API权限校验:需求分析
  2. left join on用法_MySQL 多表查询 quot;Joinquot;+“case when”语句总结
  3. 少了unicon-tools是不行滴
  4. python tkinter小项目
  5. 中文表示什么_中文分词是个伪问题
  6. 计算机大学离散数学难吗,大学离散数学怎么学
  7. matlab步长教程,matlab仿真步长
  8. 销售管理系统哪个好用?
  9. 小波变换matlab加密,混沌和小波变换的图像加密压缩算法
  10. 知道今天是星期几java_java如何判断今天是星期几
  11. led大屏按实际尺寸设计画面_LED显示屏尺寸规格及计算方法
  12. AVS2实时编码器xavs2的运行
  13. JAVA数字大写金额转换
  14. URP shader 学习笔记
  15. HDU2037:今年暑假不AC
  16. 打分功能,车牌区域划分
  17. stm32ftp服务器实现文件传输,stm32 ftp服务器
  18. 【C语言】C语言实现中文字符(句号,感叹号,问号)的十进制数值
  19. Unity自定义字体 包括中文
  20. 浙江省计算机二级office选择判断题库,浙江省计算机二级office选择判断(无答案版...

热门文章

  1. pytorch多进程加载数据
  2. 用openCV去除文字中乱入的线条
  3. eclipse 无法解析导入 javax.servlet 的解决方法
  4. 世界上最完美的公式 ----欧拉公式
  5. tc c语言弹出式下拉式菜单,c语言制作弹出式菜单
  6. java调用存储过程 oracle_java调用oracle存储过程
  7. 二级计算机为让利消费者,计算机二级office题库训练题(2)
  8. php给留言分配id_如何使用php生成唯一ID的4种方法
  9. 某同学使用计算机求30,概率论与数理统计习题集及答案
  10. 兔子繁殖MATLAB,2011-2012数学建模题