MachineLearning(12)- RNN-LSTM-tf.nn.rnn_cell
RNN-LSTM
- 1.RNN
- 2.LSTM
- 3. tensorflow 中的RNN-LSTM
- 3.1 tf.nn.rnn_cell.BasicRNNCell()
- 3.2 tf.nn.rnn_cell.BasicLSTMCell()
- 3.3 tf.nn.dynamic_rnn()--多步执行循环神经网络
1.RNN
RNN-Recurrent Neural Network-循环神经网络
RNN用来处理序列数据。多层感知机MLP层间节点全联接,层内节点并无链接。 RNN层内节点的之间存在连接关系,用来反映上一层隐层状态作为下一层的输入,将直接输给中间隐藏层。
RNN网络模块图如下所示:
其中: xtx_txt 网络t时刻输入, hth_tht 网络t时刻的输出。 如果将RNN按时间序列展开,可以得到以下链式结构:
用简单的权重矩阵建模rnn, 输入-状态-输出之间存在以下的关系:[oto_tot即上文的hth_tht]
st=f(U⋅xt+W⋅st−1)s_{t} = f(U\cdot x_t + W\cdot s_{t-1})st=f(U⋅xt+W⋅st−1)
ot=g(V⋅st)o_t = g(V\cdot s_t)ot=g(V⋅st)
st−1s_{t-1}st−1能够建模历史信息对当前输出的影响。原始RNN随着时间的推移,历史状态对当前输出的影响减弱。但是很多任务需要长时依赖关系。LSTM营运而生。
2.LSTM
LSTM-Long Short-Term Memory 可以学习长时依赖信息,LSTM网络模块图如下所示:
状态传递机制决定了上一时刻状态信息的保留量,以及新输入信息的增量。LSTM包含三个关键的门用于实现这一传递机制。
**遗忘门:**上一时刻输出和这一个时刻输入决定上一时刻的状态保留百分比ftf_tft。【sigmoid输出0-1 之间的一个数】
输入门:C~t\tilde{C}_tC~t为新信息候选向量, iti_tit 决定了多少新信息候选向量能够通过。随后更新状态信息:
输出门: 当前状态CtC_tCt 和 ht−1h_{t-1}ht−1 以及xtx_txt 共同决定当前时刻输出信息hth_tht
参考资料:https://colah.github.io/posts/2015-08-Understanding-LSTMs/
3. tensorflow 中的RNN-LSTM
RNNCell是TensorFlow中实现RNN的基本单元,每个RNNCell都有一个call方法,使用方式是:(output, next_state) = call(input, state)。RNNCell是一个抽象类,实际使用时候,用它的两个子类BasicRNNCell [RNN的基础类] 和BasicLSTMCell [LSTM的基础类]。
3.1 tf.nn.rnn_cell.BasicRNNCell()
RNNCell,具有两个比较重要类属性:state_size–决定隐层的大小,output_size决定输出大小
例如将(batch_size, input_size)数据输入RNN,得到的隐层状态就是(batch_size, state_size),输出是(batch_size, output_size)。
import tensorflow as tf
import numpy as np
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128) # state_size = 128
print(cell.state_size) # 128inputs = tf.placeholder(np.float32, shape=(32, 100)) # batch_size=32
h0 = cell.zero_state(32, np.float32) # 全零状态(batch_size, state_size)
output, h1 = cell.call(inputs, h0)
print(h1.shape) # (32, 128)
、
3.2 tf.nn.rnn_cell.BasicLSTMCell()
tf.nn.rnn_cell.BasicLSTMCell(num_units, # int类型,隐层输出大小forget_bias=1.0, # float类型, 遗忘门偏置state_is_tuple=True, # 回的状态是h_t和c_t的2元tuple LSTM可以看做有两个隐状态h和cactivation=None, # 内部状态的激活函数。默认为tanhreuse=None,name=None,dtype=None)
3.3 tf.nn.dynamic_rnn()–多步执行循环神经网络
基础的RNNCell使用它的call函数进行运算时,只是在序列时间上前进了一步。例如使用(x1,h0)得到h1,(x2, h1)得到h2等。如果序列长度为10,需调用10次call函数,比较麻烦。
TensorFlow提供了一个tf.nn.dynamic_rnn函数,该函数可实现n次调用call函数。即通过{h0,x1, x2, …., xn}得{h1,h2…,hn}。
输入数据格式为(batch_size, time_steps, input_size),其中time_steps表示序列长度,input_size表示单个序列元素的特征长度。
tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None,parallel_iterations=None, swap_memory=False, time_major=False, scope=None
)
对于一个定义的的cell ,多次执行该cell 的demo 为:
outputs, state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
# outputs就是time_steps步里所有的输出-(batch_size, time_steps, cell.output_size)。state是最后一步的隐状态,它的形状为(batch_size, cell.state_size)。
动态调整sequence_length tf.nn.dynamic_rnn 详解
参考资料:
tensorflow学习之BasicLSTMCell详解
TensorFlow中RNN实现的正确打开方式
MachineLearning(12)- RNN-LSTM-tf.nn.rnn_cell相关推荐
- tf.nn.rnn_cell.DropoutWrapper用法细节案例1
前言:前面介绍了LSTM,下面介绍LSTM的几种变种 双向RNN Bidirectional RNN(双向RNN)假设当前t的输出不仅仅和之前的序列有关,并且 还与之后的序列有关,例如:预测一个语句中 ...
- deep_learning 03. tf.nn.rnn_cell.MultiRNNCell()
开始的话: 从基础做起,不断学习,坚持不懈,加油. 一位爱生活爱技术来自火星的程序汪 前面两节讲了两个基础的BasicRNNCell和BasicLSTMCell,今天就来看下怎么把这些简单的cell给 ...
- 成功解决没有tf.nn.rnn_cell属性
成功解决没有tf.nn.rnn_cell属性 目录 解决问题 解决思路 解决方法 解决问题 没有tf.nn.rnn_cell属性 解决思路 由于不同的TensorFlow版本之间某些函数的用法引起的错 ...
- python中tf.abs_python – Tensorflow:替换tf.nn.rnn_cell._linear(输入,大小,0,范围)
我试图从 https://arxiv.org/pdf/1609.05473.pdf开始运行SequenceGAN( https://github.com/LantaoYu/SeqGAN). 修复明显的 ...
- tf.nn.rnn_cell.DropoutWrapper用法细节案例2
-- coding: utf-8 -- import tensorflow as tf from tensorflow.contrib import rnn 导入 MINST 数据集 from ten ...
- 成功解决AttributeError: module 'tensorflow.nn.rnn_cell' has no attribute 'linear'
成功解决AttributeError: module 'tensorflow.nn.rnn_cell' has no attribute 'linear' 目录 解决问题 解决思路 解决方法 解决问题 ...
- tf.nn.dynamic_rnn的详解
tf.nn.dynamic_rnn 其和tf.nn.static_rnn,在输入,输出,参数上有很大的区别,请仔细阅读比较 tf.nn.dynamic_rnn(cell,inputs,sequence ...
- 深度学习总结:tensorflow和pytorch关于RNN的对比,tf.nn.dynamic_rnn,nn.LSTM
tensorflow和pytorch关于RNN的对比: tf.nn.dynamic_rnn很难理解,他的意思只是用数据走一遍你搭建的RNN网络. 可以明显看出pytorch封装更高,更容易理解,动态图 ...
- tf.compat.v1.nn.rnn_cell.BasicLSTMCell
可能已经弃用,现在用的是 tf.compat.v1.nn.rnn_cell.LSTMCell tf.compat.v1.nn.rnn_cell.BasicLSTMCell(num_units,forg ...
最新文章
- kmp算法详解php,php中字符串匹配KMP算法实现例子
- 2009最后一天,为了期盼而祝福
- DevExpress控件使用经验总结
- linux select 进程id,Linux基础命令---显示进程ps
- 1224 哥德巴赫猜想(2)
- 基于JAVA+Servlet+JSP+MYSQL的电影院购票系统
- java新手的第一个小东西,或许小东西都算不上=。 =
- jQuery选择器经典案例
- java语言的主要特点有简单性,太厉害了!
- 关于JavaScript中变量的相互引用
- c# gerber文件读取_gerber文件查看器|gerber文件查看工具(GerbView)下载 v7.71 免费版 - 121下载站...
- C1认证学习笔记(第一章)
- 算王标准层的量如何计算机,算王软件常用功能技巧
- linux中()、[]、{}、(())、[[]]等各种括号的使用
- ValueError: y contains previously unseen labels: ‘103125‘
- 第6章 项目整体管理
- oracle一些基本函数
- android 端口查看工具,安卓模拟器连接端口一览表:(2018.11收录10款)
- iis 自动重启的bat
- 关于遥感中影像数据的组织方法BIL/BSQ/BIP
热门文章
- erpnext mysql_windows7+docker+erpnext部署
- python alter table_python(pymysql)之mysql简单操作
- 串口输出5v电压_为什么RS485比串口速度快距离远?--谈单端信号与差分信号之差异...
- activexobject对象不能创建_Oracle数据库用户管理之系统权限和对象权限
- 计算机组装与维护实验指导,计算机组装与维护实验指导书.pdf
- php获取跳转前的地址,PHP获取短链接跳转后的真实地址和响应头信息的方法
- 【转】刨根究底字符编码之七——ANSI编码与代码页
- [你必须知道的.NET]第三十回:.NET十年(下)
- 台式机电脑配置单_2020年电脑配置单重点硬件参考
- PWN-PRACTICE-BUUCTF-26