LSTM整体架构图如下:

遗忘门如下:

第一个遗忘门得到的结果是不是全都属于0-1的数,相当于不同的权重。

输入门(其实也可以叫更新门)如下:

输出门如下:

对于输出门,有两个分支,一个是直接变成下一层的隐藏变量,一个是表示这一层的输出。

代码来源:BiLSTM的PyTorch应用 - mathor

'''code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor
'''
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Datadtype = torch.FloatTensor

准备数据

sentence = ('GitHub Actions makes it easy to automate all your software workflows from continuous integration and delivery to issue triage and more'
)
word2idx = {w: i for i, w in enumerate(list(set(sentence.split())))}
idx2word = {i: w for i, w in enumerate(list(set(sentence.split())))}
n_class = len(word2idx) # classification problem
max_len = len(sentence.split())
n_hidden = 5
#word2idx={'automate': 0,'all': 1,'and': 2,'integration': 3,'your': 4,'issue': 5,'continuous': 6,'triage': 7, 'delivery': 8,'Actions': 9, 'from': 10, 'easy': 11, 'software': 12,'makes': 13, 'it': 14, 'workflows': 15, 'GitHub': 16, 'to': 17,'more': 18}#id2word就直接与word2idx的键值对调换就行了#n_class=19
#max_len=21

处理数据

def make_data(sentence):input_batch = []target_batch = []words = sentence.split()for i in range(max_len - 1):input = [word2idx[n] for n in words[:(i + 1)]]input = input + [-1] * (max_len - len(input))target = word2idx[words[i + 1]]input_batch.append(np.eye(n_class)[input])target_batch.append(target)return torch.Tensor(input_batch), torch.LongTensor(target_batch)# input_batch: [max_len - 1, max_len, n_class]
input_batch, target_batch = make_data(sentence)
dataset = Data.TensorDataset(input_batch, target_batch)
loader = Data.DataLoader(dataset, 16, True)#16表示batch_size根据自己的电脑配置更改,

相关变量可视化

class BiLSTM(nn.Module):def __init__(self):super(BiLSTM, self).__init__()self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)# fcself.fc = nn.Linear(n_hidden * 2, n_class)def forward(self, X):# X: [batch_size, max_len, n_class]batch_size = X.shape[0]input = X.transpose(0, 1)  # input : [max_len, batch_size, n_class]hidden_state = torch.randn(1*2, batch_size, n_hidden)   # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]cell_state = torch.randn(1*2, batch_size, n_hidden)     # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))outputs = outputs[-1]  # [batch_size, n_hidden * 2]model = self.fc(outputs)  # model : [batch_size, n_class]return modelmodel = BiLSTM()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

pytorch对于LSTM的输入输出格式如下图所示:

# Training
for epoch in range(10000):for x, y in loader:pred = model(x)loss = criterion(pred, y)if (epoch + 1) % 1000 == 0:print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))optimizer.zero_grad()loss.backward()optimizer.step()


# Pred
predict = model(input_batch).data.max(1, keepdim=True)[1]
print(sentence)
print([idx2word[n.item()] for n in predict.squeeze()])

效果比原作者的效果要好一点,因为占位符我更改成了 -1,这样就不会影响到标签了。

注意,对于代码的模型的前项传播不懂的,可以看下面:

class BiLSTM_1(nn.Module):def __init__(self):super(BiLSTM_1, self).__init__()self.lstm = nn.LSTM(input_size=10, hidden_size=5, bidirectional=False)# fcself.fc = nn.Linear(n_hidden * 1, n_class)def forward(self, X):# X: [batch_size, max_len, n_class]batch_size = X.shape[0]input = X.transpose(0, 1)  # input : [max_len, batch_size, n_class]#随机初试化隐藏变量和记忆细胞变量hidden_state = torch.randn(1*1, batch_size, n_hidden)   # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]cell_state = torch.randn(1*1, batch_size, n_hidden)     # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]outputs, (hc, c) = self.lstm(input, (hidden_state, cell_state))outputs = outputs  # [batch_size, n_hidden * 2]#model = self.fc(outputs)  # model : [batch_size, n_class]return outputs,hc, cmodel = BiLSTM_1()
a=torch.randn(2,5,10)
output,hc,c=model(a)

其实每一层的输出,都直接拼接在一起了,而hc只表示最后一层的输出,所以output[-1]==hc的。

最后祝大家学有所成!

