文章目录

  • 1 引入
  • 2 长短期记忆
    • 2.1 输入门、遗忘门和输出门
    • 2.2 候选记忆细胞
    • 2.3 记忆细胞
    • 2.4 隐藏状态
  • 3 代码
  • 致谢

1 引入

  本文介绍一种常用的门控循环神经网络:长短期记忆 (long short-term memory, LSTM)。它比门控循环单元的结构稍微复杂一点。

2 长短期记忆

  LSTM引入了333个门,即输入门 (input gate)、遗忘门 (forget gate)和输出门 (output gate),以及与隐藏状态形状相同的记忆细胞,从而记录额外的信息。

2.1 输入门、遗忘门和输出门

  与门控循环单元中的重置门和更新门一样,如下图,长短期记忆的门的输入均为当前时间步输入Xt\boldsymbol{X}_tXt​与上一时间步隐藏状态Ht−1\boldsymbol{H}_{t-1}Ht−1​,输出由激活函数为sigmoid函数的全连接层计算得到。如此一来,这333个门元素的值域均为[0,1][0,1][0,1]。

  具体来说,假设隐藏单元个数为hhh,给定时间步ttt的小批量输入Xt∈Rn×d\boldsymbol{X}_t\in\mathbb{R}^{n\times d}Xt​∈Rn×d和上一时间步隐藏状态Ht−1∈Rn×h\boldsymbol{H}_{t-1}\in\mathbb{R}^{n \times h}Ht−1​∈Rn×h。时间步ttt的输入门It∈Rn×h\boldsymbol{I}_t\in\mathbb{R}^{n\times h}It​∈Rn×h、遗忘门Ft∈Rn×h\boldsymbol{F}_t\in\mathbb{R}^{n\times h}Ft​∈Rn×h和输出门Ot∈Rn×h\boldsymbol{O}_t\in\mathbb{R}^{n\times h}Ot​∈Rn×h分别计算如下:
It=σ(XtWxi+Ht−1Whi+bi),Ft=σ(XtWxf+Ht−1Whf+bf),Ot=σ(XtWxo+Ht−1Who+bo),\begin{aligned} \boldsymbol{I}_{t} &=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x i}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h i}+\boldsymbol{b}_{i}\right), \\ \boldsymbol{F}_{t} &=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x f}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h f}+\boldsymbol{b}_{f}\right), \\ \boldsymbol{O}_{t} &=\sigma\left(\boldsymbol{X}_{t} \boldsymbol{W}_{x o}+\boldsymbol{H}_{t-1} \boldsymbol{W}_{h o}+\boldsymbol{b}_{o}\right), \end{aligned} It​Ft​Ot​​=σ(Xt​Wxi​+Ht−1​Whi​+bi​),=σ(Xt​Wxf​+Ht−1​Whf​+bf​),=σ(Xt​Wxo​+Ht−1​Who​+bo​),​其中Wxi,Wxf,Wxo∈Rd×h\boldsymbol{W}_{xi}, \boldsymbol{W}_{xf}, \boldsymbol{W}_{xo} \in \mathbb{R}^{d\times h}Wxi​,Wxf​,Wxo​∈Rd×h和Whi,Whf,Who∈Rh×h\boldsymbol{W}_{hi}, \boldsymbol{W}_{hf}, \boldsymbol{W}_{ho} \in \mathbb{R}^{h\times h}Whi​,Whf​,Who​∈Rh×h是权重参数,bi,bf,bo∈Rh×h\boldsymbol{b}_{i}, \boldsymbol{b}_{f}, \boldsymbol{b}_{o} \in \mathbb{R}^{h\times h}bi​,bf​,bo​∈Rh×h是偏差参数。

2.2 候选记忆细胞

  长短期记忆需要计算候选记忆细胞C~t\tilde{\boldsymbol{C}}_tC~t​。它的计算与上面介绍的333个门类似,但使用了值域在[−1,1][-1,1][−1,1]的tanh函数作为激活函数,如下图所示。

  具体来说,时间步ttt的候选记忆细胞C~t∈Rn×h\tilde{\boldsymbol{C}}_t\in\mathbb{R}^{n\times h}C~t​∈Rn×h的计算为:
C~t=tanh(XtWxc+Ht−1Whc+bc),\tilde{\boldsymbol{C}}_t = \text{tanh}(\boldsymbol{X}_t\boldsymbol{W}_{xc}+\boldsymbol{H}_{t-1}\boldsymbol{W}_{hc}+\boldsymbol{b}_c), C~t​=tanh(Xt​Wxc​+Ht−1​Whc​+bc​),其中Wxc∈Rd×h\boldsymbol{W}_{xc}\in\mathbb{R}^{d\times h}Wxc​∈Rd×h和Whc∈Rh×h\boldsymbol{W}_{hc}\in\mathbb{R}^{h\times h}Whc​∈Rh×h,bc∈R1×h\boldsymbol{b}_c\in\mathbb{R}^{1\times h}bc​∈R1×h是偏差参数。

