声明:代码主要参考https://github.com/wzg16/tensorflow-convlstm-cell/blob/master/cell.py

并按下图的将对应的变量名修改,文章对LSTM介绍的很通俗易懂,建议大家去看https://zhuanlan.zhihu.com/p/32085405

个人更喜欢按以下方式称呼各个门:

z_f: 遗忘门,因为这个门负责从c中遗忘掉某些东西

z_i:记忆门,因为这个门负责从当前的输入信息 z 中提取需要记忆的信息

z: 输入信息融合,这不是个门。这是一个把(h,x)进行融合变换的过程。虽然公式与其他的门的计算公式一样。

z_o: 输出门,因为这个门负责从c中提取需要输出的信息

cell.py

import tensorflow as tfclass ConvLSTMCell(tf.nn.rnn_cell.RNNCell):"""A LSTM cell with convolutions instead of multiplications. 在RNNCell的基础上定义子类Reference:Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015."""""""""def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=True, peephole=True, data_format='channels_last', reuse=None):""":param shape: 图像的尺寸,shape=[height,width]:param filters: 标量,kernel的个数,输出的hidden\cell的通道数:param kernel: kernel_size, shape=[kernel_height,kernel_width]:param forget_bias: 遗忘门:param activation: 激活函数:param normalize: 是否执行层标准化:param peephole::param data_format: 数据的格式是'channels_last':[N,H,W,C] 还是'channels_first': [C,N,H,W]:param reuse:"""super(ConvLSTMCell, self).__init__(_reuse=reuse)self._kernel = kernel  # 卷积核,是一个张亮self._filters = filters # 卷积核的个数,标量,代表cell/state的通道数self._forget_bias = forget_biasself._activation = activationself._normalize = normalizeself._peephole = peephole # 是否让门层接受细胞状态的输入if data_format == 'channels_last':self._size = tf.TensorShape(shape + [self._filters])  # 张量的shape ,channel数放在最后self._feature_axis = self._size.ndims  # 张量的维数self._data_format = None        # 数据的格式,主要指channel的排列顺序elif data_format == 'channels_first':self._size = tf.TensorShape([self._filters] + shape)self._feature_axis = 0  # 特征轴,channel所在轴self._data_format = 'NC'else:raise ValueError('Unknown data_format')@propertydef state_size(self):return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size)@propertydef output_size(self):return self._sizedef call(self, x, state):"""c: cell stateh: hidden statex: input-tensor:param x::param state::return:"""c, h = state# 串联x = tf.concat([x, h], axis=self._feature_axis) # x: 输入门的结果,沿channel通道串联,shape=[N,H,W,C]# 卷积 + 偏置input_channel = x.shape[-1].value                          # n: 输入的通道数output_channel = 4 * self._filters if self._filters > 1 else 4  # m:卷积后的通道数W = tf.get_variable('kernel', self._kernel + [input_channel, output_channel]) # 卷积核定义moment_ = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format) # 输入门的结果if not self._normalize:moment_ += tf.get_variable('bias', [output_channel], initializer=tf.zeros_initializer())# 生成各个门控z, z_i, z_f, z_o = tf.split(moment_, 4, axis=self._feature_axis) # 把y沿channel通道等分成4份if self._peephole: # 让门层接受细胞状态的输入, c.shape=[]z_i += tf.get_variable('W_ci', c.shape[1:]) * cz_f += tf.get_variable('W_cf', c.shape[1:]) * cif self._normalize: # 层标准化,有利于加快收敛z = tf.contrib.layers.layer_norm(z)z_i = tf.contrib.layers.layer_norm(z_i)z_f = tf.contrib.layers.layer_norm(z_f)# 记忆门与遗忘门的最终状态,激活到[0,1],0表示完全遗忘,1表示完全记忆z_f = tf.sigmoid(z_f + self._forget_bias)  # 遗忘门z_i = tf.sigmoid(z_i)                      # 记忆门c_new = c * z_f + z_i * self._activation(z)    # 更新细胞状态,遗忘部分信息,增加一部分新的信息if self._peephole: # c.shape=?,让细胞状态参与输出门z_o += tf.get_variable('W_co', c_new.shape[1:]) * c_newif self._normalize:z_o = tf.contrib.layers.layer_norm(z_o)c_new = tf.contrib.layers.layer_norm(c_new)z_o = tf.sigmoid(z_o) # 输出门,归一化到[0,1]h_new = z_o * self._activation(c_new) # 输出门state = tf.nn.rnn_cell.LSTMStateTuple(c_new, h_new)return h_new, stateclass ConvGRUCell(tf.nn.rnn_cell.RNNCell):"""A GRU cell with convolutions instead of multiplications."""def __init__(self, shape, filters, kernel, activation=tf.tanh, normalize=True, data_format='channels_last', reuse=None):super(ConvGRUCell, self).__init__(_reuse=reuse)self._filters = filtersself._kernel = kernelself._activation = activationself._normalize = normalizeif data_format == 'channels_last':self._size = tf.TensorShape(shape + [self._filters])self._feature_axis = self._size.ndimsself._data_format = Noneelif data_format == 'channels_first':self._size = tf.TensorShape([self._filters] + shape)self._feature_axis = 0self._data_format = 'NC'else:raise ValueError('Unknown data_format')@propertydef state_size(self):return self._size@propertydef output_size(self):return self._sizedef call(self, x, h):channels = x.shape[self._feature_axis].valuewith tf.variable_scope('gates'):inputs = tf.concat([x, h], axis=self._feature_axis)n = channels + self._filtersm = 2 * self._filters if self._filters > 1 else 2W = tf.get_variable('kernel', self._kernel + [n, m])y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)if self._normalize:r, u = tf.split(y, 2, axis=self._feature_axis)r = tf.contrib.layers.layer_norm(r)u = tf.contrib.layers.layer_norm(u)else:y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())r, u = tf.split(y, 2, axis=self._feature_axis)r, u = tf.sigmoid(r), tf.sigmoid(u)with tf.variable_scope('candidate'):inputs = tf.concat([x, r * h], axis=self._feature_axis)n = channels + self._filtersm = self._filtersW = tf.get_variable('kernel', self._kernel + [n, m])y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)if self._normalize:y = tf.contrib.layers.layer_norm(y)else:y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())h = u * h + (1 - u) * self._activation(y)return h, h

