阅读大概需要 9分钟!

导读

本文讨论了最新爆款论文(Training RNNs as Fast as CNNs)提出的LSTM变种SRU(Simple Recurrent Unit),以及基于pytorch实现了SRU,并且在四个句子分类的数据集上测试了准确性以及与LSTM、CNN的速度对比。

一.为什么要提出SRU?

  • 深度学习的许多进展目前很多均是来源于增加的模型能力以及相关的计算,这经常涉及到更大、更深的深层神经网络,然而,虽然深层神经网络带来了明显的提升,但是也耗费了巨大的训练时间,特别是在语音识别以及机器翻译的模型训练上,要想获得一个最优的模型,往往要耗费几天的时间。

  • 为了解决训练模型的计算能力,像利用GPU进行加速训练的并行化方法在深度学习领域已经广泛使用,使用GPU进行加速的卷积神经网络在训练速度上有提升的很明显,但是,像RNN、LSTM却无法实现并行化方法,熟悉RNN、LSTM的人都知道,在其典型的实现中,要想计算 ht必须等到前一时刻ht-1计算完成,这明显的限制了其实现并行化处理,然而论文提出的简单循环单元(SRU)解除了这种限制,ht 的计算不在依赖于前一时刻的计算,这样就可以实现并行化处理,训练速度要比LSTM快,能够达到与CNN的一样的训练速度。

二.SRU实现及其优化

1、SRU实现

熟悉LSTM和GRU的人都知道,它们是根据神经门来控制信息流来缓解梯度消失与梯度爆炸问题,所以,接下来我们看一下典型的SRU实现。
我们首先对输入的x进行简单的线性变换:

接下来计算遗忘门(forget gate)和 输入门,他们两个都是Sigmoid门:

接下来我们计算c,在计算c的过程中,我们使用了共轭表达式 it = 1 - ft 来简化运算: 

最后,我们把c传递给激活函数g来计算最终的输出h:

以上就是SRU的经典实现,熟悉LSTM的人一定能够看出来,这样的SRU与LSTM一样都是依赖于前一时刻的计算,这样的做法没有什么意义,接下来我们我们在对其进一步的改进。

SRU的实现中添加了两个附加的特征:

  • Skip Connection

    具体来说,skip connection就是Highway Connection,对训练深层神经网络很有效果,我们来具体看一下公式:
    先设置一个重置门( reset gate),和遗忘门、输入门一样都是Sigmoid门:

    然后利用Skip Connection,ht’ 就是最后的输出:

    在后文的测试中,为什单层的SRU很难达到与LSTM相同的效果,而堆叠起来的多层SRU能够达到与LSTM相差无几甚至更好的效果,这里起到了很大的作用。

  • Variational dropout

    为了RNN的正则化除了使用标准的dropout外,还使用了Variational dropout,Variational dropout 在不同的时间步骤 t 上共享 dropout mask。在 RNN 每一个矩阵乘法计算中(即 W * drop(x)),mask 需要应用到输入 x。标准的 dropout 是在 h上执行的,即没有馈送到高速连接的输出状态。

2、SRU加速优化

根据上文中的公式看出 ft 、 rt 都与 ht-1 有关,也就是要想计算 ht 必须等到前一时刻ht-1计算完成,这样就破换了并行性和独立性,无法实现并行化处理,针对此问题,提出了完全drop连接,就是去除了 ht-1 的依赖,以下是SRU的公式:

从上述(8)、(9)、(10)三个公式中可以看出,已经解除了ht-1 的依赖,这样依赖就可以实现程序的并行化处理,而公式(11),(12)能够非常迅速和简洁的执行计算,因为它们的运算都是对应元素之间的操作。

3、CUDA优化

在上述公式8 — 10中,虽然解除了前一时刻的依赖,但是仍然存在一定的瓶颈,就是三个矩阵乘法的运算,在这里提供了更深的优化策略。

  • 矩阵乘法在所有的时间步骤中可以进行批处理,可以显著的提高计算的强度和提高GPU的利用率,在8 — 10 的公式中,可以把矩阵乘法可以合成一个,以后的处理就可以根据索引查找,具体如下:

  • 对于序列中的元素间的操作可以编译合并到一个内核函数中并在隐藏维度上并行化。

三.基于pytorch实现SRU Networks

1、SRU Networks Structure Diagram

熟悉LSTM的人很容易理解SRU的网络结构图,下图是SRU的网络结构图:
xt 代表 t 时刻的输入;
W、b 代表权重和偏置;
ft 代表 t 时刻的遗忘门(forget gate);
rt 代表 t 时刻的重置门(reset gate);
ct 和 ht 分别代表 t 时刻的状态和最终的输出;
σ 和 g 分别代表Sigmoid函数和激活函数(tanh、relu);
公式中的 ⊙ 代表矩阵对应元素间的操作;

