系列语:本系列是nlp-tutorial代码注释系列,github上原项目地址为:nlp-tutorial,本系列每一篇文章的大纲是相关知识点介绍 + 详细代码注释。

前言:上一节笔记是RNN相关,链接如下:nlp-tutorial代码注释3-1,RNN简介,普通的RNN有一个问题,就是梯度消失,本节介绍解决此问题的一个方法:使用LSTM单元。

梯度消失问题:

考虑第i个时间步的损失对第j个时间步的激活值的导数,通过链式法则可以得到下图的求导公式:

当i、j距离较远时,若Wh较小,整个式子会指数级的变小,这是梯度消失;若Wh较大,整个式子会指数级的变大,这是梯度爆炸。
由于梯度消失,较远处的计算产生的影响也会消失,最终更新梯度时就只会收到较近的计算的影响,而很难受到长期的影响,亦即RNN记性很差。
举例:当RNN处理下图这样的语言模型预测问题时,根据上下文,很显然空格处是tickets,而由于上一个tickets距离太远,RNN很难预测出这个词是tickets。

主要的问题是:RNN很难去保存很多个时间步之前的信息,即不具有记忆性。我们需要一个具有记忆的RNN!这就是LSTM的主要想法。

LSTM

在第t个时间步,有一个隐藏状态h(t)h^{(t)}h(t)和一个单元状态c(t)c^{(t)}c(t),他们都是长度为n的向量,单元状态c可以存储长期信息,LSTM可以对单元状态c进行删除、写入、读取信息的操作。
单元状态c的信息被删除、写入、读取分别由三个对应的门控制。在每个时间步,门的每个元素可以是1(打开)、0(关闭),也可以是介于两者之间的值。具体公式如下,在时间步t计算h(t)h^{(t)}h(t)和c(t)c^{(t)}c(t):

遗忘门f(t)f^{(t)}f(t):控制对上一个时间步的单元状态c(t−1)c^{(t-1)}c(t−1)是保持还是遗忘;
输入门i(t)i^{(t)}i(t):控制写入新单元状态的哪些内容;
输出门o(t)o^{(t)}o(t):控制单元内容的哪些部分输出到h(t)h^{(t)}h(t);
c~\tilde{c}c~(t)^{(t)}(t):新单元状态;
c(t)c^{(t)}c(t):通过遗忘一些上个时间步的单元状态c(t−1)c^{(t-1)}c(t−1)并写入一部分c~\tilde{c}c~(t)^{(t)}(t)而在本时间步产生的新单元状态;
h(t)h^{(t)}h(t):从单元状态c(t)c^{(t)}c(t)中读取一部分作为本时间步的隐藏状态。
LSTM架构让RNN更容易保存很多个时间步之前的信息,例如如果遗忘门一直被设置为0,那么信息就能够得到永久的保存。LSTM并不能保证没有梯度消失,但他确实让模型更容易学期长期的依赖关系。

代码实现

pytorch代码及详细注释如下:(源代码为github中nlp-tutorial项目,项目地址:nlp-tutorial)
首先import一些需要的库,并设置元素默认的type为float:

import numpy as np                          #引入numpy库
import torch                                #引入torch
import torch.nn as nn                       #torch.nn是torch的神经网络库
import torch.optim as optim                 #torch.optim是优化库,包含很多优化函数
from torch.autograd import Variable         #现在的pytorch版本variable已经回归tensor了,直接用tensor即可
dtype = torch.FloatTensor

接下来是建立字典,本次代码的目的是根据前三个字母预测单词的第四个字母,字典中是26个字母:

char_arr = [c for c in 'abcdefghijklmnopqrstuvwxyz']# 建立字母列表
word_dict = {n: i for i, n in enumerate(char_arr)}  # 这两行分别建立字母到序号的和序号到字母的索引
number_dict = {i: w for i, w in enumerate(char_arr)}
n_class = len(word_dict)  # 字典大小seq_data = ['make', 'need', 'coal', 'word', 'love', 'hate', 'live', 'home', 'hash', 'star']   # 数据集

接着设置一些参数:步长为3,即根据3个字母预测下一个,n_hidden是隐藏层单元个数:

n_step = 3
n_hidden = 128

处理数据集,获得输入和对应的标记:

def make_batch(seq_data):input_batch, target_batch = [], [] #空列表for seq in seq_data:input = [word_dict[n] for n in seq[:-1]] # 'm', 'a' , 'k' is inputtarget = word_dict[seq[-1]] # 'e' is targetinput_batch.append(np.eye(n_class)[input])target_batch.append(target)return Variable(torch.Tensor(input_batch)), Variable(torch.LongTensor(target_batch))

接着定义模型:

class TextLSTM(nn.Module):

首先是_init_,先继承父类,再使用nn.LSTM搭建LSTM层,再初始化隐藏层的参数W和b:

    def __init__(self):super(TextLSTM, self).__init__()self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden)self.W = nn.Parameter(torch.randn([n_hidden, n_class]).type(dtype))self.b = nn.Parameter(torch.randn([n_class]).type(dtype))

再是forward,这里首先要初始化第0个时间步的h(0)h^{(0)}h(0)和c(0)c^{(0)}c(0),这里output是所有时间步的输出,这里是RNN语言模型,只需要最后一步的输出即可:

    def forward(self, X):input = X.transpose(0, 1)  # 将X的形状变换为:[n_step, batch_size, n_class]#初始化第0个时间步的ht和cthidden_state = Variable(torch.zeros(1, len(X), n_hidden))   # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]cell_state = Variable(torch.zeros(1, len(X), n_hidden))     # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))outputs = outputs[-1]  # 取最后一步的输出,形状为:[batch_size, num_directions(=1) * n_hidden]model = torch.mm(outputs, self.W) + self.b  # model : [batch_size, n_class]return model

