Pytorch LSTM 长短期记忆网络

0. 环境介绍

环境使用 Kaggle 里免费建立的 Notebook

教程使用李沐老师的 动手学深度学习 网站和 视频讲解

小技巧:当遇到函数看不懂的时候可以按 Shift+Tab 查看函数详解。

1. LSTM

LSTM 的设计灵感来自于计算机的逻辑门。
LSTM 引入了记忆单元(Memory cell)。
有些文献认为记忆单元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计的目的是用于记录附加的信息。LSTM 使用门控结构控制记忆单元。

在 LSTM 网络中,记忆单元 ccc 可以在某个时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔。记忆单元 ccc 中保存信息的生命周期要长于短期记忆 hhh,但又远远短于长期记忆, 长短期记忆是指长的 “短期记忆”。因此称为长短期记忆(Long Short-Term Memory)。

2. 门

  • 遗忘门:ftf_tft​ 通过上一时刻隐状态 ht−1h_{t-1}ht−1​ 和 当前输入 xtx_txt​ 控制上一时刻记忆单元 ct−1c_{t-1}ct−1​ 需要遗忘多少信息。
  • 输入门:iti_tit​ 控制当前时刻的候选记忆单元 c~t\widetilde{c}_tct​ 有多少信息需要保存。
  • 输出门:oto_tot​ 控制当前时刻的记忆单元 ctc_tct​ 有多少信息需要输出给当前隐状态 hth_tht​。

公式表示:

3. 从零开始实现代码

3.0 导包

!pip install -U d2l
import torch
from torch import nn
from d2l import torch as d2l

3.1 加载数据集

设置批量大小和步数

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

3.2 初始化模型参数

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆单元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

3.3 初始状态函数

在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆单元, 单元的值为 000,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

3.4 定义模型

其中要注意:@ 表示矩阵乘法,*+ 表示矩阵 point-wise 计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)

3.5 训练预测

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

4. 简洁实现

使用 nn.LSTM,训练更快:

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

Q & A

Pytorch LSTM 长短期记忆网络相关推荐

  1. 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

    深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...

  2. RNN循环神经网络 、LSTM长短期记忆网络实现时间序列长期利率预测

    全文链接:http://tecdat.cn/?p=25133 2017 年年中,R 推出了 Keras 包 _,_这是一个在 Tensorflow 之上运行的综合库,具有 CPU 和 GPU 功能(点 ...

  3. 【思维导图】利用LSTM(长短期记忆网络)来处理脑电数据

    文章来源| 脑机接口社区群友 认知计算_茂森的授权分享 在此非常感谢 认知计算_茂森! 本篇文章主要通过思维导图来介绍利用LSTM(长短期记忆网络)来处理脑电数据. 文章的内容来源于社区分享的文章&l ...

  4. 利用LSTM(长短期记忆网络)来处理脑电数据

    目录 LSTM 原理介绍 LSTM的核心思想 一步一步理解LSTM 代码案例 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 .QQ交流群:903290195 Rose小哥今天介绍一下用LS ...

  5. Maltab GUI课程设计——LSTM长短期记忆网络回归预测

    文章目录 课程设计 平台:Matlab App designer 功能实现:LSTM长短期记忆网络回归预测 目的: 演示: 欢迎交流 课程设计 平台:Matlab App designer 功能实现: ...

  6. 神经网络学习笔记3——LSTM长短期记忆网络

    目录 1.循环神经网络 1.1循环神经网络大致结构 1.2延时神经网络(Time Delay Neural Network,TDNN) 1.3按时间展开 1.4反向传播 1.5 梯度消失,梯度爆炸 2 ...

  7. LSTM(长短期记忆网络)原理介绍

    相关学习资料: Pytorch:RNN.LSTM.GRU.Bi-GRU.Bi-LSTM.梯度消失.爆炸 难以置信!LSTM和GRU的解析从未如此清晰 RNN_了不起的赵队-CSDN博客_rnn 如何从 ...

  8. LSTM(长短期记忆网络)原理与在脑电数据上的应用

    LSTMs(Long Short Term Memory networks,长短期记忆网络)简称LSTMs,很多地方用LSTM来指代它.本文也使用LSTM来表示长短期记忆网络.LSTM是一种特殊的RN ...

  9. LSTM -长短期记忆网络(RNN循环神经网络)

    文章目录 基本概念及其公式 输入门.输出门.遗忘门 候选记忆元 记忆元 隐状态 从零开始实现 LSTM 初始化模型参数 定义模型 训练和预测 简洁实现 小结 基本概念及其公式 LSTM,即(long ...

最新文章

  1. php的bom头会影响格式,phpBOM头(字符#65279;)出现的原因以及解决方法_PHP程序员博客|高蒙个人博客...
  2. Informix存储过程
  3. Institute for Manufacturing virtual check in part 1
  4. zoj3777(状态压缩)
  5. ubuntu14.04修改limits.conf后链接限制仍然不生效
  6. Linux系统中,各种小动物
  7. 学python可以做什么-学会Python后都能做什么?介绍五种Python的实用场景
  8. RQNOJ 342 最不听话的机器人:网格dp
  9. 软考《软件设计师教程》(第五版)
  10. MAC地址-集线器-ARP
  11. 0ctf-2017-pwn-char 题解
  12. op 圣诞节活动_圣诞节到了–这是我们精选的IT饼干笑话
  13. cesium导入骨骼动画
  14. igrp和eigrp详解
  15. 微软第四朵智能云:低代码平台Power Platform
  16. sin和cos的爱恋
  17. 2021年新春佳节,《经济学人》是如何报道的?
  18. pta 构造哈夫曼树-有序输入 优先队列做法
  19. 简单开发的android阅读器源码,包含了读取数据库和文件流处理功能
  20. 舌尖美味实践团采访活动

热门文章

  1. Caffe2 Synchronous SGD
  2. 红旗linux 6,初识 Linux(红旗Linux 6)
  3. 计算机图形学 (二) 图元的属性 - 概念、轮廓
  4. Eclipse4.4.2(luna) JDK1.8.0_212 PyDev5.2.0 Python2.7
  5. Cesium 热力图(可直接使用)
  6. java Servlet 笔记
  7. 在GitHub 上下载指定的文件夹的两种方法
  8. CSS五种方式实现 Footer 置底
  9. 数字图像-理想低通滤波器
  10. 【TouchDesigner】Feedback的应用NO.3