2、基于pytorch实现SRU Formula

pytorch搭建神经网络一般需要继承nn.Module这个类,然后实现里面的forward()函数,现在搭建SRU Networks需要另外写一个SRU Cell 类,Cell 里面实现SRU的全部运算,具体代码如下:

SRU_Formula类:

SRU Cell类:

在这里我实现了多层的SRU搭建,对于维度不等的经过线性转换(Linear),以下是这部分的代码:

calculate one layer 函数实现了SRU的计算:

以上是SRU的公式实现,由于代码没有进行CUDA优化也没有进行并行化处理,所以速度上并没有明显的改变。

Github链接:https://github.com/bamtercelboo/pytorch_SRU

3、调用论文代码实现SRU

由于论文封装的代码比较不错,可以像LSTM一样简单调用:

其中cuda_functional是论文中已经封装好的SRU,在这里SRU实现了CUDA的优化,并对程序进行了并行化处理,所以速度上有了明显的提升,下文的测试也是基于此SRU与pytorch优化过的LSTM、CNN进行对比,测试结果参考下文。具体的使用可以参考论文的Github,以下是链接:

Github链接:https://github.com/bamtercelboo/pytorch_SRU

Paper Github链接:https://github.com/taolei87/sru/tree/master/classification

四.实验结果

1、数据集

本次实验任务是情感分类任务(二分类),数据来源于MR(电影评论数据集)、CR(客户对各种产品评价的数据集)、Subj(主观性数据集)以及Twitter情感分类数据集,以下是各个数据集的详细信息:

下图是MR、CR、Subj数据集的详细信息,测试采用十折交叉验证,下载数据从 Github:https://github.com/harvardnlp/sent-conv-torch/tree/master/data 

下图是Twitter情感分类数据集的详细信息:

2、SRU、LSTM、CNN准确率对比

以下实验结果是在CR、Subj、MR、Twitter四个句子分类数据集上测试的结果:

实验结果:在四个数据集上SRU与LSTM的准确率相差不大,有的数据集(像CR、Subj)一层的SRU效果就能达到一层LSTM的效果,但是在MR、Twitter数据集上一层的效果反而不是很好,需要叠加多层SRU才能达到LSTM一层的效果,这与上文提及的Highway Connection有很大的关系。

3、SRU、LSTM、CNN速度对比

以下实验结果是在Twitter数据集上对forward和backward测试的平均运行时间,其中SRU、LSTM、CNN都是经过CUDA优化的,CNN的kernel-size=3,SRU和LSTM的隐层维度是300,三个模型的batch size是16,以毫秒为单位计算,图中SRU-1代表一层的SRU模型:

实验结果:从上述实验结果能够说明在句子分类任务上,单层的SRU能够达到与CNN相同的速度,比LSTM快2 — 3倍;上文测试需要4层SRU才能达到一层LSTM效果的情况下,4层SRU能与一层LSTM的达到相同的速度。

References

[1] Tao Lei and Yu Zhang. Training RNNs as Fast as CNNs. arXiv:1709.02755, 2017.
[2] James Bradbury, Stephen Merity, Caiming Xiong, and Richard Socher. Quasi-recurrent neural networks. In ICLR, 2017.
[3] Yarin Gal and Zoubin Ghahramani. A theoretically grounded application of dropout in recurrent neural networks. In Advances in Neural Information Processing Systems 29 (NIPS), 2016.
[4] Jeremy Appleyard, Tomas Kocisky, and Phil Blunsom. Optimizing performance of recurrent neural networks on gpus. arXiv preprint arXiv:1604.01946, 2016.


欢迎关注深度学习自然语言处理公众号,我会在这里记录自己在路上的一点一滴!再小的人也有自己的品牌!期待和你一起进步!

