Recurrent Models of Visual Attention(RAM)

摘要

在大量图片上使用深度卷积网络的计算量随着图片像素的增加而线性增加。本文提出一个循环神经网络模型,能够选择性地从图片或视频中提取一系列的区域位置,使其有很高的像素,而区域之外的像素很低,实验过程中只对高像素区域进行处理。和卷及神经网络一样,循环神经网络具有平移不变性,但其计算量与图片大小无关。由于所提模型是不可微分的,所以使用强化学习得到基于任务的策略。模型用于手写数字分类,能够得到比卷积神经网络更好的效果,同时在一个动态视觉控制游戏中,模型也能取得很好的控制效果。

背景介绍

面对整个场景时,人类是有选择地分配注意力的,聚焦不同的区域,并将信息整合以得到整个空间的视觉表达,进而指导后面的眼球运动及策略制定。这样能够过滤掉无用信息,将与任务相关的物体放在视野中心。对于视觉注意力的研究一般基于低级的视觉特征及自下而上的过程(比如视觉显著图,见我的博客),但任务导向也是视觉注意力的重要特征。我们将基于注意力的视觉处理过程当成一个控制任务,使其能够应用于静态图,视频,或一个与动态场景实时交互的Agent模块,比如移动机器人。
本文描述了一个端对端(end-to-end)的优化过程,使得模型在一个任务下直接得以训练,并且基于模型在整个时间序列上的决策优化其性能。该过程使用反向传播法训练神经网络的参数,用强化学习中的策略梯度解决控制问题中的不可微问题。

循环注意力模型(The Recurrent Attention Model,RAM)

本文中我们将注意力问题看成是与视觉环境交互的agent基于目标的一系列决策过程。在某一时间点,agent通过限制带宽(视野大小)的传感器观察场景,并提取局部区域的信息,该agent能够:1)控制传感器放置区域,2)也能够通过执行动作影响环境状态。随时间推移,将局部信息整合,为更有效地放置传感器提供决策辅助,每一步,agent都将得到一定数量的奖励(基于agent行为效果),agent的目标就是最大化总的奖励(reward)。具体地,在一个游戏场景中,agent动作对应着游戏杆的操作,而奖励对应着分数。
模型如下图,模型解释见图下面的英文:
 是:随机搜索法

居中数字训练:我们用上述的训练方法成功学习了一个glimpse策略,在MNIST手写字数据集上,用它训练最多7次glimpse的RAM模型,传感器大小为8*8,只能提取数字的部分信息,我们要验证其组合部分信息的能力。我们还训练了一个标准的反馈神经网络,它包含两个隐层,每个256个线性转换单元,输出作为基准值(baseline)。不同模型在测试集上得到的误差率见下表1a,当有6次glimpse的时候,RAM模型的正确率就差不多和传统的在28*28图片上训练的两层全连接神经网络(FC)相媲美了,这证明了RAM模型具备信息组合的能力。
非居中数字训练:如果数字不在居中了,而是像图2a中的那样,可以在其他区域。表1b列出了不同模型在60*60图片上的分类错误率,除了RAM模型和全连接模型(FC)我们还训练了一个卷积神经网络模型,卷基层有16个10*10 的滤波器(卷积核),步长为5,后面跟着一个有256个变换单元(rectifier units)的全连接层。卷基网络、RAM、全连接模型具有相同数目的参数。由于距安吉审计网络具有平移不变性(translation invariance),所以它比全连接层的错误率小,为2.3%左右,但4次glimpse的RAM模型的错误率与它相当,而6次和8次glimpse的RAM错误率基本达到1.9%,这可能是因为RAM模型能够将视野集中在数字上,学习一个平移不变的策略。该实验说明RAM模型能够成功地在一幅大图片中找到不居中的物体。

杂乱的非居中数字训练:一个挑战性的工作是在有干扰的图片中进行分类,除了要分类的数字外,还有其他的杂乱的图形,这些图形是从其他数字的图片下提取的8*8的小块,直接加入待分类图片中。上图2b是加入干扰的非居中数字图片。该实验为了证明注意力机制能够在复杂的环境下忽略不相关信息,而集中于相关部分。下表2a列出了各模型训练60*60带有4块干扰小块图片的分类结果,为了进一步验证模型的效果,还在含有8块干扰小块的100*100图片上进行了实验,结果如表2b,RAM都能达到最低的错误率,且相比于60*60的图片,其计算量并没有增加,而卷积神经网络模型的计算量却线性增加了。
下图中是学习的策略,空心圆圈是初始化的glimpse中心位置,实心圆圈是最终的glimpse中心位置,中间的glimpse位置用绿线连起来了。
动态环境:在一个控制策略游戏中,使用了RAM模型,85%的时间都能达到控制效果,具体见原文。

总结

本文提出了一个全新的视觉注意力模型,循环神经网络以一个glimpse窗口为输入,利用网络的内部状态选择下一个聚焦的位置,并在动态环境中生成一个控制信号。尽管该模型是不可微的,所提统一框架使用一个策略梯度模型从像素到动作进行端对端(end-to-end)的训练。RAM模型有以下两个巨大的优势: 1)计算量与输入图片像素大小没有关系。 2)可以排除其他干扰,将视野放在相关物体上。而且模型本身比较灵活,可以控制glimpse的次数,改变采样区域的大小使其适应各种大小的物体。
具体代码实现见这篇博文,基于torch的代码框架如下:
感觉基于注意力和记忆的计算机视觉是很有前景的研究方向。有没有朋友想在github上开源基于其他深度学习平台(caffe,MXNet)的,这篇文章的python实现版?对torch实在不熟悉啊...

