本文主要参考github上一个开源的seq2seq教程,在此基础上稍作修改

https://github.com/bentrevett/pytorch-seq2seq/blob/master/3%20-%20Neural%20Machine%20Translation%20by%20Jointly%20Learning%20to%20Align%20and%20Translate.ipynb

1.Seq2Seq模型


在我上一篇文章有这个代码,原理就是一开始利用编码器的hidden,解码器去生成相应的字符。

2.Attention机制

虽然Seq2Seq模型可以通过encoder生成的上文信息来生成相应的字符或者词语,但是却不能理解encoder中输入序列中句子内部的词语和词语,字符和字符之间潜在的关系。例如我们做翻译的时候,要将英文中的动词转化为中文,我们应该更关注英文中的动词词汇。这就是attention机制,它可以告诉我们每次翻译时更应该关注英文中的哪一部分。

计算权重的方式有很多,这里介绍最基础的一种:concat
原理是通过encoder每层最后一个状态和encoder中每个输入进行相应计算,最后得到一个权重向量。如图所示,将encoder的输入和s0叠加之后,经过两个线性变换得到权重向量。得到的权重每次decoder做解码的时候都要更新一次。因此attention的计算量要比传统的Seq2Seq要大的多。

3.代码

首先是Encoder,和传统的Seq2Seq基本没有区别,只是多了一个要输出最后一层的状态。

# 此例子默认编码器解码器的hidden_size相同
class Encoder(nn.Module):def __init__(self, input_dim, emb_dim, hidden_size, n_layers, dropout=0.5, bidirectional=True):super(Encoder, self).__init__()self.hidden_size = hidden_sizeself.n_layers = n_layersself.embedding = nn.Embedding(input_dim, emb_dim)  # input_dim数量等于源语言的字符数self.gru = nn.GRU(emb_dim, hidden_size, n_layers, dropout=dropout, bidirectional=bidirectional)self.fc = nn.Linear(hidden_size*2, hidden_size)def forward(self, input_seqs):# input_seqs  [seq_len, batch]embedded = self.embedding(input_seqs)# embedded  [seq_len, batch, embed_dim]outputs, hidden = self.gru(embedded)# outputs  [seq_len, batch, hidden_size * 2]# hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]# outputs are always from the last layerh_hat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)# h_hat [batch, hidden_size*2]# h_hat 是最后一层的最后一个状态,用于计算注意力权重# 通过线性层将h_hat转为编码器每次最后一层所输出的形式h_hat = torch.tanh(self.fc(h_hat))# h_hat [bacth, hidden]return outputs, hidden, h_hat

任意取一组数据输入到Encoder

INPUT_DIM = 8
OUTPUT_DIM = 6
HIDDEN_SIZE = 10
N_LAYERS = 2
EMB_SIZE = 12
BATCH_SIZE = 2
SEQ_LEN = 5
encoder = Encoder(INPUT_DIM, EMB_SIZE, HIDDEN_SIZE, N_LAYERS)
x = torch.randint(1, 7, (SEQ_LEN, BATCH_SIZE))
en_out, hidden, h_hat = encoder(x)

重点来了,Attention部分的代码如下:
这里有很多的拼接操作

class Attention(nn.Module):def __init__(self, hidden_size):super().__init__()self.attn = nn.Linear(3*hidden_size, hidden_size)self.v = nn.Linear(hidden_size, 1, bias = False)def forward(self, h_hat, encoder_outputs):# h_hat [batch, hidden_size]# encoder_outputs [seq_len, bacth, hidden_size*2]batch_size = encoder_outputs.shape[1]src_len = encoder_outputs.shape[0]  # 序列长度(时间步)# 编码器每个时间步最后的输出都要和h_hat做拼接,因此需要将h_hat复制序列长度份h_hat = h_hat.unsqueeze(1).repeat(1, src_len, 1)encoder_outputs = encoder_outputs.permute(1, 0, 2)#h_hat = [batch, seq_len, hidden_size]#encoder_outputs = [batch, seq_len, hidden_size * 2]energy = torch.tanh(self.attn(torch.cat((h_hat, encoder_outputs), dim = 2))) #energy  [batch, seq_len, hidden_size]attention = self.v(energy).squeeze(2)# attention [batch, seq_len]# 返回每个batch每个时间步的权重return F.softmax(attention, dim=1)