2.3 记忆细胞

  通过元素值域在[0,1][0,1][0,1]的输入门、遗忘门和输出门来控制隐藏状态中信息的流动,这一般也是通过使用按元素乘法⊙\odot⊙来实现的。当前时间步记忆细胞Ct∈Rn×h\boldsymbol{C}_t\in\mathbb{R}^{n \times h}Ct​∈Rn×h的计算组合了上一时间步记忆细胞和当前时间步候选记忆细胞的信息,并通过遗忘门和输入门来控制信息的流动:
Ct=Ft⊙Ct−1+It⊙C~t,\boldsymbol{C}_t = \boldsymbol{F}_t \odot \boldsymbol{C}_{t-1}+\boldsymbol{I}_t\odot\tilde{\boldsymbol{C}}_t, Ct​=Ft​⊙Ct−1​+It​⊙C~t​,如下图所示。该设计可以应对RNN中的梯度衰减问题,并更好地捕捉时间序列中时间步间距较大依赖关系

2.4 隐藏状态

  有了记忆细胞以后,接下来可以通过输出门来控制从记忆细胞到隐藏状态Ht\boldsymbol{H}_tHt​的信息流动:
Ht=Ot⊙tanh(Ct).\boldsymbol{H}_t=\boldsymbol{O}_t\odot\text{tanh}(\boldsymbol{C}_t). Ht​=Ot​⊙tanh(Ct​).这里的tanh函数确保隐藏状态元素值在[−1,1][-1,1][−1,1]之间。需要注意的是,当输出门近似111时,记忆细胞信息将传递到隐藏状态供输出层使用;解决000时,记忆细胞的信息只自己保留,如下图。

3 代码

  代码的主题框架与博客周杰伦歌词数据集测试循环神经网络中的架构一致,不同之处在于需要将mainpy文件中的以下语句替换为:

rnn_layer = get_rnn_layer(input_size=dict_size, hidden_size=hidden_size)
model = RNNModel(rnn_layer, dict_size).to(device)

↓↓↓

lstm_layer = nn.LSTM(input_size=dict_size, hidden_size=hidden_size)
model = RNNModel(lstm_layer, dict_size).to(device)

  输出如下:

epoch 50, perplexity 1.017165, time 1.56 sec- 分开始移动 回到当初爱你的时空 停格内容不忠 所有回忆对着我进攻       所有回忆对着我进攻    - 不分开 我知道这里很美但家乡的你更美走过了很多地方 我来到伊斯坦堡 就像是童话故事  有教堂有城堡 每天忙
epoch 100, perplexity 1.013933, time 1.57 sec- 分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让- 不分开 我对定会呵护著你 也逗你笑 你对我有多重要 我后悔没让你知道 安静的听你撒娇 看你睡著一直到老 就
epoch 150, perplexity 1.019680, time 1.55 sec- 分开始想像 爸和妈当年的模样 说著一口吴侬软语的姑娘缓缓走过外滩 消失的 旧时光 一九四三 在回忆 的路- 不分开 我知道这里很美但家乡的你更美原来我只想要你 陪我去吃汉堡  说穿了其实我的愿望就怎么小 就怎么每天
epoch 200, perplexity 2.351453, time 1.55 sec- 分开始想太你的 我都的可爱 不再考倒我 难过 我想躲 我不能再想 我不能再想 我不 我不 我不要再想 我- 不分开 那场外加油 你却还让我和狂的玩我 相思寄红豆 相思寄红豆走是人方的响尾蛇 无力的我爱你 爱情来的太
epoch 250, perplexity 1.014510, time 1.54 sec- 分开始打呼 管家是一只会说法语举止优雅的猪 吸血前会念约翰福音做为弥补 拥有一双蓝色眼睛的凯萨琳公主 专- 不分开 我用家二 在人海中 盲目跟从 别人的梦 全面放纵 恨没有用 疗伤止痛 不再感动 没有梦 痛不知轻重

致谢

感谢李沐、Aston Zhang等老师的这本《动手学深度学习》一书,为鄙人学习深度学习提供了很大的帮助。本文一系列关于深度学习的博客均无侵权之意,只为记录自己的深度学习历程。
  项目Github地址:https://github.com/ShusenTang/Dive-into-DL-PyTorch