pytorch实现BiLSTM代码相关推荐

  1. pytorch实现BiLSTM+CRF用于NER(命名实体识别)

    pytorch实现BiLSTM+CRF用于NER(命名实体识别) 在写这篇博客之前,我看了网上关于pytorch,BiLstm+CRF的实现,都是一个版本(对pytorch教程的翻译), 翻译得一点质 ...

  2. 基于pytorch的Bi-LSTM中文文本情感分类

    基于pytorch的Bi-LSTM中文文本情感分类 目录 基于pytorch的Bi-LSTM中文文本情感分类 一.前言 二.数据集的准备与处理 2.1 数据集介绍 2.2 文本向量化 2.3 数据集处 ...

  3. Pytorch Bert+BiLstm文本分类

    文章目录 前言 一.运行环境 二.数据 三.模型结构 四.训练 五.测试及预测 前言 昨天按照该文章(自然语言处理(NLP)Bert与Lstm结合)跑bert+bilstm分类的时候,没成功跑起来,于 ...

  4. Resnet的pytorch官方实现代码解读

    Resnet的pytorch官方实现代码解读 目录 Resnet的pytorch官方实现代码解读 前言 概述 34层网络结构的"平原"网络与"残差"网络的结构图 ...

  5. pytorch geometric GraphSAGE代码样例reddit和ogbn_products_sage,为何subgraph_loader将sizes设成[-1]

    pytorch geometric GraphSAGE代码样例reddit和ogbn_products_sage,为何subgraph_loader将sizes设成[-1] loader infere ...

  6. 史上最详细的Pytorch版yolov3代码中文注释详解(四)

    史上最详细的Pytorch版yolov3代码中文注释详解(一):https://blog.csdn.net/qq_34199326/article/details/84072505 史上最详细的Pyt ...

  7. EDSR MDSR IRAN RCN -pytorch实现及代码常见问题

    EDSR MDSR IRAN RCN -pytorch实现及代码常见问题 代码下载地址:https://github.com/sanghyun-son/EDSR-PyTorch 环境配置 Depend ...

  8. 一步步读懂Pytorch Chatbot Tutorial代码(四) - 为模型准备数据

    文章目录 自述 有用的工具 代码出处 目录 头大 代码及说明 Prepare Data for Models 重点关注 indexesFromSentence zeroPadding binaryMa ...

  9. 一步步读懂Pytorch Chatbot Tutorial代码(二) - 数据处理

    文章目录 自述 代码出处 目录 代码 Create formatted data file (为了方便理解,把代码的顺序略微改一下, 此章节略长.) 1. `loadLines` 将文件的每一行拆分为 ...

最新文章

  1. python array 语法_Python基本语法
  2. android VectorDrawable使用笔记(五)
  3. 计算机视觉方面代码和论文
  4. 全球及中国海洋工程装备行业产值规模价值及投资风险预警报告2021-2027年版
  5. spark wordcount完整工程代码(含pom.xml)
  6. #1117. 编码 ( 字典树版 ) 题解分析
  7. git提交代码报错解决方法 Git-remote Incorrect username or password ( access token )
  8. Java应用程序中的验证
  9. angularJs关于指令的一些冷门属性
  10. [记录] ---阿里云java.io.IOException: Connection reset by peer的问题
  11. 齐鲁计算机函授学院,【齐鲁师范学院继续教育学院自考网站】2021自考本科|电话|专业有哪些...
  12. Java中文jsp页面_java中文乱码解决之道(七)—–JSP页面编码过程
  13. 基于MRG_MyISAM引擎的Mysql分表
  14. 增值税发票识别,智能自动识别
  15. cesium-加载天地图影像
  16. 迪杰斯特拉(Dijkstra)
  17. android 分享文件功能实现
  18. 常见的网络状态检测及分析工具
  19. JAVA面向对象(2)
  20. Mac Scrcpy无线连接

热门文章

  1. 【群晖】两种常用下载器设置及使用
  2. JS实现动画效果(利用定时器)
  3. Zemax操作--5(热分析)
  4. 一个低成本的FOC控制方案分享
  5. 吴军博士被ChatGPT粉丝 疯狂炮轰!他究竟做错了什么?
  6. 新东方和北大青鸟可能倒闭吗?
  7. 有理数的不定积分(真分式采用待定系数求解)——高等数学
  8. 研报精选230512
  9. OGC标准介绍 11
  10. 数学符号arg的含义