我们这里可以测试一下Attention的输出:

atten = Attention(HIDDEN_SIZE)
print(atten(h_hat, en_out))
"""
tensor([[0.1896, 0.1891, 0.2035, 0.2095, 0.2083],[0.1951, 0.2025, 0.2011, 0.1986, 0.2026]], grad_fn=<SoftmaxBackward>)
"""

最后是Decoder

class Decoder(nn.Module):def __init__(self, output_dim, emb_dim, hidden_size, n_layers, attention, dropout=0.5, bidirectional=True):super().__init__()self.output_dim = output_dimself.attention = attentionself.embedding = nn.Embedding(output_dim, emb_dim)self.rnn = nn.GRU(2*hidden_size + emb_dim, hidden_size, n_layers, bidirectional=True)self.fc_out = nn.Linear(4*hidden_size+emb_dim, output_dim)self.dropout = nn.Dropout(dropout)def forward(self, input, h_hat, hidden, encoder_outputs):# input  [batch]# h_hat  [batch, hidden_size]# encoder_outputs  [seq_len, batch, hidden_size * 2]# hidden [direction*n_layers, bacth, hidden_size]input = input.unsqueeze(0)# input  [1, batch]embedded = self.dropout(self.embedding(input))#embedded  [1, batch, emb_dim]a = self.attention(h_hat, encoder_outputs)# a  [batch, seq_len]a = a.unsqueeze(1)# a  [batch, 1, seq_len]encoder_outputs = encoder_outputs.permute(1, 0, 2)# encoder_outputs  [batch, seq_len, hidden_size * 2]weighted = torch.bmm(a, encoder_outputs)# weighted  [batch, 1, hidden_size * 2]weighted = weighted.permute(1, 0, 2)# weighted  [1, batch, hidden_size * 2]rnn_input = torch.cat((embedded, weighted), dim = 2)#rnn_input  [1, batch, 2*hidden_size+emb_dim]output, hidden = self.rnn(rnn_input, hidden)# output  [1, batch, hidden_size*2]# hidden  [layers * directions, batch, hidden_size]embedded = embedded.squeeze(0)output = output.squeeze(0)weighted = weighted.squeeze(0)prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))# prediction = [batch, output dim]# hidden [n_layers*directions, batch, hidden_size]return prediction, hidden

最后进行测试:

decoder = Decoder(OUTPUT_DIM, EMB_SIZE, HIDDEN_SIZE, N_LAYERS, atten)
decoder_input = torch.tensor([1, 2])
pre, hidden = decoder(decoder_input, h_hat, hidden, en_out)
print(pre.shape)
print(hidden.shape)"""
torch.Size([2, 6])
torch.Size([4, 2, 10])
"""

