文章目录

  • 1.LSTM 网络
  • 2.之前也提到过RNNs取得了不错的成绩,这些成绩很多是基于LSTMs来做的,说明LSTMs适用于大部分的序列场景应用。
  • 3.代码实现

1.LSTM 网络

可以理解为RNN的升级。
Long Short Term Memory networks(以下简称LSTMs),一种特殊的RNN网络,该网络设计出来是为了解决长依赖问题。该网络由 Hochreiter & Schmidhuber (1997)引入,并有许多人对其进行了改进和普及。他们的工作被用来解决了各种各样的问题,直到目前还被广泛应用。

所有循环神经网络都具有神经网络的重复模块链的形式。 在标准的RNN中,该重复模块将具有非常简单的结构,例如单个tanh层。标准的RNN网络如下图所示

LSTMs也具有这种链式结构,但是它的重复单元不同于标准RNN网络里的单元只有一个网络层,它的内部有四个网络层。LSTMs的结构如下图所示。

在解释LSTMs的详细结构时先定义一下图中各个符号的含义,符号包括下面几种

图中黄色类似于CNN里的激活函数操作,粉色圆圈表示点操作,单箭头表示数据流向,箭头合并表示向量的合并(concat)操作,箭头分叉表示向量的拷贝操作。

2.之前也提到过RNNs取得了不错的成绩,这些成绩很多是基于LSTMs来做的,说明LSTMs适用于大部分的序列场景应用。

3.代码实现


# please note, all tutorial code are running under python3.5.
# If you use the version like python2.7, please modify the code accordingly# 8 - RNN LSTM Regressor example# to try tensorflow, un-comment following two lines
# import os
# os.environ['KERAS_BACKEND']='tensorflow'
import numpy as np
np.random.seed(1337)  # for reproducibility
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, TimeDistributed, Dense
from keras.optimizers import AdamBATCH_START = 0
TIME_STEPS = 20
BATCH_SIZE = 50
INPUT_SIZE = 1
OUTPUT_SIZE = 1
CELL_SIZE = 20
LR = 0.006def get_batch():global BATCH_START, TIME_STEPS# xs shape (50batch, 20steps)xs = np.arange(BATCH_START, BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10*np.pi)seq = np.sin(xs)res = np.cos(xs)BATCH_START += TIME_STEPS# plt.plot(xs[0, :], res[0, :], 'r', xs[0, :], seq[0, :], 'b--')# plt.show()return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs]model = Sequential()
# build a LSTM RNN
model.add(LSTM(batch_input_shape=(BATCH_SIZE, TIME_STEPS, INPUT_SIZE),       # Or: input_dim=INPUT_SIZE, input_length=TIME_STEPS,output_dim=CELL_SIZE,return_sequences=True,      # True: output at all steps. False: output as last step.stateful=True,              # True: the final state of batch1 is feed into the initial state of batch2
))
# add output layer
model.add(TimeDistributed(Dense(OUTPUT_SIZE)))
adam = Adam(LR)
model.compile(optimizer=adam,loss='mse',)print('Training ------------')
for step in range(501):# data shape = (batch_num, steps, inputs/outputs)X_batch, Y_batch, xs = get_batch()cost = model.train_on_batch(X_batch, Y_batch)pred = model.predict(X_batch, BATCH_SIZE)plt.plot(xs[0, :], Y_batch[0].flatten(), 'r', xs[0, :], pred.flatten()[:TIME_STEPS], 'b--')plt.ylim((-1.2, 1.2))plt.draw()plt.pause(0.1)if step % 10 == 0:print('train cost: ', cost)

