教程原文在这里Tutorial,这篇文章中用LSTM实现了一个简单的词类标注模型。下面是一些具体的解析:

# Author: Robert Guthrieimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimtorch.manual_seed(1)
# 引用库函数

我们首先了解如何初始化一个nn.LSTM实例, 以及它的输入输出。初始化nn.LSTM实例, 可以设定的参数如下:

常用的是前两个,用来描述LSTM输入的词向量维度和输出的向量的维度(与hidden state相同),其中num_layer指的是这样的结构:

这种称作stacked LSTM,如上就是两层LSTM堆叠起来。bi-direction指的是双向,双向的LSTM会从正反两个方向读句子,依次输入词向量, 两个方向的hidden state也并不是公共的,如下图:

对应到下面代码的第一行,就是创建了一个输入输出的维度均为3、单层单向的LSTM网络。

lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5

这个网络的输入输出怎么样呢?LSTM的基本功能是接收一个句子(一个词向量序列),从第一个词开始逐个后移,移到每一个词的时候,根据hidden state、cell state以及当前的词向量计算输出,并更新hidden state 和cell state,因此输入首先是一个词向量列,同时也可以设定一开始的hidden state和cell state,如果不设定那就自动初始化为0。而输出有三个,一个是每一步的输出构成的序列,这里每一个输出对应句子中的每一个词,第二个输出是最后的hidden state,第三个则是最后的cell state,具体的输入输出如下图:

# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),torch.randn(1, 1, 3))for i in inputs:# Step through the sequence one element at a time.# after each step, hidden contains the hidden state.out, hidden = lstm(i.view(1, 1, -1), hidden)# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument  to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)

如上,hidden实际上是(h0, c0),它的三个维度的含义是(num_layer*num_direction, batch_size, hidden_size),第一维和第三维都是由LSTM实例的参数确定的,batch_size倒是很特别,如果hidden的第二个维度不为1,那难道不同的batch的hidden state不会公用吗?
LSTM的输入维度也很有意思,三个维度的含义为:(sequence_len, batch_size, input_size),第三个维度就是词向量embedding长度,sequence len是句长,batch_size是分批的大小,这跟我们一般常用表示方式(batch_size, sequence_len, input_size) 相反。而且我们看到使用时总是配合tensor.view这个方法, 这个方法其实就是tensor的reshape, 它究竟会让tensor怎样重新排列呢?我们可以看下面的结果:

   input = np.array([11, 12, 21, 22, 31, 32])input_tensor = torch.from_numpy(input)print(input_tensor.view(2, 3, 1))

输出:

tensor([[[11],[12],[21]],[[22],[31],[32]]])

按照LSTM的理解,这样的input就是3个长为2的句子为一批,每一个词向量维度都是1。我们的第一个句子是[11, 22],第二个是[12, 31],这就有一个问题,原本是连续输入的单词,现在反而被隔开了,假设我们的输入数据是一句话连着一句话,LSTM反而会理解为每一个句子的开头连成一句话。比如上面,LSTM神经网络就是先输入i = [[11], [12], [21]],这三者没有先后顺序,而是平行地进行向量运算:
联想到刚刚h的维度是(batch_size, hidden_size)(不考虑双向和stack),更加可以说明这里batch中各个矩阵运算其实就是平行的)

然后再输入每句话的第二个词[[22], [31], [32]]这时,会分别利用到前面三个的hidden state、cell state并且更新。
其实我觉得这样设计还挺迷的

