PredRNN++:网络结构和代码解读
已经有很多帖子对PredRNN++的理论和改进效果进行了解读,不再赘述。直接分析结构和代码。
Causal LSTM 单元
三层级联结构:
第一层(蓝色框)类似传统的LSTM结构用于更新时间状态C(temporal state) ;
第二层(绿色框)也是类似的结构,利用更新的C和上一层的M更新空间状态M(spatial state);
第三层(灰色框)是输出门的结构,利用X、新的C和M来更新H (hidden state)。
所以一个Causal LSTM 细胞单元的输入为X、H、C和M,输出为更新的H、C和M。
最后一个时间步的H用于预测生成目标序列。
根据上面的公式可以细看一下。
第一层根据当前步的X和前一步的H、C来决定门,然后获得当前步的C。
第二层利用刚更新的当前步的C、X和前一层的M来决定第二层的门。在对前一层的M做tanh非线性变换后获得更新的M。M、o和H的更新都使用了tanh,增加了非线性操作。
第三层利用X、更新的当前步的C和当前层的M来更新输出门。在作者源代码中还加入了前一步的H,对预测准确度的影响有多大未经验证。又或者是公式里遗漏了(但根据框图看不像是漏了)。在ST-LSTM结构中输出门确实是用到H的。
# CausalLSTMCell.py 124-127行。
if x is None: o = tf.tanh(o_h + o_c + o_m)
else:o = tf.tanh(o_x + o_h + o_c + o_m)
最后利用输出门、C和M来更新H,输出给下一步。
GHU
理论证据表明,highway layers能够在非常深的前馈网络中有效地传递梯度,所以作者将这一思想应用到递归网络中,以防止长期梯度的快速消失,并提出了一种新的时空递归结构GHU,结构如下:
S是转换门,决定在P和Z之间各保留多大比例。这里的X相当于hidden state H,Z相当于更新后的H。
PredRNN++总体结构
具体看一下H、C、M和Z的传递过程:
H
H既有横向传递也有纵向传递。H在同一时间步的传递(纵向)是作为input(X)进行传递。在时间维度上是作为hidden state进行传递。
C
C是只在同层的时间步直接进行传递。
M
如果将不同时间步的网络单元全部堆叠起来就容易发现是一个M在持续向下传递。
Z
在同一时间步作为H1的更新向下一层传递一次,主要在横向进行传递,将长时特征传递给未来。
代码解读
输入
输入为5D数据,格式为(B,T,H,W,C)分别代表batch size,时间步数,高,宽,通道(channel)。对于有监督学习我们通常将数据分为输入X和对应标签y。以输入前10步预测后10步为例,对于一个样本X=(1:10,H,W,C),输出y_hat = (11:20,H,W,C), 然后将y_hat和y进行损失计算,反向传播更新参数。
但在这里作者没有区分y,X=(1:20,H,W,C)。而是利用一个0-1掩码矩阵随机选择生成的数据和真实数据来作为预测下一步的输入,增加模型的稳健性。
# train.py 75-87行。
# x的维度为(1,20,16,16,16)
self.x = tf.placeholder(tf.float32,[FLAGS.batch_size,FLAGS.seq_length,FLAGS.img_width/FLAGS.patch_size,FLAGS.img_width/FLAGS.patch_size,FLAGS.patch_size*FLAGS.patch_size*FLAGS.img_channel])
# mask_true的维度为(1,9,16,16,16)
# 为什么是9步呢,因为最后一步肯定是要被预测出来的,不会用作输入。
self.mask_true = tf.placeholder(tf.float32,[FLAGS.batch_size,FLAGS.seq_length-FLAGS.input_length-1,FLAGS.img_width/FLAGS.patch_size,FLAGS.img_width/FLAGS.patch_size,FLAGS.patch_size*FLAGS.patch_size*FLAGS.img_channel])
在训练时,每一次循环都重置mask_true。
# train.py 173-199行。if itr < 50000:eta -= deltaelse:eta = 0.0# random_flip 是一个(1,9)的[0,1]区间随机数向量。random_flip = np.random.random_sample((FLAGS.batch_size,FLAGS.seq_length-FLAGS.input_length-1))# 判断哪些比eta小,为True则使用真实数据否则用生成的数据。true_token = (random_flip < eta)# 作者尝试过利用指数函数衰减的方式作为阈值判断。#true_token = (random_flip < pow(base,itr)) # ones = (16,16,16)ones = np.ones((int(FLAGS.img_width/FLAGS.patch_size),int(FLAGS.img_width/FLAGS.patch_size),FLAGS.patch_size**2*FLAGS.img_channel))# zeros = (16,16,16)zeros = np.zeros((int(FLAGS.img_width/FLAGS.patch_size),int(FLAGS.img_width/FLAGS.patch_size),FLAGS.patch_size**2*FLAGS.img_channel))mask_true = []for i in range(FLAGS.batch_size):for j in range(FLAGS.seq_length-FLAGS.input_length-1):if true_token[i,j]:mask_true.append(ones)else:mask_true.append(zeros)# mask_true = (9,16,16,16)mask_true = np.array(mask_true)# mask_true = (1,9,16,16,16)mask_true = np.reshape(mask_true, (FLAGS.batch_size,FLAGS.seq_length-FLAGS.input_length-1,int(FLAGS.img_width/FLAGS.patch_size),int(FLAGS.img_width/FLAGS.patch_size),FLAGS.patch_size**2*FLAGS.img_channel))
随着循环次数增加,使用真实数据的比例不断下降。但是如果设置的最大循环次数过低,eta下降的幅度太小的话会出现每次都用真实值作为输入的情况。论文中提到在KTH action 数据集上设定的循环次数是200000,在MNIST数据集上未提及循环次数,代码中设定的是80。
# predrnn_pp.py 40-43行。
if t < input_length:inputs = images[:,t]
else:inputs = mask_true[:,t-10]*images[:,t] + (1-mask_true[:,t-10])*x_gen
# mask_true[:,t-10]为ones矩阵的话取真实数据images[:,t],否则取预测数据x_gen.
reverse_input
如果为True,则在训练时将数据倒序后再训练一次,误差取两次的平均。
cost = model.train(ims, lr, mask_true)
if FLAGS.reverse_input:ims_rev = ims[:,::-1]cost += model.train(ims_rev, lr, mask_true)cost = cost/2
输出
先看代码
# predrnn_pp.pygen_images = []for t in range(seq_length-1): #predrnn++网络结构,略过后面讲# x_gen为生成的预测帧x_gen = tf.layers.conv2d(inputs=hidden[num_layers-1],filters=output_channels,kernel_size=1,strides=1,padding='same',name="back_to_pixel")gen_images.append(x_gen)gen_images = tf.stack(gen_images)# [batch_size, seq_length, height, width, channels]gen_images = tf.transpose(gen_images, [1,0,2,3,4])loss = tf.nn.l2_loss(gen_images - images[:,1:])#loss += tf.reduce_sum(tf.abs(gen_images - images[:,1:]))return [gen_images, loss]
这里有意思了,作者其实是搭了一个19步的网络结构,输入是20帧,输出是19帧,损失函数计算的也是19帧的误差。而且利用的不是前10帧预测下一帧,而是越往后利用的信息越多,即预测第t+1帧时利用的是前面所有帧的信息。但是前面9帧不属于要预测的范围,属于输入,而且利用的是更少的数据预测出来的,作者在这里也把他们加入到损失计算里了。正常来讲我们只需计算后10帧预测的损失。这部分读者可以根据自身需求进行适当改写。
CausalLSTMCell
可以先看一下公式右侧出现的变量各参与了几次卷积操作。
变量 | 卷积次数 |
---|---|
X_t | 7 |
C_t-1 | 3 |
H_t-1 | 4(输出门里也有一个) |
M_k-1 | 4 |
C_t | 5 |
M_k | 2 |
总计 | 25 |
弄清楚了这个就容易看懂代码了,举个例子如下:
x_cc = tf.layers.conv2d(x, self.num_hidden*7,self.filter_size, 1, padding='same',kernel_initializer=self.initializer,name='input_to_state')i_x, g_x, f_x, o_x, i_x_, g_x_, f_x_ = tf.split(x_cc, 7, 3)# 这里的x_cc有7个卷积,最后被分为7份参与7个门的计算。
layer_normalization
if self.layer_norm:x_cc = tensor_layer_norm(x_cc, 'x2c')
#相当于
from tensorflow.python.keras.layers.experimental import LayerNormalization
x_cc = LayerNormalization(axis=[1,2,3])(x_cc)
GradientHighwayUnit
def __call__(self, x, z):if z is None:z = self.init_state(x, self.num_features)with tf.variable_scope(self.layer_name):# z和x各卷积2次,生成z_concat和x_concat,用于P,S的计算。z_concat = tf.layers.conv2d(z, self.num_features*2,self.filter_size, 1, padding='same',kernel_initializer=self.initializer,name='state_to_state')if self.layer_norm:z_concat = tensor_layer_norm(z_concat, 'state_to_state')x_concat = tf.layers.conv2d(x, self.num_features*2,self.filter_size, 1, padding='same',kernel_initializer=self.initializer,name='input_to_state')if self.layer_norm:x_concat = tensor_layer_norm(x_concat, 'input_to_state')gates = tf.add(x_concat, z_concat)#这里的u就是论文中的S(switch gate).p, u = tf.split(gates, 2, 3)p = tf.nn.tanh(p)u = tf.nn.sigmoid(u)z_new = u * p + (1-u) * zreturn z_new
predrnn_pp
def rnn(images, mask_true, num_layers, num_hidden, filter_size, stride=1,seq_length=20, input_length=10, tln=True):gen_images = []lstm = []cell = []hidden = []shape = images.get_shape().as_list()output_channels = shape[-1]# 定义num_layers=4个cslstm层。初始化cell state 和hidden state 都为None.for i in range(num_layers):if i == 0: # 如果是第一层则传入最后一层的hidden数,实现的是M的z字形传递。num_hidden_in = num_hidden[num_layers-1]else:num_hidden_in = num_hidden[i-1]new_cell = cslstm('lstm_'+str(i+1),filter_size,num_hidden_in,num_hidden[i],shape,tln=tln)lstm.append(new_cell)cell.append(None)hidden.append(None)# 定义ghu层#在这里每个时间步上有5个层,其中4个cslstm层和一个ghu层,ghu处在第一个和第二个cslstm中间gradient_highway = ghu('highway', filter_size, num_hidden[0], tln=tln)# 初始化M和Z,它们会不断更新。mem = Nonez_t = None# 下面是predrnn的完整结构# 19步,每步5层。for t in range(seq_length-1): reuse = bool(gen_images)with tf.variable_scope('predrnn_pp', reuse=reuse):# 决定每一时间步的输入Xif t < input_length:inputs = images[:,t]else: #随机选择利用真实数据还是上一步刚生成的预测数据。inputs = mask_true[:,t-10]*images[:,t] + (1-mask_true[:,t-10])*x_gen # outputs,第二部分好像都为0没有意义##################################################################
#这部分就是同一时间步内的5层网络结构
# z_t每一个时间步更新一次
# mem沿着层级和时间步一直更新
# hidden 和 cell在同层之间沿着时间步更新hidden[0], cell[0], mem = lstm[0](inputs, hidden[0], cell[0], mem)z_t = gradient_highway(hidden[0], z_t)hidden[1], cell[1], mem = lstm[1](z_t, hidden[1], cell[1], mem)for i in range(2, num_layers):hidden[i], cell[i], mem = lstm[i](hidden[i-1], hidden[i], cell[i], mem)
################################################################### 利用最后一个hidden state 进行2D卷积生成预测序列。x_gen = tf.layers.conv2d(inputs=hidden[num_layers-1],filters=output_channels,kernel_size=1,strides=1,padding='same',name="back_to_pixel")gen_images.append(x_gen)gen_images = tf.stack(gen_images)# [batch_size, seq_length, height, width, channels]gen_images = tf.transpose(gen_images, [1,0,2,3,4])#对预测的19帧进行了l2损失计算。loss = tf.nn.l2_loss(gen_images - images[:,1:])#loss += tf.reduce_sum(tf.abs(gen_images - images[:,1:]))return [gen_images, loss]
PredRNN++:网络结构和代码解读相关推荐
- Inception代码解读
Inception代码解读 目录 Inception代码解读 概述 Inception网络结构图 inception网络结构框架 inception代码细节分析 概述 inception相比起最开始兴 ...
- AlexNet代码解读
AlexNet代码解读 目录 AlexNet代码解读 概述 网络结构图 AlexNet代码细节分析 概述 AlexNet的网络结构很简单,是最初级版本的CNN,没有使用什么技巧. 网络分成两个部分,分 ...
- Resnet的pytorch官方实现代码解读
Resnet的pytorch官方实现代码解读 目录 Resnet的pytorch官方实现代码解读 前言 概述 34层网络结构的"平原"网络与"残差"网络的结构图 ...
- Memory-Associated Differential Learning论文及代码解读
Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论 ...
- ResNet及其变种的结构梳理、有效性分析与代码解读(PyTorch)
点击我爱计算机视觉标星,更快获取CVML新技术 本文来自知乎,作者费敬敬,现为同济大学计算机科学与技术硕士. https://zhuanlan.zhihu.com/p/54289848 温故而知新,理 ...
- 编译原理语义分析代码_Pix2Pix原理分析与代码解读
原理分析: 图像.视觉中很多问题都涉及到将一副图像转换为另一幅图像(Image-to-Image Translation Problem),这些问题通常都使用特定的方法来解决,不存在一个通用的方法.但 ...
- BigGAN代码解读(gpt3.5帮助)——生成器部分
代码来源于Github中点赞最多的BigGAN复现 作者个人学习记录 BigGAN的生成器代码内部引用了代码人员编写的谱正则化(SN)以及批正则化(BN),关于这部分的解读地址在这里: 批正则化 谱正 ...
- mask rcnn 超详细代码解读(一)
mask r-cnn 代码解读(一) 文章目录 1 代码架构 2 model.py 的结构 3 train过程代码解析 3.1 Resnet Graph 3.2 Region Proposal Net ...
- siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency
阅前说明 前面已经出现的代码用 - 代替. 本文仅解析train部分的代码(inference的部分会后续更新). 不对网络结构做过多解释,默认已经熟悉 mrcnn 的结构以及读过这篇论文了. 另:i ...
最新文章
- [IE9] 解决了傲游、搜狗浏览器在IE9下网页截图的问题
- 13.2System类中的常用方法
- 机器学习笔记(十一)----降维
- 医学专业失业率最高 三类相关行业人才紧缺
- 最新自动发卡网源码V7.0
- mysql 数据库设计实例_一个简单数据库设计例子
- jq UI-引入、拖动效果、api文档位置
- python 办公自动化 视频教程_Python自学爬虫/办公自动化视频教程
- echart:legend中显示value+自定义文字样式
- 高校三维地图校内导航系统解决方案
- 请冷静地对待手中的EOS——EOS数据分析
- 【R语言】R语言编程规范
- PLC常用标志位信号时序编程注意事项
- [易飞]包材Forcast四周滚动需求
- ZXR10 1809 路由器 1800开启WEB配置界面调试方法
- 在疫苗生产、包装、入库、放行、质量管理、电子数据采集/输入应用电子签名
- 高中计算机专业教师 教学计划,信息技术教师教学计划
- Dplayer实现弹幕功能
- 近期金三银四旺季,网上出现各种各样的面试文章跟视频,以下是我整理的一些拙见
- 自媒体人如何积累素材?素材整理四步法get