已经有很多帖子对PredRNN++的理论和改进效果进行了解读,不再赘述。直接分析结构和代码。

Causal LSTM 单元

三层级联结构:
第一层(蓝色框)类似传统的LSTM结构用于更新时间状态C(temporal state) ;
第二层(绿色框)也是类似的结构,利用更新的C和上一层的M更新空间状态M(spatial state);
第三层(灰色框)是输出门的结构,利用X、新的C和M来更新H (hidden state)。
所以一个Causal LSTM 细胞单元的输入为X、H、C和M,输出为更新的H、C和M。
最后一个时间步的H用于预测生成目标序列。

根据上面的公式可以细看一下。
第一层根据当前步的X和前一步的H、C来决定门,然后获得当前步的C。
第二层利用刚更新的当前步的C、X和前一层的M来决定第二层的门。在对前一层的M做tanh非线性变换后获得更新的M。M、o和H的更新都使用了tanh,增加了非线性操作。
第三层利用X、更新的当前步的C和当前层的M来更新输出门。在作者源代码中还加入了前一步的H,对预测准确度的影响有多大未经验证。又或者是公式里遗漏了(但根据框图看不像是漏了)。在ST-LSTM结构中输出门确实是用到H的。

# CausalLSTMCell.py 124-127行。
if x is None:  o = tf.tanh(o_h + o_c + o_m)
else:o = tf.tanh(o_x + o_h + o_c + o_m)

最后利用输出门、C和M来更新H,输出给下一步。

GHU

理论证据表明,highway layers能够在非常深的前馈网络中有效地传递梯度,所以作者将这一思想应用到递归网络中,以防止长期梯度的快速消失,并提出了一种新的时空递归结构GHU,结构如下:

S是转换门,决定在P和Z之间各保留多大比例。这里的X相当于hidden state H,Z相当于更新后的H。

PredRNN++总体结构

具体看一下H、C、M和Z的传递过程:
H
H既有横向传递也有纵向传递。H在同一时间步的传递(纵向)是作为input(X)进行传递。在时间维度上是作为hidden state进行传递。
C
C是只在同层的时间步直接进行传递。
M
如果将不同时间步的网络单元全部堆叠起来就容易发现是一个M在持续向下传递。
Z
在同一时间步作为H1的更新向下一层传递一次,主要在横向进行传递,将长时特征传递给未来。

代码解读

输入

输入为5D数据,格式为(B,T,H,W,C)分别代表batch size,时间步数,高,宽,通道(channel)。对于有监督学习我们通常将数据分为输入X和对应标签y。以输入前10步预测后10步为例,对于一个样本X=(1:10,H,W,C),输出y_hat = (11:20,H,W,C), 然后将y_hat和y进行损失计算,反向传播更新参数。
但在这里作者没有区分y,X=(1:20,H,W,C)。而是利用一个0-1掩码矩阵随机选择生成的数据和真实数据来作为预测下一步的输入,增加模型的稳健性。

# train.py 75-87行。
# x的维度为(1,20,16,16,16)
self.x = tf.placeholder(tf.float32,[FLAGS.batch_size,FLAGS.seq_length,FLAGS.img_width/FLAGS.patch_size,FLAGS.img_width/FLAGS.patch_size,FLAGS.patch_size*FLAGS.patch_size*FLAGS.img_channel])
# mask_true的维度为(1,9,16,16,16)
# 为什么是9步呢,因为最后一步肯定是要被预测出来的,不会用作输入。
self.mask_true = tf.placeholder(tf.float32,[FLAGS.batch_size,FLAGS.seq_length-FLAGS.input_length-1,FLAGS.img_width/FLAGS.patch_size,FLAGS.img_width/FLAGS.patch_size,FLAGS.patch_size*FLAGS.patch_size*FLAGS.img_channel])

