本文将LSTM+attention用于时间序列预测

class lstm(torch.nn.Module):def __init__(self, output_size, hidden_size, embed_dim, sequence_length):super(lstm, self).__init__()self.output_size = output_sizeself.hidden_size = hidden_size#对应特征维度self.embed_dim = embed_dimself.dropout = 0.8#对应时间步长self.sequence_length = sequence_length#1层lstmself.layer_size = 1self.lstm = nn.LSTM(self.embed_dim,self.hidden_size,self.layer_size,dropout=self.dropout,)self.layer_size = self.layer_sizeself.attention_size = 30#(4,30)self.w_omega = Variable(torch.zeros(self.hidden_size * self.layer_size, self.attention_size))#(30)self.u_omega = Variable(torch.zeros(self.attention_size))#将隐层输入全连接self.label = nn.Linear(hidden_size * self.layer_size, output_size)

LSTM输入输出说明

1. 输入数据包括input,(h_0,c_0):
input就是shape==(seq_length,batch_size,input_size)的张量
h_0的shape==(num_layers×num_directions,batch,hidden_size)的张量
,它包含了在当前这个batch_size中每个句子的初始隐藏状态,num_layers就是LSTM的层数,如果bidirectional=True,num_directions=2,否则就是1,表示只有一个方向,
c_0和h_0的形状相同,它包含的是在当前这个batch_size中的每个句子的初始细胞状态。
==h_0,c_0如果不提供,那么默认是0
==

2. 输出数据包括output,(h_n,c_n):
output的shape==(seq_length,batch_size,num_directions×hidden_size),
它包含的LSTM的最后一层的输出特征(h_t),t是batch_size中每个句子的长度.
h_n.shape==(num_directions × num_layers,batch,hidden_size)
c_n.shape==h_n.shape
h_n包含的是句子的最后一个单词的隐藏状态,c_n包含的是句子的最后一个单词的细胞状态,所以它们都与句子的长度seq_length无关。
output[-1]与h_n是相等的,因为output[-1]包含的正是batch_size个句子中每一个句子的最后一个单词的隐藏状态,注意LSTM中的隐藏状态其实就是输出,cell
state细胞状态才是LSTM中一直隐藏的,记录着信息

def attention_net(self, lstm_output):#print(lstm_output.size()) = (squence_length, batch_size, hidden_size*layer_size)output_reshape = torch.Tensor.reshape(lstm_output, [-1, self.hidden_size*self.layer_size])#print(output_reshape.size()) = (squence_length * batch_size, hidden_size*layer_size)#tanh(H)attn_tanh = torch.tanh(torch.mm(output_reshape, self.w_omega))#print(attn_tanh.size()) = (squence_length * batch_size, attention_size)#张量相乘attn_hidden_layer = torch.mm(attn_tanh, torch.Tensor.reshape(self.u_omega, [-1, 1]))#print(attn_hidden_layer.size()) = (squence_length * batch_size, 1)exps = torch.Tensor.reshape(torch.exp(attn_hidden_layer), [-1, self.sequence_length])#print(exps.size()) = (batch_size, squence_length)alphas = exps / torch.Tensor.reshape(torch.sum(exps, 1), [-1, 1])#print(alphas.size()) = (batch_size, squence_length)alphas_reshape = torch.Tensor.reshape(alphas, [-1, self.sequence_length, 1])#print(alphas_reshape.size()) = (batch_size, squence_length, 1)state = lstm_output.permute(1, 0, 2)#print(state.size()) = (batch_size, squence_length, hidden_size*layer_size)attn_output = torch.sum(state * alphas_reshape, 1)#print(attn_output.size()) = (batch_size, hidden_size*layer_size)return attn_outputdef forward(self, input):# input = self.lookup_table(input_sentences)input = input.permute(1, 0, 2)# print('input.size():',input.size())s,b,f=input.size()h_0 = Variable(torch.zeros(self.layer_size, b, self.hidden_size))c_0 = Variable(torch.zeros(self.layer_size, b, self.hidden_size))print('input.size(),h_0.size(),c_0.size()',input.size(),h_0.size(),c_0.size())lstm_output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))attn_output = self.attention_net(lstm_output)logits = self.label(attn_output)return logits
在计算attention时主要分为三步:第一步是将query和每个key进行相似度计算得到权重,常用的相似度函数有点积,拼接,感知机等;
第二步一般是使用一个softmax函数对这些权重进行归一化;
最后将权重和相应的键值value进行加权求和得到最后的attention。