Reference

1 Recurrent Models of Visual Attention, Volodymyr Mnih, Nicolas Heess, Alex Graves, Koray Kavukcuoglu.
2 http://www.cosmosshadow.com/ml/%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C/2016/03/08/Attention.html#_label2_3

视觉注意力的循环神经网络模型相关推荐

  1. 吴恩达深度学习笔记(109)-循环神经网络模型(RNN介绍)

    https://www.toutiao.com/a6652926357133066755/ 2019-02-06 20:15:53 循环神经网络模型(Recurrent Neural Network ...

  2. 1.3 循环神经网络模型-深度学习第五课《序列模型》-Stanford吴恩达教授

    ←上一篇 ↓↑ 下一篇→ 1.2 数学符号 回到目录 1.4 通过时间的方向传播 循环神经网络模型 (Recurrent Neural Network Model) 上节视频中,你了解了我们用来定义序 ...

  3. 基于循环神经网络模型(GRU)的新型冠状病毒肺炎流行趋势预测

    资源下载地址:https://download.csdn.net/download/sheziqiong/85639079 资源下载地址:https://download.csdn.net/downl ...

  4. 【小白学习keras教程】五、基于reuters数据集训练不同RNN循环神经网络模型

    @Author:Runsen 文章目录 循环神经网络RNN Load Dataset 1. Vanilla RNN 2. Stacked Vanilla RNN 3. LSTM 4. Stacked ...

  5. DL-3利用MNIST搭建神经网络模型(三种方法):1.用CNN 2.用CNN+RNN 3.用自编码网络autoencoder

    Author:吾爱北方的母老虎 原创链接:https://blog.csdn.net/weixin_41010198/article/details/80286216 import tensorflo ...

  6. 【Pytorch神经网络实战案例】11 循环神经网络结构训练语言模型并进行简单预测

    1 语言模型步骤 简单概述:根据输入内容,继续输出后面的句子. 1.1 根据需求拆分任务 (1)先对模型输入一段文字,令模型输出之后的一个文字. (2)将模型预测出来的文字当成输入,再放到模型里,使模 ...

  7. 【seq2seq】深入浅出讲解seq2seq神经网络模型

    本文收录于<深入浅出讲解自然语言处理>专栏,此专栏聚焦于自然语言处理领域的各大经典算法,将持续更新,欢迎大家订阅! 个人主页:有梦想的程序星空 个人介绍:小编是人工智能领域硕士,全栈工程师 ...

  8. PyTorch框架:(2)使用PyTorch框架构建神经网络模型---气温预测

    目录 第一步:数据导入 第二步:将时间转换成标准格式(比如datatime格式) 第三步: 展示数据:(画了4个子图) 第四步:做独热编码 第五步:指定输入与输出 第六步:对数据做一个标准化 第七步: ...

  9. 作者解读ICML接收论文:如何使用不止一个数据集训练神经网络模型?

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:欧明锋,浙江大学 导读:在实际的深度学习项目中,难免遇到多个相似数 ...

最新文章

  1. 利用vc的mfc做的Excel表格处理工具
  2. 动态规划(浅层基础)
  3. java 数据类型 存储_Java数据类型以及存储
  4. 使用动画播放文件夹中的图片
  5. 美国团购网站Groupon的盈利模式
  6. Android ImageView设置图片原理
  7. PX4 CMakeLists.txt 文件剖析
  8. Scala与Java混编译:java日志不打印的问题
  9. 念荆轲[原创诗一首]
  10. Bugku-CTF之login3(SKCTF)(基于布尔的SQL盲注)
  11. Box plot (箱形图) 中 quartile (四分位数)原理,及python_matplotlib中Q1和Q3定义的不同
  12. Therefore, hence, so, then, thus
  13. 首次公开,用了三年的 pandas 速查表
  14. 【统计学】【2015.09】基于状态空间模型的时间序列预测与插值
  15. 卷积神经网络中全连接层、softmax与softmax loss理解
  16. 常见七种逻辑门真值表
  17. Cascade CNN
  18. PowerBI visuals共计246组2020年1月31日扒取(Power BI 视觉对象)
  19. nodename nor servname provided, or not known
  20. Android免打包多渠道统计如何实现?附带学习经验

热门文章

  1. 零基础入门金融风控之贷款违约预测挑战赛-task01
  2. 奶牛慢跑 (寒假每日一题 18)
  3. 解决Client.Timeout exceeded while awaiting headers报错
  4. Linux 能替代 Windows 吗?
  5. 计算机机房的网络属于,学校机房的网络属于()。
  6. scrapy 抓取豆瓣Top250书籍信息
  7. 尊敬的用户您好: 您访问的网站被机房安全管理系统拦截,可能是以下原因造成: 1.您
  8. 利用 Python 分析 MovieLens 1M 数据集
  9. 分贝通携手衡石科技,用心护好客户「钱袋子」 增收节流数百万
  10. Rockchip Linux PCIe 开发指南