在训练时,每一次循环都重置mask_true。

        # train.py 173-199行。if itr < 50000:eta -= deltaelse:eta = 0.0# random_flip 是一个(1,9)的[0,1]区间随机数向量。random_flip = np.random.random_sample((FLAGS.batch_size,FLAGS.seq_length-FLAGS.input_length-1))# 判断哪些比eta小,为True则使用真实数据否则用生成的数据。true_token = (random_flip < eta)# 作者尝试过利用指数函数衰减的方式作为阈值判断。#true_token = (random_flip < pow(base,itr)) # ones = (16,16,16)ones = np.ones((int(FLAGS.img_width/FLAGS.patch_size),int(FLAGS.img_width/FLAGS.patch_size),FLAGS.patch_size**2*FLAGS.img_channel))# zeros = (16,16,16)zeros = np.zeros((int(FLAGS.img_width/FLAGS.patch_size),int(FLAGS.img_width/FLAGS.patch_size),FLAGS.patch_size**2*FLAGS.img_channel))mask_true = []for i in range(FLAGS.batch_size):for j in range(FLAGS.seq_length-FLAGS.input_length-1):if true_token[i,j]:mask_true.append(ones)else:mask_true.append(zeros)# mask_true = (9,16,16,16)mask_true = np.array(mask_true)# mask_true = (1,9,16,16,16)mask_true = np.reshape(mask_true, (FLAGS.batch_size,FLAGS.seq_length-FLAGS.input_length-1,int(FLAGS.img_width/FLAGS.patch_size),int(FLAGS.img_width/FLAGS.patch_size),FLAGS.patch_size**2*FLAGS.img_channel))

随着循环次数增加,使用真实数据的比例不断下降。但是如果设置的最大循环次数过低,eta下降的幅度太小的话会出现每次都用真实值作为输入的情况。论文中提到在KTH action 数据集上设定的循环次数是200000,在MNIST数据集上未提及循环次数,代码中设定的是80。

# predrnn_pp.py 40-43行。
if t < input_length:inputs = images[:,t]
else:inputs = mask_true[:,t-10]*images[:,t] + (1-mask_true[:,t-10])*x_gen
# mask_true[:,t-10]为ones矩阵的话取真实数据images[:,t],否则取预测数据x_gen.

reverse_input
如果为True,则在训练时将数据倒序后再训练一次,误差取两次的平均。

cost = model.train(ims, lr, mask_true)
if FLAGS.reverse_input:ims_rev = ims[:,::-1]cost += model.train(ims_rev, lr, mask_true)cost = cost/2

输出

先看代码

    # predrnn_pp.pygen_images = []for t in range(seq_length-1): #predrnn++网络结构,略过后面讲# x_gen为生成的预测帧x_gen = tf.layers.conv2d(inputs=hidden[num_layers-1],filters=output_channels,kernel_size=1,strides=1,padding='same',name="back_to_pixel")gen_images.append(x_gen)gen_images = tf.stack(gen_images)# [batch_size, seq_length, height, width, channels]gen_images = tf.transpose(gen_images, [1,0,2,3,4])loss = tf.nn.l2_loss(gen_images - images[:,1:])#loss += tf.reduce_sum(tf.abs(gen_images - images[:,1:]))return [gen_images, loss]

这里有意思了,作者其实是搭了一个19步的网络结构,输入是20帧,输出是19帧,损失函数计算的也是19帧的误差。而且利用的不是前10帧预测下一帧,而是越往后利用的信息越多,即预测第t+1帧时利用的是前面所有帧的信息。但是前面9帧不属于要预测的范围,属于输入,而且利用的是更少的数据预测出来的,作者在这里也把他们加入到损失计算里了。正常来讲我们只需计算后10帧预测的损失。这部分读者可以根据自身需求进行适当改写。

CausalLSTMCell

可以先看一下公式右侧出现的变量各参与了几次卷积操作。

变量 卷积次数
X_t 7
C_t-1 3
H_t-1 4(输出门里也有一个)
M_k-1 4
C_t 5
M_k 2
总计 25

弄清楚了这个就容易看懂代码了,举个例子如下:

x_cc = tf.layers.conv2d(x, self.num_hidden*7,self.filter_size, 1, padding='same',kernel_initializer=self.initializer,name='input_to_state')i_x, g_x, f_x, o_x, i_x_, g_x_, f_x_ = tf.split(x_cc, 7, 3)# 这里的x_cc有7个卷积,最后被分为7份参与7个门的计算。