Keras【Deep Learning With Python】更优模型探索Keras实现LSTM相关推荐

  1. Keras【Deep Learning With Python】更优模型探索Keras实现CNN

    文章目录 CNN 1.CNN介绍 2.CNN基本原理 代码实现 手写数字复杂网络层抽特征可视化工具http://scs.ryerson.ca/~aharley/vis/conv/ CNN 1.CNN介 ...

  2. Keras【Deep Learning With Python】更优模型探索Keras实现RNN

    文章目录 RNN简介 1.RNN的应用 2.什么是RNN? 3.RNN用来做什么? 4. 训练RNNs Keras代码实现(Mnist) RNN简介 1.RNN的应用 RNN主要有两个应用,一是评测一 ...

  3. Keras Tutorial: Deep Learning in Python

    This Keras tutorial introduces you to deep learning in Python: learn to preprocess your data, model, ...

  4. Deep learning with Python 学习笔记(9)

    神经网络模型的优化 使用 Keras 回调函数 使用 model.fit()或 model.fit_generator() 在一个大型数据集上启动数十轮的训练,有点类似于扔一架纸飞机,一开始给它一点推 ...

  5. python思想读后感_《Deep Learning with Python》读后感精选

    <Deep Learning with Python>是一本由Francois Chollet著作,Manning Publications出版的Paperback图书,本书定价:USD ...

  6. Deep learning with Python 学习笔记(6)

    本节介绍循环神经网络及其优化 循环神经网络(RNN,recurrent neural network)处理序列的方式是,遍历所有序列元素,并保存一个状态(state),其中包含与已查看内容相关的信息. ...

  7. Python深度学习:基于PyTorch [Deep Learning with Python and PyTorch]

    作者:吴茂贵,郁明敏,杨本法,李涛,张粤磊 著 出版社:机械工业出版社 品牌:机工出版 出版时间:2019-11-01 Python深度学习:基于PyTorch [Deep Learning with ...

  8. Deep Learning with Python

    1.学习地址 Deep Learning with Python(wang@123) 2.大神的twitter 大神的twitter

  9. Keras【Deep Learning With Python】—Keras基础

    文章目录 1.关于Keras 2.Keras的模块结构 3.使用Keras搭建一个神经网络 4. 主要概念 5.第一个示例 下载网站数据注意 1.关于Keras 1)简介 Keras是由纯python ...

最新文章

  1. python嵌套字典赋值_Python:更新深度嵌套字典中的值
  2. 什么地方容易刷出ak_男人会用什么理由拒绝表白?
  3. CentOS7 Python3安装redis
  4. 云效飞流Flow项目版本管理的最佳实践
  5. 树莓派+神经计算棒2实时人脸检测
  6. libghttp 编译及封装使用实例
  7. Android 8.0 学习(13)---开发者 FAQ
  8. Spark Shuffle 堆外内存溢出问题与解决(Shuffle通信原理)
  9. linux阻止程序,Linux:阻止某些应用程序/主机名的IPv6
  10. Linux : ext3_free_blocks: Freeing blocks not in datazone
  11. css摇杆,【宅家呗专题】Virtual Joystick虚拟摇杆插件教程
  12. mysql增加约束sql语句_sql语句添加约束
  13. select设置默认的option
  14. [转]PCI与PCIe
  15. Dorado7 notify非alert 输入框prompt confirm layer dialoger,layer.msg,toast效果,几秒关闭layer.load layer.open
  16. android 支付选择按钮,微信支付新增“确认”按钮,更安全还是更麻烦?
  17. Type com.xxx.xxx.BuildConfig is defined multiple times:...
  18. SSM (JDK 1.7) 使用Rabbit MQ
  19. 3D Fe3O4@Au@Ag nanoflowers assembled magnetoplasmonic chains for in situ SERS monitoring of plasmon-
  20. S3C4412学习笔记

热门文章

  1. android-ndk-r15c libncurses.so.5
  2. 双系统 win10 时间不对
  3. pytorch 预处理
  4. 四丶KingFeng搭建+青龙对接kingfeng
  5. HuaWei设置镜像端口和观察端口
  6. php array_merge内存不够,php array_merge函数使用需要注意的一个问题
  7. mysql server windows_Windows下mysql community server 8.0.16安装配置方法图文教程
  8. android毛玻璃效果,Android 中实现毛玻璃效果
  9. python正则表达式面试_【正则表达式Python面试题】面试问题:Scrapy之… - 看准网...
  10. mysql限制小数位_Mysql中设置小数点用什么数据类型 decimal