torch学习 (三十二):周杰伦歌词数据集与长短期记忆 (LSTM)相关推荐

  1. torch学习 (三十四):迁移学习之微调

    文章目录 引入 1 微调 2 热狗识别 2.1 数据集载入 2.2 数据集预处理 2.3 定义和初始化模型 2.4 微调模型 致谢 引入   场景:   从图像中识别出不同种类的椅子,然后将购买链接推 ...

  2. Java多线程学习三十二:Callable 和 Runnable 的不同?

    为什么需要 Callable?Runnable 的缺陷 先来看一下,为什么需要 Callable?要想回答这个问题,我们先来看看现有的 Runnable 有哪些缺陷? 不能返回一个返回值 第一个缺陷, ...

  3. ballerina 学习 三十二 编写安全的程序

    ballerina编译器已经集成了部分安全检测,在编译时可以帮助我们生成错误提示,同时ballerina 标准库 已经对于常见漏洞高发的地方做了很好的处理,当我们编写了有安全隐患的代码,编译器就已经提 ...

  4. 深度学习入门(三十二)卷积神经网络——BN批量归一化

    深度学习入门(三十二)卷积神经网络--BN批量归一化 前言 批量归一化batch normalization 课件 批量归一化 批量归一化层 批量归一化在做什么? 总结 教材 1 训练深层网络 2 批 ...

  5. tensorflow学习笔记(三十二):conv2d_transpose (解卷积)

    tensorflow学习笔记(三十二):conv2d_transpose ("解卷积") deconv解卷积,实际是叫做conv_transpose, conv_transpose ...

  6. OpenCV学习笔记(三十一)——让demo在他人电脑跑起来 OpenCV学习笔记(三十二)——制作静态库的demo,没有dll也能hold住 OpenCV学习笔记(三十三)——用haar特征训练自己

    OpenCV学习笔记(三十一)--让demo在他人电脑跑起来 这一节的内容感觉比较土鳖.这从来就是一个老生常谈的问题.学MFC的时候就知道这个事情了,那时候记得老师强调多次,如果写的demo想在人家那 ...

  7. JavaScript学习(三十二)— Keycode常用键位码对照表

    JavaScript学习(三十二)- Keycode常用键位码对照表 (一).字母和数字键的键码值(keyCode) (二).控制键键码值(keyCode) (三).多媒体键码值(keyCode)

  8. C++语言学习(十二)——C++语言常见函数调用约定

    C++语言学习(十二)--C++语言常见函数调用约定 一.C++语言函数调用约定简介 C /C++开发中,程序编译没有问题,但链接的时候报告函数不存在,或程序编译和链接都没有错误,但只要调用库中的函数 ...

  9. axi dma 寄存器配置_FPGA Xilinx Zynq 系列(三十二)AXI 接口

    大侠好,欢迎来到FPGA技术江湖,江湖偌大,相见即是缘分.大侠可以关注FPGA技术江湖,在"闯荡江湖"."行侠仗义"栏里获取其他感兴趣的资源,或者一起煮酒言欢. ...

最新文章

  1. java 布局教程_java布局学习(新)
  2. 做一个完整的Java Web项目太难了,因为这些你不会!
  3. AI公开课:19.03.06何晓冬博士《自然语言与多模态交互前沿技术》课堂笔记以及个人感悟
  4. iOS 应用性能测试的相关方法、工具及技巧
  5. 实验2 java_《Java程序设计》实验2
  6. 28岁程序员狂赚上亿宣布退休,网友:这就是命!
  7. 信息学奥赛一本通 1067:整数的个数 | OpenJudge NOI 1.5 11
  8. android按钮周围阴影,Android 上的按钮填充和阴影
  9. .NET中的设计模式——一步步发现装饰模式
  10. java人账户atm模拟存款,模拟银行ATM系统(基础版)
  11. 某LINUX平台,管道open直接崩溃
  12. curl安装使用【超级无敌简单】
  13. 程序员写代码都用什么样的笔记本?
  14. Java-满天繁星案例(1)
  15. 类和对象,属性和方法
  16. 在VMware 14虚拟机下,ndn-cxx和NFD平台搭建
  17. allure如何设置新logo
  18. java 解析p12_java引用微信支付的p12证书文件
  19. 上班被监控屏幕和摄像头,拒绝就直接开除,员工起诉公司获赔52万元
  20. java中a= b_Java中a+=b和a=a+b的区别

热门文章

  1. HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions
  2. 第二届“红明谷”杯数据安全大赛-安全意识赛
  3. 2016阿里安全峰会(附峰会议题一览表)
  4. 【浏览器】1219- 换一种风格理解 Chrome 浏览器渲染全过程
  5. shp,sde,xmd的理解
  6. 华为P40或将搭载鸿蒙,厉害了任正菲的华为:P40或将搭载鸿蒙系统很快就要上市了...
  7. 设置邮件规则,轻松整理你的收件箱!
  8. 基于 bioMart 构建绵羊(非常见物种) OrgDb 包/数据库
  9. 计算机日语考研的学校,专科日语考研考什么学校
  10. DLL搜索路径和DLL劫持