layer_normalization

if self.layer_norm:x_cc = tensor_layer_norm(x_cc, 'x2c')
#相当于
from tensorflow.python.keras.layers.experimental import LayerNormalization
x_cc = LayerNormalization(axis=[1,2,3])(x_cc)

GradientHighwayUnit

    def __call__(self, x, z):if z is None:z = self.init_state(x, self.num_features)with tf.variable_scope(self.layer_name):# z和x各卷积2次,生成z_concat和x_concat,用于P,S的计算。z_concat = tf.layers.conv2d(z, self.num_features*2,self.filter_size, 1, padding='same',kernel_initializer=self.initializer,name='state_to_state')if self.layer_norm:z_concat = tensor_layer_norm(z_concat, 'state_to_state')x_concat = tf.layers.conv2d(x, self.num_features*2,self.filter_size, 1, padding='same',kernel_initializer=self.initializer,name='input_to_state')if self.layer_norm:x_concat = tensor_layer_norm(x_concat, 'input_to_state')gates = tf.add(x_concat, z_concat)#这里的u就是论文中的S(switch gate).p, u = tf.split(gates, 2, 3)p = tf.nn.tanh(p)u = tf.nn.sigmoid(u)z_new = u * p + (1-u) * zreturn z_new

predrnn_pp

def rnn(images, mask_true, num_layers, num_hidden, filter_size, stride=1,seq_length=20, input_length=10, tln=True):gen_images = []lstm = []cell = []hidden = []shape = images.get_shape().as_list()output_channels = shape[-1]# 定义num_layers=4个cslstm层。初始化cell state 和hidden state 都为None.for i in range(num_layers):if i == 0: # 如果是第一层则传入最后一层的hidden数,实现的是M的z字形传递。num_hidden_in = num_hidden[num_layers-1]else:num_hidden_in = num_hidden[i-1]new_cell = cslstm('lstm_'+str(i+1),filter_size,num_hidden_in,num_hidden[i],shape,tln=tln)lstm.append(new_cell)cell.append(None)hidden.append(None)# 定义ghu层#在这里每个时间步上有5个层,其中4个cslstm层和一个ghu层,ghu处在第一个和第二个cslstm中间gradient_highway = ghu('highway', filter_size, num_hidden[0], tln=tln)# 初始化M和Z,它们会不断更新。mem = Nonez_t = None# 下面是predrnn的完整结构# 19步,每步5层。for t in range(seq_length-1):  reuse = bool(gen_images)with tf.variable_scope('predrnn_pp', reuse=reuse):# 决定每一时间步的输入Xif t < input_length:inputs = images[:,t]else: #随机选择利用真实数据还是上一步刚生成的预测数据。inputs = mask_true[:,t-10]*images[:,t] + (1-mask_true[:,t-10])*x_gen  # outputs,第二部分好像都为0没有意义##################################################################
#这部分就是同一时间步内的5层网络结构
# z_t每一个时间步更新一次
# mem沿着层级和时间步一直更新
# hidden 和 cell在同层之间沿着时间步更新hidden[0], cell[0], mem = lstm[0](inputs, hidden[0], cell[0], mem)z_t = gradient_highway(hidden[0], z_t)hidden[1], cell[1], mem = lstm[1](z_t, hidden[1], cell[1], mem)for i in range(2, num_layers):hidden[i], cell[i], mem = lstm[i](hidden[i-1], hidden[i], cell[i], mem)
################################################################### 利用最后一个hidden state 进行2D卷积生成预测序列。x_gen = tf.layers.conv2d(inputs=hidden[num_layers-1],filters=output_channels,kernel_size=1,strides=1,padding='same',name="back_to_pixel")gen_images.append(x_gen)gen_images = tf.stack(gen_images)# [batch_size, seq_length, height, width, channels]gen_images = tf.transpose(gen_images, [1,0,2,3,4])#对预测的19帧进行了l2损失计算。loss = tf.nn.l2_loss(gen_images - images[:,1:])#loss += tf.reduce_sum(tf.abs(gen_images - images[:,1:]))return [gen_images, loss]