pytorch实现lstm分类模型相关推荐

  1. Python实现PSO粒子群优化循环神经网络LSTM分类模型项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 PSO是粒子群优化算法(Particle Swarm Optim ...

  2. pytorch实现简易分类模型

    1 导入库 import torch import matplotlib.pyplot as plt import torch.nn.functional as F 2 数据处理 n_data=tor ...

  3. pytorch实现二分类模型

    使用的数据集是iris 一共150行数据, 三种花各有50行数据, 这里取了前100行, 选两种花进行二分类. 数据集地址:https://github.com/hydra-ZD/AI/blob/ma ...

  4. 动手学深度学习(PyTorch实现)(二)--softmax与分类模型

    softmax与分类模型 1. 关于softmax 1.1 基本概念 1.2 交叉熵损失函数 1.3 模型训练与预测 2. 获取Fashion-MNIST训练集 2.1 下载数据集 2.2 处理数据集 ...

  5. 【项目实战】Python实现深度神经网络RNN-LSTM分类模型(医学疾病诊断)

    说明:这是一个机器学习实战项目(附带数据+代码+视频+文档),如需数据+完整代码可以直接到文章最后获取. 1.项目背景 随着互联网+的不断深入,我们已步入人工智能时代,机器学习作为人工智能的一个分支越 ...

  6. python文本分类模型_下载 | 最全中文文本分类模型库,上手即用

    原标题:下载 | 最全中文文本分类模型库,上手即用 本文转自『大数据文摘』 如何选择合适的模型上手进行中文文本分类呢? 别慌,福利来了,GitHub上一位名为"huwenxing" ...

  7. 循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用

    循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用 1. Pytorch中LSTM和GRU模块使用 1.1 LSTM介绍 LSTM和GRU都是由torch.nn提供 通过观察文档, ...

  8. 独家 | 教你用Pytorch建立你的第一个文本分类模型!

    作者:Aravind Pai 翻译:王威力 校对:张一豪 本文约3400字,建议阅读10+分钟 本文介绍了利用Pytorch框架实现文本分类的关键知识点,包括使用如何处理Out of Vocabula ...

  9. Pytorch《LSTM模型》

    前面的博文我们讲了LSTM的原理与分析,这一篇我们用pytorch类LSTM做测试 完整测试代码如下,用于进行MNIST数据集测试,主要学习LSTM类的输入输出维度. 这里定义的LSTM模型是用了三层 ...

最新文章

  1. To rename a docker image
  2. BugkuCTF-Misc:隐写3
  3. iPhone开发之BASE64加密和解密
  4. AOL search
  5. Webpack基础之加载器
  6. wpf使用webbrowser时提示当前页面脚本发生错误_win7系统internet脚本错误的应对办法...
  7. 阿里云MaxCompute被Forrester评为全球云端数据仓库领导者
  8. 《深入理解Elasticsearch(原书第2版)》一2.3.3 把查询模板保存到文件
  9. 利用JDK1.5的工具对远程的Java应用程序进行监测(摘录)
  10. 拓端tecdat|Matlab正态分布、历史模拟法、加权移动平均线 EWMA估计风险价值VaR和回测Backtest标准普尔指数 SP500时间序列
  11. LINUX的awk和sed的常用用法 正则表达式 grep egrep用法
  12. VTD的文件结构和Project建立的思路
  13. 初中英语语法(002)-be动词和一般动词的一般现在时
  14. 金融贷款逾期模型 -- 029
  15. Spring-IOC与AOP是解决什么问题的?
  16. org.jboss.netty.util.internal.jzlib.ZStream scanned from multiple locations: jar:
  17. java.lang.NoClassDefFoundError: javax/servlet/http/HttpServlet
  18. BAT批处理如何去写Windows防火墙规则
  19. 如何用计算机进行绘画,怎么用电脑画画-PS电脑手绘的5个基本步骤,轻松自学成PS手绘达人...
  20. 对话系统 | (1) 任务导向型对话系统 -- 对话管理模型研究最新进展

热门文章

  1. 魅族手机MX4 MX4 Pro 魅蓝note 无法连接USB调试,adb连不上问题的解决方案
  2. 分区格式化大于2 TiB磁盘
  3. @GenericValue和@GenericGenerator详解
  4. Windows sever中域、域树、域森林之间的区别与联系
  5. spring注解详解与用法(总览)
  6. JS 即时刷新验证码图片代码
  7. 笨办法学python3进阶篇_笨办法学Python 3 进阶篇
  8. 网易我的世界导入皮肤服务器显示吗,网易我的世界导入皮肤方法 | 手游网游页游攻略大全...
  9. “讲得清,控得住,降得下”——红辽公司备件全生命周期管理创新实践
  10. 微立体岗位竞聘PPT模板