代码的运行与测试:

test.py

import tensorflow as tfbatch_size = 32
timesteps = 100
shape = [640, 480] # [height,width]
kernel = [3, 3] # 卷积核的尺寸
channels = 3 # input_channel
filters = 12 # output_channel# Create a placeholder for videos.
# 对输入的格式需求,inputs.shape=[batch_size,time_steps,height,width,input_channels]
# 比通常的 CNN多了一个time_step维度,放在batch_size 后面
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels])# Add the ConvLSTM step.
# tf.nn.dynamic_rnn 会自动识别 inputs 中的 time_steps,并设定网络
from cell import ConvLSTMCell
cell = ConvLSTMCell(shape, filters, kernel)
outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype)
# outputs中包含了每个time_step的输出,outputs.shape=[batch_size,time_steps,height,width,output_channels]
# state 是一个tuple,包含了最后一个timestep的Ct与ht,其中ht与outputs的最后一个timestep的输出相同
# state.c : Ct,state.c.shape=[batch_size,height,width,output_channel]
# state.h : ht,state.h.shape=[batch_size,height,width,output_channel]# There's also a ConvGRUCell that is more memory efficient.
from cell import ConvGRUCell
cell = ConvGRUCell(shape, filters, kernel)
outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype)# It's also possible to enter 2D input or 4D input instead of 3D.
shape = [100]
kernel = [3]
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels])
cell = ConvLSTMCell(shape, filters, kernel)
outputs, state = tf.nn.bidirectional_dynamic_rnn(cell, cell, inputs, dtype=inputs.dtype)shape = [50, 50, 50]
kernel = [1, 3, 5]
inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels])
cell = ConvGRUCell(shape, filters, kernel)
outputs, state= tf.nn.bidirectional_dynamic_rnn(cell, cell, inputs, dtype=inputs.dtype)

代码已测试,可以运行。关于 outputs, state = tf.nn.dynamic_rnn的更多参数解释与输出说明可以参考https://blog.csdn.net/Strive_For_Future/article/details/103605290,文章讲的还不错