【干货】神经网络SRU相关推荐

  1. 干货|神经网络及理解反向传播

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 一.人工神经网络简述 下面开始说神经网络.注意,当我们说N层神经网 ...

  2. 干货 | 神经网络原来这么简单,机器学习入门贴送给你

    白交 发自 凹非寺  量子位 报道 | 公众号 QbitAI 你想学机器学习吗?这里有一个入门贴适合你. 什么神经网络.随机森林.计算机视觉通通一网打尽. 这个Facebook软件工程师做了一个入门贴 ...

  3. 干货 | 神经网络与深度学习精选文章汇总

    AI有道 不可错过的AI技术公众号 关注 下面这部分列出了吴恩达深度学习专项课程中关于NN和DNN方面的所有精炼笔记.主要包括:神经网络与深度学习.优化神经网络.构建机器学习项目三块内容. 如果你对我 ...

  4. Rnn Lstm Gru Sru学习小结

    1.Rnn Rnn的详细介绍可以参考 深度学习之RNN(循环神经网络) 零基础入门深度学习(5) - 循环神经网络 详解循环神经网络(Recurrent Neural Network) 基本原理和算法 ...

  5. 力荐 | 吴恩达《序列模型》精炼笔记(1)-- 循环神经网络(RNN)

    AI有道 不可错过的AI技术公众号 关注 序列模型(Recurrent Neural Networks)是Andrw Ng深度学习专项课程中的第五门课,也是最后一门课.这门课主要介绍循环神经网络(RN ...

  6. 吴恩达《卷积神经网络》精炼笔记(2)-- 深度卷积模型:案例研究

    AI有道 不可错过的AI技术公众号 关注 1 Why Look at Case Studies 本文将主要介绍几个典型的CNN案例.通过对具体CNN模型及案例的研究,来帮助我们理解知识并训练实际的模型 ...

  7. 推荐一位零基础学 NLP 的大佬,内含成长历程

    大佬介绍 大佬:笔名zenRRan,方向自然语言处理,方法主要是深度学习. 未来的目标:人工智能之自然语言处理博士. 写公众号目的:将知识变成开源,让每个渴求知识而难以入门人工智能的小白以及想进阶的小 ...

  8. 一份从入门到精通NLP的完整指南 | NLPer

    该小博主介绍 本人:笔名zenRRan,方向自然语言处理,方法主要是深度学习. 未来的目标:人工智能之自然语言处理博士. 写公众号目的:将知识变成开源,让每个渴求知识而难以入门人工智能的小白以及想进阶 ...

  9. 最“燃”研究生!浙工大 64 岁研究生毕业,老师称其毕业论文写的最好

    点击上方"视学算法",选择"星标"公众号 重磅干货,第一时间送达 ‍ 整合 | 募格学术 来源 | 中国新闻网.杭州网 "如果能重选一次,我依旧会选择 ...

  10. 550 万华人在美人才现状:7 诺奖、300 院士,320 八大常春藤高校终身正教授......

    原文来源 | 北美学生圈 转载来源 | BioWorld 早在 150 年前就有华人移居美国的历史,从过去的劳工.公派留学生,到现在的投资者.高科技人才或者是看中养老福利过去养老的人.从古至今,移民美 ...

最新文章

  1. java byter是字节吗_GitHub - XXQAQ/Byter: 字节对象转换框架,一个基于字节的 Gson/FastJson...
  2. linux php 升级5.3,Linux php5.2.10升级到PHP5.3.29
  3. AMD为何要选择捆绑中国市场?
  4. redis 公网ip访问_Redis很重要,怎么只允许指定IP访问?
  5. c3p0 参数 模糊查询_Hibernate day03笔记
  6. LINUX无法运行navixat,关于RX5700XT的驱动方法以及bug解决方案
  7. 高德地图安卓 拖拽选点_行车记录仪当“眼睛” 高德地图手机AR导航再次升级...
  8. l298n电机哪一端为正_汽车维修要知道的几个答案,交流发电机、调节器有什么功用?...
  9. iOS开发:几种静态扫描工具的使用与对比
  10. java jtextfield 密码_Java Swing实战(三)文本组件JTextField和密码组件JPasswordField
  11. 惠州物联网产业规模 明年争取达400亿元
  12. ubuntu 12.04 修改 grub 启动参数
  13. shift 位置参数左移命令
  14. 去掉高德api上的logo图标
  15. 移动端屏幕适配(750px设计稿)
  16. Java使用ObjectInputStream时报错:java.lang.ClassNotFoundException: commen.User
  17. python方差齐性检验_【Python】统计科学之方差齐性检验
  18. Springboot2中文件上传报java.io.FileNotFoundException: C:\Users\WIzarder\AppData\Local\Temp\tomcat.8080.589
  19. 如何从0开始在鸿蒙OS中制作一个APP!
  20. 周末被马云的无人超市刷屏了

热门文章

  1. QToolBox学习笔记
  2. python之logging模块简单用法
  3. Swift 3必看:新的访问控制fileprivate和open
  4. Windows下Python,setuptools,pip,virtualenv的安装
  5. 高效程序员秘籍(9):快速查找硬盘上的文件和目录
  6. 数据库的跨平台设计(转)
  7. MySQL的安装与连接方法
  8. vue 数据劫持 响应式原理 Observer Dep Watcher
  9. YII2 搭建redis拓展(教程)
  10. 运行批处理bat文件不出现黑框