Attention机制--concat方式相关推荐

  1. seq2seq与Attention机制

    学习目标 目标 掌握seq2seq模型特点 掌握集束搜索方式 掌握BLEU评估方法 掌握Attention机制 应用 应用Keras实现seq2seq对日期格式的翻译 4.3.1 seq2seq se ...

  2. 一文看懂 Bahdanau 和 Luong 两种 Attention 机制的区别

    来自 | 知乎  作者 | Flitter 链接 | https://zhuanlan.zhihu.com/p/129316415 编辑 | 深度学习这件小事公众号 本文仅作学术交流,如有侵权,请联系 ...

  3. [深度学习] 自然语言处理 --- 基于Attention机制的Bi-LSTM文本分类

    Peng Zhou等发表在ACL2016的一篇论文<Attention-Based Bidirectional Long Short-Term Memory Networks for Relat ...

  4. 深度学习中Attention机制的“前世今生”

    关注公众号,发现CV技术之美 本文转载自FightingCV. [写在前面] 随着注意力在机器学习中的普及,包含注意力机制的神经结构也在逐渐发展.但是大多数人似乎只知道Transformer中的Sel ...

  5. NLP中的attention机制总结

    目录 1 attention机制原理 2 attention类型 2.1 按照是否可直接BP算法计算梯度进行分类 2.1.1 Soft attention 2.1.2 Hard attention 2 ...

  6. 深度学习视觉领域中的attention机制的汇总解读(self-attention、交叉self-attention、ISSA、通道注意、空间注意、位置注意、Efficient Attention等)

    self-attention来自nlp的研究中,在深度学习视觉领域有不少新的attention版本,为了解各种attention机制.博主汇集了6篇视觉领域中attention相关的论文,分别涉及DA ...

  7. Attention机制(一)基本原理及应用

    提纲: 1. 动机 2. 发展过程 3. 应用点 4. 代码实现 1. 动机 1.1 人类的视觉注意力 视觉注意力机制是人类视觉所特有的大脑信号处理机制,人类通过快速扫描全局图像,获得需要重点关注的目 ...

  8. 【TensorFlow实战笔记】对于TED(en-zh)数据集进行Seq2Seq模型实战,以及对应的Attention机制(tf保存模型读取模型)

    个人公众号 AI蜗牛车 作者是南京985AI硕士,CSDN博客专家,研究方向主要是时空序列预测和时间序列数据挖掘,获国家奖学金,校十佳大学生,省优秀毕业生,阿里天池时空序列比赛rank3.公众号致力于 ...

  9. Attention机制学习记录(四)之Transformer

    前言 注意力(Attention)机制[2]由Bengio团队与2014年提出并在近年广泛的应用在深度学习中的各个领域,例如在计算机视觉方向用于捕捉图像上的感受野,或者NLP中用于定位关键token或 ...

  10. Attention机制详解

    一.Attention 原理 在Encoder-Decoder结构中,Encoder把所有的输入序列都编码成一个统一的语义特征c再解码,因此, c中必须包含原始序列中的所有信息,它的长度就成了限制模型 ...

最新文章

  1. c语言中void跟argv,argc和argv []在C语言中
  2. 一个月6次泄露,为啥大家用Elasticsearch总不设密码?
  3. [小猫学NA]CCNA学习指南第二章笔记
  4. 6大最流行、最有用的自然语言处理库对比
  5. ways to improve your presentation by your own
  6. map内置函数分析所得到的思路
  7. OpenCV基本的SIMD的实例(附完整代码)
  8. 将一个项目中已有的文档添加到另一个项目中的方法
  9. 数据库-优化-MYSQL的执行顺序
  10. azure kinect三维点云_万众期待的 【三维点云处理】 课程来啦!
  11. 演示:使用Sniffer统计与分析流量
  12. Ajax中的JSON
  13. Cortex-M3 I-Code,D-Code,系统总线及其他总线接口
  14. 46个不可不知的生活小常识
  15. C++模板中关键字typename与class的区别
  16. python tkinter Checkbutton控件
  17. mfc mysql delete_MFC中简单的数据库文件操作(添加,修改,查找,删除)
  18. python 字典 列表 深度遍历_如何完全遍历未知深度的复杂字典?
  19. 如何用python画太阳花
  20. html如何修改title前的小图标

热门文章

  1. 免费顺丰快递单号查询电子面单api接口对接【快递鸟API】
  2. 基于Android的家庭财务管理流程图,基于android的个人财务管理系统的设计与实现.pdf...
  3. GandCrab4.0勒索病毒解密工具
  4. 【数据分析自学】二、Excel基础知识
  5. Java多线程及锁相关面试题
  6. 计算机一寸照编辑教程,Photoshop教您快速的制作标准一寸证件照教程
  7. mbot机器人自动超声波模式程序_测评 | mBot机器人秒变编程达人
  8. List集合去重的三种方法
  9. python判断成语是abac型_ABAC型的成语
  10. 计算机软件本科毕业生一般起薪多少,南京邮电大学本科毕业生平均薪资多少?一起来看看吧...