原文

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import init
from torch import Tensor
import mathclass NaiveLSTM(nn.Module):"""Naive LSTM like nn.LSTM"""def __init__(self, input_size: int, hidden_size: int):super(NaiveLSTM, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_size# input gateself.w_ii = Parameter(Tensor(hidden_size, input_size))self.w_hi = Parameter(Tensor(hidden_size, hidden_size))self.b_ii = Parameter(Tensor(hidden_size, 1))self.b_hi = Parameter(Tensor(hidden_size, 1))# forget gateself.w_if = Parameter(Tensor(hidden_size, input_size))self.w_hf = Parameter(Tensor(hidden_size, hidden_size))self.b_if = Parameter(Tensor(hidden_size, 1))self.b_hf = Parameter(Tensor(hidden_size, 1))# output gateself.w_io = Parameter(Tensor(hidden_size, input_size))self.w_ho = Parameter(Tensor(hidden_size, hidden_size))self.b_io = Parameter(Tensor(hidden_size, 1))self.b_ho = Parameter(Tensor(hidden_size, 1))# cellself.w_ig = Parameter(Tensor(hidden_size, input_size))self.w_hg = Parameter(Tensor(hidden_size, hidden_size))self.b_ig = Parameter(Tensor(hidden_size, 1))self.b_hg = Parameter(Tensor(hidden_size, 1))self.reset_weigths()def reset_weigths(self):"""reset weights"""stdv = 1.0 / math.sqrt(self.hidden_size)for weight in self.parameters():init.uniform_(weight, -stdv, stdv)def forward(self, inputs: Tensor, state: Tuple[Tensor]) \-> Tuple[Tensor, Tuple[Tensor, Tensor]]:"""ForwardArgs:inputs: [1, 1, input_size]state: ([1, 1, hidden_size], [1, 1, hidden_size])"""
#         seq_size, batch_size, _ = inputs.size()if state is None:h_t = torch.zeros(1, self.hidden_size).t()c_t = torch.zeros(1, self.hidden_size).t()else:(h, c) = stateh_t = h.squeeze(0).t()c_t = c.squeeze(0).t()hidden_seq = []seq_size = 1for t in range(seq_size):x = inputs[:, t, :].t()# input gatei = torch.sigmoid(self.w_ii @ x + self.b_ii + self.w_hi @ h_t +self.b_hi)# forget gatef = torch.sigmoid(self.w_if @ x + self.b_if + self.w_hf @ h_t +self.b_hf)# cellg = torch.tanh(self.w_ig @ x + self.b_ig + self.w_hg @ h_t+ self.b_hg)# output gateo = torch.sigmoid(self.w_io @ x + self.b_io + self.w_ho @ h_t +self.b_ho)c_next = f * c_t + i * gh_next = o * torch.tanh(c_next)c_next_t = c_next.t().unsqueeze(0)h_next_t = h_next.t().unsqueeze(0)hidden_seq.append(h_next_t)hidden_seq = torch.cat(hidden_seq, dim=0)return hidden_seq, (h_next_t, c_next_t)def reset_weigths(model):"""reset weights"""for weight in model.parameters():init.constant_(weight, 0.5)### test
inputs = torch.ones(1, 1, 10)
h0 = torch.ones(1, 1, 20)
c0 = torch.ones(1, 1, 20)
print(h0.shape, h0)
print(c0.shape, c0)
print(inputs.shape, inputs)# test naive_lstm with input_size=10, hidden_size=20
naive_lstm = NaiveLSTM(10, 20)
reset_weigths(naive_lstm)output1, (hn1, cn1) = naive_lstm(inputs, (h0, c0))print(hn1.shape, cn1.shape, output1.shape)
print(hn1)
print(cn1)
print(output1)
# Use official lstm with input_size=10, hidden_size=20
lstm = nn.LSTM(10, 20)
reset_weigths(lstm)
output2, (hn2, cn2) = lstm(inputs, (h0, c0))
print(hn2.shape, cn2.shape, output2.shape)
print(hn2)
print(cn2)
print(output2)

使用pytorch动手实现LSTM模块相关推荐

  1. 【深度学习】在PyTorch中使用 LSTM 进行新冠病例预测

    时间序列数据,顾名思义是一种随时间变化的数据.例如,24 小时时间段内的温度,一个月内各种产品的价格,特定公司一年内的股票价格.长短期记忆网络(LSTM)等高级深度学习模型能够捕捉时间序列数据中的模式 ...

  2. 【Pytorch学习笔记2】Pytorch的主要组成模块

    个人笔记,仅用于个人学习与总结 感谢DataWhale开源组织提供的优秀的开源Pytorch学习文档:原文档链接 本文目录 1. Pytorch的主要组成模块 1.1 完成深度学习的必要部分 1.2 ...

  3. 运用PyTorch动手搭建一个共享单车预测器

    本文摘自 <深度学习原理与PyTorch实战> 我们将从预测某地的共享单车数量这个实际问题出发,带领读者走进神经网络的殿堂,运用PyTorch动手搭建一个共享单车预测器,在实战过程中掌握神 ...

  4. 【Pytorch】BERT+LSTM+多头自注意力(文本分类)

    [Pytorch]BERT+LSTM+多头自注意力(文本分类) 2018年Google提出了BERT[1](Bidirectional Encoder Representations from Tra ...

  5. lstm 根据前文预测词_干货 | Pytorch实现基于LSTM的单词检测器

    Pytorch实现 基于LSTM的单词检测器 字幕组双语原文: Pytorch实现基于LSTM的单词检测器 英语原文: LSTM Based Word Detectors 翻译: 雷锋字幕组(Icar ...

  6. PyTorch基础-使用LSTM神经网络实现手写数据集识别-08

    import numpy as np import torch from torch import nn,optim from torch.autograd import Variable from ...

  7. PyTorch基础(part8)--LSTM

    学习笔记,仅供参考,有错必纠 文章目录 代码 问题分析 初始设置 导包 载入数据 模型 模型持久化 模型的保存 模型的载入 代码 问题分析 我们将图片看成是一种序列问题,比如将28*28的图像中的每行 ...

  8. 深度学习实战(2)用Pytorch搭建双向LSTM

    用Pytorch搭建双向LSTM 应最近的课程实验要求,要做LSTM和GRU的实验效果对比.LSTM的使用和GRU十分相似,欢迎参考我的另外一篇介绍搭建双向GRU的Blog:https://blog. ...

  9. Pytorch实现的LSTM模型结构

    LSTM模型结构 1.LSTM模型结构 2.LSTM网络 3.LSTM的输入结构 4.Pytorch中的LSTM 4.1.pytorch中定义的LSTM模型 4.2.喂给LSTM的数据格式 4.3.L ...

最新文章

  1. Windows7 64bit VS2013 Caffe train MNIST操作步骤
  2. Python Logging模块实现运行的程序写入 日志
  3. python零基础怎么学-零基础如何学习Python?老男孩Python入门培训
  4. Python基础教程:内置类型之比较
  5. 如何制作linux系统硬盘,教你制作Linux操作系统的Boot/Root盘
  6. c++——reverse()函数的使用
  7. cad快捷键文件路径_办公格式转太难不会看这里!CAD、PDF、Word、Excel、TXT教你玩转...
  8. 从滴滴亏109亿说起
  9. Leetcode 207. 课程表(值得一刷的宽搜)
  10. 详解SpringBoot整合ace-cache缓存
  11. JVM监控及诊断工具命令行篇之jstatd
  12. 上月和本月对比叫什么_环比增长率怎么算月度(本周比上周是同比还是环比)...
  13. 分布式拒绝服务(DDoS)攻击原理介绍和防范措施
  14. API是什么?API的基础知识你知道多少
  15. 使用while循环实现xyz+yzz=532
  16. 【13】 数学建模 | 预测模型 | 灰色预测、BP神经网络预测 | 预测题型的思路 | 内附代码(清风课程,有版权问题,私聊删除)
  17. 什么是深度学习,深度学习和机器学习有什么关系?
  18. OpenStack Tracker
  19. 二进制有符号数补码计算器
  20. 钱大妈,一家卖猪肉的广告公司

热门文章

  1. 标准PSO辨识NARMAX模型源码程序
  2. Nginx Web服务应用
  3. mysql添加和root用户一样的权限
  4. Python -- dict 类
  5. cacti0.8.8安装文档
  6. js GPS 百度地图坐标转换
  7. iOS开发隐藏键盘方法总结
  8. QA:智能布线系统二十问
  9. 中小企业ERP快速实施的八大准则
  10. 把Zend Studio 5.5改为简体中文版的办法