LSTM+attention代码原理详解相关推荐

  1. 循环神经网络RNN、LSTM、GRU原理详解

    一.写在前面 这部分内容应该算是近几年发展中最基础的部分了,但是发现自己忘得差不多了,很多细节记得不是很清楚了,故写这篇博客,也希望能够用更简单清晰的思路来把这部分内容说清楚,以此能够帮助更多的朋友, ...

  2. 51单片机实训项目之“万年历”代码原理详解

    一.原理图 二.芯片器件 STC89C52 DS18B20(温度传感器) DS1302(时钟芯片) LCD1602液晶显示 独立按键 杜邦线 三.仿真图 四.程序代码详解 (一).子程序 EEPROM ...

  3. EventHub代码原理详解

    一.EventHub简述 Android系统基于Linux系统,由多个子系统组合而成,各子系统分工合作,在各自功能域中扮演关键角色.其中一个比较重要的子系统是Input子系统,正如其名地,挂载于And ...

  4. 串口通信(SBUF代码原理详解)

    这里写目录标题 基本概念 读取数据手册 串口中断 代码讲解 基本概念 前言: 时钟对于单片机来说是非常重要的,它能为单片机提供一个稳定的机器周期从而使系统能够正常工作.它就像我们人类的心脏一样重要,一 ...

  5. LSTM内部实现原理详解

    https://cloud.tencent.com/developer/article/1528295

  6. Transformer 初识:模型结构+attention原理详解

    Transformer 初识:模型结构+原理详解 参考资源 前言 1.整体结构 1.1 输入: 1.2 Encoder 和 Decoder的结构 1.3 Layer normalization Bat ...

  7. Attention原理详解

    Attention原理详解 Attention模型 对齐 模型介绍 Attention整体流程 Step1 计算Encoder的隐藏状态和Decoder的隐藏状态 Step2 获取每个编码器隐藏状态对 ...

  8. 视频教程-深度学习原理详解及Python代码实现-深度学习

    深度学习原理详解及Python代码实现 大学教授,美国归国博士.博士生导师:人工智能公司专家顾问:长期从事人工智能.物联网.大数据研究:已发表学术论文100多篇,授权发明专利10多项 白勇 ¥88.0 ...

  9. 图像质量损失函数SSIM Loss的原理详解和代码具体实现

    本文转自微信公众号SIGAI 文章PDF见: http://www.tensorinfinity.com/paper_164.html http://www.360doc.com/content/19 ...

  10. TOPSIS(逼近理想解)算法原理详解与代码实现

    写在前面: 个人理解:针对存在多项指标,多个方案的方案评价分析方法,也就是根据已存在的一份数据,判断数据中各个方案的优劣.中心思想是首先确定各项指标的最优理想值(正理想值)和最劣理想值(负理想解),所 ...

最新文章

  1. oracle @spool,Oracle spool 用法小结
  2. Java中Split函数的用法技巧
  3. 二十二、linux定时器
  4. 十七、二叉树的建立与基本操作
  5. n分频器 verilog_基于Verilog的分频器实现
  6. 千万别用树套树(线段树)
  7. Linux系统编程(八)线程
  8. 领域驱动设计(DDD)前夜:面向对象思想
  9. Oracle的tnsnames.ora配置(PLSQL Developer)
  10. 【STM32】【STM32CubeMX】STM32CubeMX的使用之九:ADC
  11. LeetCode问题7
  12. 阿里巴巴基于Java容器的多应用部署技术实践
  13. Android逆向基础入门
  14. JAVA实现邮箱注册功能
  15. QT 使用QModbus类实现modbus TCP踩过的坑
  16. 构造函数创造对象--创建四大天王的对象
  17. Revit 绘制幕墙系统
  18. (附源码)anjule客户信息管理系统 毕业设计 181936
  19. 让猛男娇羞的AI算法
  20. 满爷的2019年终总结: 趋势、反思及展望

热门文章

  1. 点击流日志分析项目实战开发流程
  2. zookeeper分布式原理实战解析
  3. iQOO5G手机卡槽公布
  4. [AST实战]从零开始写一个wepy转VUE的工具
  5. 专访UCloud徐亮:UCloud虚拟网络的演进之路
  6. SpringBoot 使用小技巧合集
  7. Virtualbox以及VWare在Win10下的不兼容
  8. vue $emit 父组件与子组件之间的通信(父组件向子组件传参)
  9. 游戏服务器当中的唯一名设计方法
  10. [TWAIN] 3句话总结TWAIN在Windows Server 2008 R2 SP1的使用