PredNet阅读笔记——从视频预测的角度学习视频表征
看到有网站未经允许转载了,无奈加个原文地址:http://blog.csdn.net/zhangsipppcsdn/article/details/69907598
ICLR 2017论文《Deep Predicitve Coding Networks for Video prediction and unsupervised Learning》阅读笔记,作者是哈佛大学的William Lotter, Gabriel Kreiman & David Cox。
github:https://github.com/coxlab/prednet
这篇论文采用的也是CNN+LSTM的思路做视频预测,但是对网络结构做了很大调整(如下图),将图像预测误差在网络中前向传递,可以更好地学习到视频表征。
文章是从视频预测的角度设计网络PredNet:为了达到预测视频的目的,需要学习视频的特征表示。然而实验证明,PredNet在视频预测任务表现一般,预测时间短且不够清晰;但在学习视频表征方面表现突出,可以提取物体动态特征,将这些特征用于分类器、参数估算等任务,相比于从静态图像中提取的特征,物体识别准确度会提高。
背景
现有用于物体识别的方法是有监督的,效果非常好。有监督训练需要大量标注图像,使分类器可以在不同角度、背景、光照等条件下识别物体,但这与我们人类对物体的认知习惯不同,我们只需要一个或少数几个角度的物体图像就足够识别出物体。由于大量标注图像难以获取,限制了有监督方法的识别能力,因此需要尝试无监督学习的方法获得物体的特征表示,用于物体识别。
计算机视觉利用静态图像为物体建模,但现实世界中的物体或观察者总是在运动的,物体运动的时序信息也组成了物体特征的一部分,应当构建物体的动态模型。一些研究也尝试过将物体随时间的变化特点加入物体的特征表示,但是识别结果不理想,难以和一般基于图像的有监督方法匹敌。
这里,作者从视频预测的角度去学习物体的时间变化特征。因为想要预测一个物体的变化,本身就需要建立物体内在模型和它的运动模型。人脑的预测是基于不断获取的新图像,不断校正预测结果的。作者据此提出的prednet就设计了这样一个结构,根据产生的预测图像与实际下一帧图像的误差及各层特征图像的误差,去训练网络预测能力。
模型
- Prednet模型如上图。整个网络是上图左半部分在时间、网络层两个维度上的堆叠,右半部分是各step每层网络(称为一个模块)的具体实现。
- 每个模块由四个单元组成:
- AlA_l:输入卷积层,对于第一层,是目标图像;对于更高层,是前一层预测误差E的卷积+relu。
- RlR_l:卷积LSTM层。
- A^l\hat{A}_l:预测层,对R单元卷积+relu得到。
- ElE_l:误差表示层,f(Al−A^l)f{(A_l-\hat{A}_l)}。
- (为什么csdn写出来的公式后面都带一个竖线?)
- 具体地:
- ElE_l单元:由于采用激活函数Relu,AlA_l与A^l\hat{A}_l之差小于零的部分会被置零,因此需要AlA_l与A^l\hat{A}_l相互作差,拼接,再经过Relu层。是L1 loss(作者表示还没有尝试过其他诸如对抗loss之类的其他loss).
- ElE_l传给A^l+1\hat{A}_{l+1},作为下一层的输入,是自下而上的。
- RtlR^t_l单元:接受的输入是前一刻本层误差EtlE^t_l,本层状态Rt−1lR^{t-1}_l,本时刻高层预测特征Rtl+1R^{t}_{l+1}(由上而下),根据这三者进行特征级的预测。预测的特征在A^l\hat{A}_l单元卷积,得到特征图像,与AlA_l相比较。
- 总体loss是各层、各时刻预测误差的加权和。各层误差权重λl\lambda_l,各时刻误差权重λt\lambda_t由实验确定。
- 网络状态更新存在水平(时间)方向和竖直(各层)方向两方向的更新。竖直方向先更新,先自下而上前向传播计算得各层误差ElE_l,再自上而下计算RNN单元的状态RlR_l。t时刻网络更新好后,进行t+1时刻的更新。因此对于各t的网络,输入是前一刻RNN状态Rt−1R^{t-1},和本时刻目标输出图像A0A_0。
- 具体更新规则:
按照时间展开大概是这样的(2017/5/9改,之前图片有问题,按时间展开应该没有反向的箭头了):
本文LSTM的代码实现
PredNet各个时刻的网络,四个部分作为一个整体,可以看做一个完整的LSTM层,即没有堆叠,按时间递归循环的RNN。也就是下图的一个A模块。因此也可以理解,为什么更新状态量时先垂直再水平。
类似于LSTM,该模块除了输入输出外,还有状态量在各时刻间传递。
- 输入:
- 本时刻目标输出At0A^t_0
- 状态量(R单元所需):各层Rt−1R^{t-1}、Et−1E^{t-1}以及LSTM单元内部状态C
- 输出:
- 状态量
- 根据不同需求,可以输出三种形式
- 训练时:error mode,输出各层的平均误差,1维向量形式(各特征、样本间的平均,每层一个标量)。
- 测试时:prediction mode,最底层输出prediction图像
- 观测、调试网络时,mode=其他,可以根据需要输出某中间层的特征图像
代码
我仔细阅读了作者github上的代码,是keras的。大部分代码都是按照论文描述搭建模型,比较关键的是作者重载了rnn中的step()函数,实现了作者自己搭建的LSTM层。这样的LSTM层只需要一层,不用multilayer堆叠了。
下面注释是根据我对网络的理解写的,完整的代码注释太长就不放了。def step(self, a, states): # 重载rnn中的step r_tm1 = states[:self.nb_layers] # 读取输入的R、E、C(上时刻状态) c_tm1 = states[self.nb_layers:2*self.nb_layers] e_tm1 = states[2*self.nb_layers:3*self.nb_layers]if self.extrap_start_time is not None:t = states[-1]a = K.switch(t >= self.t_extrap, states[-2], a) # if past self.extrap_start_time, the previous prediction will be treated as the actualc = [] r = [] e = [] # R Unit for l in reversed(range(self.nb_layers)): # 由于R的计算需要前时刻和高一层的R,因此需要由上向下进行计算inputs = [r_tm1[l], e_tm1[l]]if l < self.nb_layers - 1:inputs.append(r_up) # 除了最高层,前面的输入都是R_t-1,R_l+1,E,以及隐含的状态C# 标准LSTM过程inputs = K.concatenate(inputs, axis=self.channel_axis) # 把各个特征图放到一起i = self.conv_layers['i'][l].call(inputs) # 按照相应的卷积门尺寸卷积f = self.conv_layers['f'][l].call(inputs)o = self.conv_layers['o'][l].call(inputs)_c = f * c_tm1[l] + i * self.conv_layers['c'][l].call(inputs) # c_t = f*c_t-1 + i*tanh(inputs)_r = o * self.LSTM_activation(_c) # r_t = o*tanh(c_t)c.insert(0, _c)r.insert(0, _r)if l > 0:r_up = self.upsample.call(_r) # 上采样for l in range(self.nb_layers):ahat = self.conv_layers['ahat'][l].call(r[l]) # Ahat是R的卷积if l == 0:ahat = K.minimum(ahat, self.pixel_max) # 第一层,Ahat限幅,准备作为输出图像frame_prediction = ahat # 当output_mode == 'prediction'时输出# compute errorse_up = self.error_activation(ahat - a)e_down = self.error_activation(a - ahat)e.append(K.concatenate((e_up, e_down), axis=self.channel_axis))if self.output_layer_num == l:if self.output_layer_type == 'A':output = aelif self.output_layer_type == 'Ahat':output = ahatelif self.output_layer_type == 'R':output = r[l]elif self.output_layer_type == 'E':output = e[l]if l < self.nb_layers - 1:a = self.conv_layers['a'][l].call(e[l])a = self.pool.call(a) # target for next layerif self.output_layer_type is None:if self.output_mode == 'prediction':output = frame_predictionelse:for l in range(self.nb_layers):layer_error = K.mean(K.batch_flatten(e[l]), axis=-1, keepdims=True) # 各层平均误差,每层一个数all_error = layer_error if l == 0 else K.concatenate((all_error, layer_error), axis=-1)if self.output_mode == 'error':output = all_errorelse:output = K.concatenate((K.batch_flatten(frame_prediction), all_error), axis=-1)states = r + c + e if self.extrap_start_time is not None:states += [frame_prediction, t + 1] return output, states
实验
实验1:测试网络的预测误差
- 测试数据:合成的人脸3D图像,加了两个方向的随机旋转,得到旋转人脸视频
- metrics:各帧平均MSE,SSIM
- 对照设置:
- 普通编码-解码模式的卷积LSTM,传递预测图像A,而非误差。记为E.CNN-LSTM Enc.-Dec.
- 直接复制前一帧 Copy Last Frame
- 只保留最底层误差权重λ0\lambda_0,其余置0,记为PredNet L0L_0
- 各层误差都有权重,λ0=1\lambda_0=1,其他层均小一数量级,即0.1,记为PredNet LallL_{all}
- 结果:
PredNet生成的预测图像。注意到,PredNet需要不断喂数据,根据目标图像调整预测图像,在几步之后达到可以准确预测下一帧的效果,实际上仍是next-frame-prediction.
表一可见,使用L0L_0权重方案的PredNet对于下一帧图像的预测准确度最高。尤其是对于结构相似性指标SSIM的提升很明显,表明在结构水平上预测更准确。
实验2:测试网络学习隐含变量的能力
- 测试数据:仍是合成的人脸3D图像
- 将网络学习到的特征(各层R单元输出)拼接起来,作为一幅图的整体特征,输入一个全连接网络,进行参数学习的任务。
- 对于旋转人脸图像,参数(隐变量)包括:人脸主分量、初始旋转角度、旋转角速度。
- 评价指标:参数的准确度。
- 对照设置:提取R单元各层step2,3的特征图像,与未经训练的网络进行比较
- 结果如下图左:
实验2拓展:利用提取到的特征进行图像分类任务
- baseline: 自编码器、Ladder Network,使用重建误差训练,根据静态图像提取的特征
- 线性SVM分类器,对不同的3D人脸图像进行分类
- 实验结果见上图右
可以观察到使用PredNet学习到的特征,可以更好地完成参数回归、图像分类任务。这说明PredNet提取到的特征可以泛化到其他任务中。
还要注意到,在这两个任务中,使用LallL_{all}权重方案得到的参数回归更准确,提醒我们,对于不同的任务,要考虑调整PredNet各层误差权重。
实验3:处理真实场景图像
- 训练数据集:kitti,车载摄像头的录像集,同时记录了摄像头的运动和场景中物体的运动
- 测试数据集:CalTech Pedestrian
- baseline:CNN-LSTM Enc.-Dec.
- 具体定量比较结果见论文。一句话就是PredNet的预测误差小于CNN-LSTM。
- 为了验证提高不是来自于网络参数的不同,作者又采用了四组不同参数进行实验,PredNet平均预测误差仍低于CNN-LSTM 14.7%.
- 结果图
在不同场景下的预测都很精确。可以预测车辆的运动,还可以补全车开走后的空缺(第一组图)。还可以预测自己的运动——准确预测树影的变化(第2、5组图)。
预测实验
前面的实验实际需要不断输入图像,达到准确预测下一帧的效果,不能解决长时间视频预测的任务。
论文附录5.3中进一步展示了关于预测的实验。将各步预测输出作为网络输入,就可以预测很多步。结果显示,直接使用所述PredNet预测,效果不好且模糊。
但是用同样的网络结构,针对预测问题调整网络参数,如增加预测的时长,输入10帧图像,以后5帧作为输入喂回网络,得到共15帧的预测输出,各步的loss权重也有所调整。这样针对性调整后的网络,预测能力增强很多,也不那么模糊。
我自己运行了一下代码,用kitti数据集做预测。猜猜我会不会上传动图!
总结(我的看法)
本文提出的PredNet网络,因为需要不断喂目标输出,才能保证输出与输入相似。解决长时间视频预测的任务还不够好。所以我认为它作为物体的视频特征学习的方法的意义>视频预测方法。
本文提出的传递预测误差的CNN+LSTM结构的网络还是很新颖且有效的。相比于更常见的“编码-LSTM预测变化——解码”的结构,相当于把LSTM的作用范围扩大了,把编码和解码过程都放到LSTM里面了,各单元关系保留更多,但好像更不容易分块观测和提高各单元能力。
作者进行了大量的实验,比较网络各部分的设计细节,从网络参数、权重到各单元存在的必要性,都经过实验得到最优设置,或验证单元的存在是合理有效的。非常严谨,学习一下。严谨到论文正文的实验全是与baseline的比较,与state of the art的预测效果比较只能放到附录中了。。。
我的程序注释:https://github.com/hello-world-zsp/prednet/tree/master
PredNet阅读笔记——从视频预测的角度学习视频表征相关推荐
- 想自学python看哪位的视频比较好-python学习视频好的有哪些
原标题:python学习视频好的有哪些 Python视频,一般找的同学事打算自学的,但作为曾经自学过Python的一员,想给你提几点学习Python的建议,希望你的Python学习之路能平坦些. 第一 ...
- 论文阅读笔记(3)---基于深度学习的节律异常或传导阻滞多标签心电图自动诊断
论文地址:Automatic multilabel electrocardiogram diagnosis of heart rhythm or conduction abnormalities wi ...
- 讲解c语言算法的视频,c语言算法学习视频
[教程介绍] c语言看着简单,但等你深入学习以后,你就会发现需要学习的知识很多很多. 对于程序设计,算法是特别重要的,极为核心的一个知识. 在我们这部C语言算法教程,讲解的知识点极多,主要内容包括基本 ...
- 基于深度学习的视频预测研究综述
原址:http://html.rhhz.net/tis/html/201707032.htm (收集材料ing,为论文做准备)[综述性文章,,,可以做背景资料] 莫凌飞, 蒋红亮, 李煊鹏 摘要:近年 ...
- 基于深度学习的表面缺陷检测方法综述-论文阅读笔记
//2022.3.2日阅读笔记 原文链接:基于深度学习的表面缺陷检测方法综述 (aas.net.cn) 个人对本篇综述内容的大致概括 论文首先介绍了表面缺陷检测中不同场景下的成像方案,主要根据表面颜色 ...
- f-AnoGan阅读笔记
声明 本人是做织物瑕疵检测的,处理的小数据量,想从这篇文章找灵感,记下自己的阅读笔记.然后是接触深度学习的小萌新,也是第一次写博客,若有错误,请指证不胜感激. 补充一些内容 在此之前先了解下它的前身A ...
- Multi-Task Video Captioning with Video and Entailment Generation阅读笔记
这篇文章提出多任务学习去优化Video Captioning框架,模型框架图如上所示,共3个任务. 其中,UNSUPERVISED VIDEO PREDICTION(无监督视频预测):一个视频由n帧组 ...
- 一个模型通杀8大视觉任务,一句话生成图像、视频、P图、视频处理...都能行 | MSRA北大出品...
丰色 发自 凹非寺 量子位 报道 | 公众号 QbitAI 有这样一个模型. 它可以做到一句话生成视频: 不仅零样本就能搞定,性能还直达SOTA. 它的名字,叫"NüWA"(女娲) ...
- 西门子系列全套学习视频,免费领取!
这是一套西门子全套视频专辑,S7200,S7300,S71200,S71500,sinamic,sv20变频器,mm440系列变频器,winCC,触摸屏HMI系列教程!! 视频内容包含哪些? 这是西门 ...
最新文章
- window10+python3.7安装tensorflow--gpu tensorflow 安装
- reverseString
- mysql 导入txt数据到数据表【原创】
- 计算机硕士工资情况收集
- 跟我一起学.NetCore之中间件(Middleware)简介和解析请求管道构建
- mysql 转型_MySQL的未来在哪?
- esxi 5.5运行linux拯救模式,启用Esxi 5.5 SSH 功能
- Map3D中获取地图中心及Zoom到新的中心点
- POJ3262 Protecting the Flowers【贪心】
- Python+Flask+MysqL的web建设技术过程
- 学计算机的大学生买什么U盘,大学生最容易丢的几样东西,最后一件最让人着急,网友:真实了...
- jeDate实现日期联动
- 计算机怎么更改网络密码,该如何修改自家宽带帐号的密码?
- 10个优秀免费高清素材图库相册:各类美图应有尽有
- Java Shiro 设置 anon 无效
- 献给加班的各位同仁,祝工作顺利
- java 找不到方法_Java程序找不到主方法,在哪里加上呢
- Paragraph 对象'代表所选内容、范围或文档中的一个段落。Paragraph 对象是 Paragraphs 集合的一个成员。Paragraphs 集合包含所选内容、范围或文档中的所有段落。...
- ORACLE分区表查询
- Linux下内存检测工具:asan
热门文章
- java 字符串截取最后一位,获取最后一位前面的字符串
- 支持刷机(OpenWrt)的路由器大全
- r55600g和r75700g差别大吗 r55600g和r75700g区别
- 测试环境和测试分类的介绍
- 请至少列举 5 个 PEP8 规范(越多越好)
- [Android应用]《幽默笑话》V2.0 正式版震撼发布!
- java canvas 描边,HTML5 Canvas如何实现纹理填充与描边(Fill And Stroke)
- eclipse左侧栏目即包资源管理器怎么打开
- <From Zero to Hero>零基础学习Python基础语法【条件判断与条件嵌套】
- “一天宕机三次”,为什么高并发这么难?