convLSTM-tensorflow:LSTM理解相关推荐

  1. TensorFlow LSTM 注意力机制图解

    TensorFlow LSTM Attention 机制图解 深度学习的最新趋势是注意力机制.在接受采访时,现任OpenAI研究主管的Ilya Sutskever提到,注意力机制是最令人兴奋的进步之一 ...

  2. tensorflow LSTM + CTC实现端到端OCR

    最近在做OCR相关的东西,关于OCR真的是有悠久了历史了,最开始用tesseract然而效果总是不理想,其中字符分割真的是个博大精深的问题,那么多年那么多算法,然而应用到实际总是有诸多问题.比如说非等 ...

  3. tensorflow lstm 实现 RNN / LSTM 的关键几个步骤 多层通俗易懂

    https://blog.csdn.net/Jerr__y/article/details/61195257?depth_1-utm_source=distribute.pc_relevant.non ...

  4. tensorflow lstm 预测_解析seq2seq原理+tensorflow实现

    1 写在前面 最近抽时间写文章,将自己以前学过的知识总结一下,通过文章记录下来,以后忘记了也可以随时翻阅. 本文主要介绍最基础的seq2seq模型,并尝试使用tensorflow实现.很多朋友都说py ...

  5. tensorflow LSTM+CTC实现端到端的不定长数字串识别

    转载地址: https://www.jianshu.com/p/45828b18f133 上一篇文章tensorflow 实现端到端的OCR:二代身份证号识别实现了定长18位数字串的识别,并最终达到了 ...

  6. 【TensorFlow】理解tf.nn.conv2d方法 ( 附代码详解注释 )

    最近在研究学习TensorFlow,在做识别手写数字的demo时,遇到了tf.nn.conv2d这个方法,查阅了官网的API 发现讲得比较简略,还是没理解.google了一下,参考了网上一些朋友写得博 ...

  7. tensorflow LSTM

    循环神经网络 介绍 可以在 this great article 查看循环神经网络(RNN)以及 LSTM 的介绍. 语言模型 此教程将展示如何在高难度的语言模型中训练循环神经网络.该问题的目标是获得 ...

  8. Tensorflow LSTM时间序列预测的尝试

    一.网上的资源 网上有不少用LSTM来预测时间序列的资源,如下面: 深度学习(08)_RNN-LSTM循环神经网络-03-Tensorflow进阶实现 http://blog.csdn.net/u01 ...

  9. tensorflow lstm 预测_图卷积神经网络GCN与递归结构RNN相结合的时间序列预测

    时间序列预测任务可以按照不同的方法执行.最经典的是基于统计和自回归的方法.更准确的是基于增强和集成的算法,我们必须使用滚动周期生成大量有用的手工特性.另一方面,我们可以使用在开发过程中提供更多自由的神 ...

最新文章

  1. 「工科神器」MATLAB风波未平,「化学神器」ChemOffice再爆清查国内盗版行为
  2. java 面试题 简书_java面试题
  3. 线下课程推荐 | 知识图谱理论与实战:构建行业知识图谱 (第四期)
  4. html filter 在线预览,HTML Filter
  5. 派生类构造的时候一定要调用_夏天使用电蚊香的时候一定要注意这几点
  6. C#合成解析XML与JSON
  7. 大学课程 | 《微机原理与接口技术》知识点总结
  8. Android富文本编辑器RichEditor的使用
  9. android 应用中 assets 下文件的绝对路径
  10. 如何构建高效可信的持续交付能力,华为云有绝活!
  11. android文本框删除按钮,ClearEditText — 带删除按钮的输入框
  12. 如何搭建免费的网络验证系统
  13. css绝对定位如何居中?css绝对定位居中的四种实现方法-web前端教程
  14. 机器学习常用小代码块
  15. idea ctrl+alt+鼠标左键和ctrl+鼠标左键的作用和区别
  16. CLion 使用VS环境
  17. div标签和span标签区别
  18. C++高阶 类型转换函数最透彻的一篇文章
  19. 机械革命无线网消失解决办法
  20. 软件测试工程师面试一般常见问题汇总

热门文章

  1. mysql int 拼接_MySQL 修改int类型为bigint SQL语句拼接
  2. 鼠标悬停显示图片html5,JavaScript 鼠标悬停图片,显示隐藏文本
  3. origin9语言设置中文_《英雄联盟手游》界面翻译图 LOL手游界面设置全翻译图一览...
  4. vc+ mfc 方法怎么被调用_Spring源码阅读(二)我的方法是怎么被自动调用的
  5. public class c中_Spring中@Import的各种用法以及ImportAware接口
  6. ubuntu ftp服务器_如何在Ubuntu上安装FTP服务器?
  7. unix和linux命令_Linux / Unix系统中SSH命令的用法
  8. photorec_如何在Linux / Ubuntu中使用PhotoRec恢复已删除的文件
  9. 小白如何快速学会C++?
  10. 开课吧课堂之如何创建多级类层次