接着是训练前的准备工作,调用make_batch函数获得输入和输出,接着选择损失函数和优化方法:

input_batch, target_batch = make_batch(seq_data)
model = TextLSTM()
criterion = nn.CrossEntropyLoss()                    #损失函数为交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) #使用Adam算法进行优化

接下来训练模型:

for epoch in range(1000):optimizer.zero_grad()          #每次训练前清除梯度缓存output = model(input_batch)    #模型计算outputloss = criterion(output, target_batch)       #计算lossif (epoch + 1) % 100 == 0:print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))loss.backward()                              #反向传播、自动求导optimizer.step()                             #优化、更新参数

最后对训练好的模型进行测试:

inputs = [sen[:3] for sen in seq_data]predict = model(input_batch).data.max(1, keepdim=True)[1]
print(inputs, '->', [number_dict[n.item()] for n in predict.squeeze()])

nlp-tutorial代码注释3-2,LSTM简介相关推荐

  1. nlp-tutorial代码注释3-1,RNN简介

    系列语:本系列是nlp-tutorial代码注释系列,github上原项目地址为:nlp-tutorial,本系列每一篇文章的大纲是相关知识点介绍 + 详细代码注释. 前言:针对之前n-gram等具有 ...

  2. nlp-tutorial代码注释笔记

    系列语:本系列是nlp-tutorial代码注释系列,github上原项目地址为:nlp-tutorial,本系列每一篇文章的大纲是相关知识点介绍 + 详细代码注释. 传送门: nlp-tutoria ...

  3. nlp-tutorial代码注释3-3,双向RNN简介

    系列语:本系列是nlp-tutorial代码注释系列,github上原项目地址为:nlp-tutorial,本系列每一篇文章的大纲是相关知识点介绍 + 详细代码注释. 前言:3-1节介绍了普通的RNN ...

  4. yolov3网络结构图_目标检测——YOLO V3简介及代码注释(附github代码——已跑通)...

    GitHub: liuyuemaicha/PyTorch-YOLOv3​github.com 注:该代码fork自eriklindernoren/PyTorch-YOLOv3,该代码相比master分 ...

  5. nlp-tutorial代码注释1-1,语言模型、n-gram简介

    系列语:本系列是nlp-tutorial代码注释系列,github上原项目地址为:nlp-tutorial,本系列每一篇文章的大纲是相关知识点介绍 + 详细代码注释. 本文知识点介绍来自斯坦福大学CS ...

  6. tensorflow笔记:流程,概念和简单代码注释

    tensorflow是google在2015年开源的深度学习框架,可以很方便的检验算法效果.这两天看了看官方的tutorial,极客学院的文档,以及综合tensorflow的源码,把自己的心得整理了一 ...

  7. PHP中类和文件的代码注释规范

    编写好的文档对于任何软件项目都至关重要,不仅是因为文档的质量可能比代码的质量更重要,还因为良好的第一印象会促使开发人员进一步查看代码以及后续的迭代. 文件注释 /*** Sample file com ...

  8. c++ doxygen 注释规范_C语言代码注释参考

    简述 该参考是基于Doxygen注释规范进行简单归纳,可以适当根据自己的需求进行约定. Doxygen可以从一套归档源文件开始,生成HTML格式的在线类浏览器,或离线的LATEX.RTF参考手册.简单 ...

  9. Python代码注释 - Python零基础入门教程

    目录 一.什么是代码注释 二.为什么写代码要注释 三.代码注释的方式 1.单行注释,使用英文符号 # 2.多行注释 方法一:英文状态下使用单引号 """ 方法二:英文状态 ...

最新文章

  1. 微信小程序开发 笔记
  2. 生成对抗网络GANs理解(附代码)
  3. 浅谈视觉设计的准确性
  4. if condition 大于_小函数,大用处!巧用AND函数,避开IF函数嵌套
  5. 利物浦大学图书馆官网西交利物浦大学图书馆官网
  6. [转]大话企业级Android应用开发实战 音乐播放器的开发
  7. 基于verilog的FFT设计与实现
  8. android studio adil位置,在Android Studio 中正确使用adil ”绝对经典“
  9. linux交换区使用过多导致的性能问题
  10. 节俭,是一种了不起的能力
  11. 7款纯CSS3实现的炫酷动画应用
  12. Dxdesigner SCH to Mentor PCB
  13. 关于手机ping电脑和电脑ping手机
  14. 大会没看够?2021 Google 开发者大会总结看这里!
  15. 泡泡堂、QQ堂游戏通信架构分析
  16. 关于xmind6图标组导入教程
  17. 知到答案 环境学概论 智慧树网课章节测试答案
  18. 馋猫美食记录本_隐私政策
  19. 2021秋招笔试(1)_乐鑫
  20. 文字下划线效果(标题hover效果)

热门文章

  1. visual studio 代码分析利器 FxCop
  2. (转):GOF设计模式趣解(23种设计模式)
  3. 背起行囊,就是过客;放下包袱,就有归宿。
  4. codewars--js--Happy numbers++无穷大判断
  5. centos 下 sphinx安装和配置
  6. HDOJ 5542 The Battle of Chibi
  7. 手机网页宽度自动适应屏幕宽度的方…
  8. 【渗透测试实战】PHP语言有哪些后门?以及利用方法
  9. Linux操作Oracle(11)——Oracle用户密码过期 设置密码永不过期方法
  10. Linux故障解决(4)——新安装的CentOS 系统无法上网解决方法 (未知的名称或服务)