PredRNN++:网络结构和代码解读相关推荐

  1. Inception代码解读

    Inception代码解读 目录 Inception代码解读 概述 Inception网络结构图 inception网络结构框架 inception代码细节分析 概述 inception相比起最开始兴 ...

  2. AlexNet代码解读

    AlexNet代码解读 目录 AlexNet代码解读 概述 网络结构图 AlexNet代码细节分析 概述 AlexNet的网络结构很简单,是最初级版本的CNN,没有使用什么技巧. 网络分成两个部分,分 ...

  3. Resnet的pytorch官方实现代码解读

    Resnet的pytorch官方实现代码解读 目录 Resnet的pytorch官方实现代码解读 前言 概述 34层网络结构的"平原"网络与"残差"网络的结构图 ...

  4. Memory-Associated Differential Learning论文及代码解读

    Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论 ...

  5. ResNet及其变种的结构梳理、有效性分析与代码解读(PyTorch)

    点击我爱计算机视觉标星,更快获取CVML新技术 本文来自知乎,作者费敬敬,现为同济大学计算机科学与技术硕士. https://zhuanlan.zhihu.com/p/54289848 温故而知新,理 ...

  6. 编译原理语义分析代码_Pix2Pix原理分析与代码解读

    原理分析: 图像.视觉中很多问题都涉及到将一副图像转换为另一幅图像(Image-to-Image Translation Problem),这些问题通常都使用特定的方法来解决,不存在一个通用的方法.但 ...

  7. BigGAN代码解读(gpt3.5帮助)——生成器部分

    代码来源于Github中点赞最多的BigGAN复现 作者个人学习记录 BigGAN的生成器代码内部引用了代码人员编写的谱正则化(SN)以及批正则化(BN),关于这部分的解读地址在这里: 批正则化 谱正 ...

  8. mask rcnn 超详细代码解读(一)

    mask r-cnn 代码解读(一) 文章目录 1 代码架构 2 model.py 的结构 3 train过程代码解析 3.1 Resnet Graph 3.2 Region Proposal Net ...

  9. siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency

    阅前说明 前面已经出现的代码用 - 代替. 本文仅解析train部分的代码(inference的部分会后续更新). 不对网络结构做过多解释,默认已经熟悉 mrcnn 的结构以及读过这篇论文了. 另:i ...

最新文章

  1. [IE9] 解决了傲游、搜狗浏览器在IE9下网页截图的问题
  2. 13.2System类中的常用方法
  3. 机器学习笔记(十一)----降维
  4. 医学专业失业率最高 三类相关行业人才紧缺
  5. 最新自动发卡网源码V7.0
  6. mysql 数据库设计实例_一个简单数据库设计例子
  7. jq UI-引入、拖动效果、api文档位置
  8. python 办公自动化 视频教程_Python自学爬虫/办公自动化视频教程
  9. echart:legend中显示value+自定义文字样式
  10. 高校三维地图校内导航系统解决方案
  11. 请冷静地对待手中的EOS——EOS数据分析
  12. 【R语言】R语言编程规范
  13. PLC常用标志位信号时序编程注意事项
  14. [易飞]包材Forcast四周滚动需求
  15. ZXR10 1809 路由器 1800开启WEB配置界面调试方法
  16. 在疫苗生产、包装、入库、放行、质量管理、电子数据采集/输入应用电子签名
  17. 高中计算机专业教师 教学计划,信息技术教师教学计划
  18. Dplayer实现弹幕功能
  19. 近期金三银四旺季,网上出现各种各样的面试文章跟视频,以下是我整理的一些拙见
  20. 自媒体人如何积累素材?素材整理四步法get

热门文章

  1. Odoo相关资源(持续更新中)
  2. 如何使用报表工具制作条形码报表
  3. android(9)_数据存储和访问3_scard基本介绍
  4. 解决chrome浏览器应用商店排版混乱问题
  5. SAP中税码、税率、税务科目的几个表及其中的勾稽关系
  6. Matlab之数据筛选
  7. Fabric-ca与现有fabric网络组织绑定
  8. 点云八个方向极值点获取
  9. python之父考虑重构python解释器_Python之父考虑重构Python解释器
  10. 机器学习作业(第十八次课堂作业)