Pytorch LSTM 长短期记忆网络
Pytorch LSTM 长短期记忆网络
0. 环境介绍
环境使用 Kaggle 里免费建立的 Notebook
教程使用李沐老师的 动手学深度学习 网站和 视频讲解
小技巧:当遇到函数看不懂的时候可以按 Shift+Tab
查看函数详解。
1. LSTM
LSTM 的设计灵感来自于计算机的逻辑门。
LSTM 引入了记忆单元(Memory cell)。
有些文献认为记忆单元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计的目的是用于记录附加的信息。LSTM 使用门控结构控制记忆单元。
在 LSTM 网络中,记忆单元 ccc 可以在某个时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔。记忆单元 ccc 中保存信息的生命周期要长于短期记忆 hhh,但又远远短于长期记忆, 长短期记忆是指长的 “短期记忆”。因此称为长短期记忆(Long Short-Term Memory)。
2. 门
- 遗忘门:ftf_tft 通过上一时刻隐状态 ht−1h_{t-1}ht−1 和 当前输入 xtx_txt 控制上一时刻记忆单元 ct−1c_{t-1}ct−1 需要遗忘多少信息。
- 输入门:iti_tit 控制当前时刻的候选记忆单元 c~t\widetilde{c}_tct 有多少信息需要保存。
- 输出门:oto_tot 控制当前时刻的记忆单元 ctc_tct 有多少信息需要输出给当前隐状态 hth_tht。
公式表示:
3. 从零开始实现代码
3.0 导包
!pip install -U d2l
import torch
from torch import nn
from d2l import torch as d2l
3.1 加载数据集
设置批量大小和步数
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
3.2 初始化模型参数
def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three() # 输入门参数W_xf, W_hf, b_f = three() # 遗忘门参数W_xo, W_ho, b_o = three() # 输出门参数W_xc, W_hc, b_c = three() # 候选记忆单元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params
3.3 初始状态函数
在初始化函数中, 长短期记忆网络的隐状态需要返回一个额外的记忆单元, 单元的值为 000,形状为(批量大小,隐藏单元数)。 因此,我们得到以下的状态初始化。
def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))
3.4 定义模型
其中要注意:@
表示矩阵乘法,*
和 +
表示矩阵 point-wise 计算。
def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)
3.5 训练预测
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
4. 简洁实现
使用 nn.LSTM
,训练更快:
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
Q & A
Pytorch LSTM 长短期记忆网络相关推荐
- 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别
深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...
- RNN循环神经网络 、LSTM长短期记忆网络实现时间序列长期利率预测
全文链接:http://tecdat.cn/?p=25133 2017 年年中,R 推出了 Keras 包 _,_这是一个在 Tensorflow 之上运行的综合库,具有 CPU 和 GPU 功能(点 ...
- 【思维导图】利用LSTM(长短期记忆网络)来处理脑电数据
文章来源| 脑机接口社区群友 认知计算_茂森的授权分享 在此非常感谢 认知计算_茂森! 本篇文章主要通过思维导图来介绍利用LSTM(长短期记忆网络)来处理脑电数据. 文章的内容来源于社区分享的文章&l ...
- 利用LSTM(长短期记忆网络)来处理脑电数据
目录 LSTM 原理介绍 LSTM的核心思想 一步一步理解LSTM 代码案例 本分享为脑机学习者Rose整理发表于公众号:脑机接口社区 .QQ交流群:903290195 Rose小哥今天介绍一下用LS ...
- Maltab GUI课程设计——LSTM长短期记忆网络回归预测
文章目录 课程设计 平台:Matlab App designer 功能实现:LSTM长短期记忆网络回归预测 目的: 演示: 欢迎交流 课程设计 平台:Matlab App designer 功能实现: ...
- 神经网络学习笔记3——LSTM长短期记忆网络
目录 1.循环神经网络 1.1循环神经网络大致结构 1.2延时神经网络(Time Delay Neural Network,TDNN) 1.3按时间展开 1.4反向传播 1.5 梯度消失,梯度爆炸 2 ...
- LSTM(长短期记忆网络)原理介绍
相关学习资料: Pytorch:RNN.LSTM.GRU.Bi-GRU.Bi-LSTM.梯度消失.爆炸 难以置信!LSTM和GRU的解析从未如此清晰 RNN_了不起的赵队-CSDN博客_rnn 如何从 ...
- LSTM(长短期记忆网络)原理与在脑电数据上的应用
LSTMs(Long Short Term Memory networks,长短期记忆网络)简称LSTMs,很多地方用LSTM来指代它.本文也使用LSTM来表示长短期记忆网络.LSTM是一种特殊的RN ...
- LSTM -长短期记忆网络(RNN循环神经网络)
文章目录 基本概念及其公式 输入门.输出门.遗忘门 候选记忆元 记忆元 隐状态 从零开始实现 LSTM 初始化模型参数 定义模型 训练和预测 简洁实现 小结 基本概念及其公式 LSTM,即(long ...
最新文章
- php的bom头会影响格式,phpBOM头(字符#65279;)出现的原因以及解决方法_PHP程序员博客|高蒙个人博客...
- Informix存储过程
- Institute for Manufacturing virtual check in part 1
- zoj3777(状态压缩)
- ubuntu14.04修改limits.conf后链接限制仍然不生效
- Linux系统中,各种小动物
- 学python可以做什么-学会Python后都能做什么?介绍五种Python的实用场景
- RQNOJ 342 最不听话的机器人:网格dp
- 软考《软件设计师教程》(第五版)
- MAC地址-集线器-ARP
- 0ctf-2017-pwn-char 题解
- op 圣诞节活动_圣诞节到了–这是我们精选的IT饼干笑话
- cesium导入骨骼动画
- igrp和eigrp详解
- 微软第四朵智能云:低代码平台Power Platform
- sin和cos的爱恋
- 2021年新春佳节,《经济学人》是如何报道的?
- pta 构造哈夫曼树-有序输入 优先队列做法
- 简单开发的android阅读器源码,包含了读取数据库和文件流处理功能
- 舌尖美味实践团采访活动
热门文章
- Caffe2 Synchronous SGD
- 红旗linux 6,初识 Linux(红旗Linux 6)
- 计算机图形学 (二) 图元的属性 - 概念、轮廓
- Eclipse4.4.2(luna) JDK1.8.0_212 PyDev5.2.0 Python2.7
- Cesium 热力图(可直接使用)
- java Servlet 笔记
- 在GitHub 上下载指定的文件夹的两种方法
- CSS五种方式实现 Footer 置底
- 数字图像-理想低通滤波器
- 【TouchDesigner】Feedback的应用NO.3