刚从图像处理的hole中攀爬出来,刚走一步竟掉到了另一个hole(fire in the hole*▽*)

1.RNN中的attention
pytorch官方教程:https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
首先,RNN的输入大小都是(1,1,hidden_size),即batch=1,seq_len=1,hidden_size=embed_size,相对于传统的encoder-decoder模型,attention机制仅在decoder处有所不同。下面具体看看:
1>保存了rnn每个词向量对应隐藏层的输出状态(encoder_outputs),用于decoder的attention机制

#train代码部分
for ei in range(input_length):
encoder_output, encoder_hidden = encoder(
input_tensor[ei], encoder_hidden)
encoder_outputs[ei] = encoder_output[0, 0]
1
2
3
4
5
2>AttnDecoderRNN的forward
1.输入的input经过embed

embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
1
2
2.获取关于输入的attention权重,这里的Q=decoder_rnn的input,K=decoder_rnn的隐藏元
2.1求Q和K相似度的方法有很多,这里让全连接层自己来学习,把embedded和hidden连接在一起经过fc层(部分修改了下)

similarity=self.attn(torch.cat((embedded[0], hidden[0]), 1))
1
2.2 经过softmax获得归一化的权重

attn_weights = F.softmax(similarity, dim=1)
1
3.权重应用于encoder输出的所有词对应的词向量上(对应相乘即可)->获得attention结果

attn_applied = torch.bmm(attn_weights.unsqueeze(0),encoder_outputs.unsqueeze(0))
1
4.把attention结果和decoder的输入cat在一起,使用1个全连接层来融合二者,最终生成带注意力机制的词向量

output = torch.cat((embedded[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
1
2
5.根据decoder的上一个输出单词来预测下一个单词,这里多插一句,decoder的首个输入为起始标志符’sos’,其根据encode最后的隐藏元来预测第一个单词,后面依次类推。

output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
1
2
3
4
2.transformer中的attention
“Attention is All You Need”(霸气标题),pytorch代码推荐2篇:
哈佛大学NLP研究组:http://nlp.seas.harvard.edu/2018/04/03/attention.html
台湾小哥的代码(较通俗):https://github.com/jadore801120/attention-is-all-you-need-pytorch:

下面以soft_attention为例(*input和output的attention,仅和self_attention做下区分,第1篇代码标记src_attn,第2篇代码标记dec_enc_attn),soft_attention的目标:给定序列Q(query,长度记为lq,维度dk),键序列K(key,长度记为lk,维度dk),值序列V(value,长度记为lv,维度dv),计算Q和K的相似度权重,最后再乘上V。

下面直接贴上attention-is-all-you-need-pytorch中MultiHeadAttention代码

def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
#这里把batch和分块数放在一起,便于使用bmm
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)

output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

output = self.dropout(self.fc(output))
output = self.layer_norm(output + residual)
return output, attn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
和RNN中的attention的不同,这里的batch_size和seq_len均不为1,其把序列视为一个整体,求Q和V的相似度可使用点乘(V可以视为上面提及的encoder_outputs),获得的是一个相似度矩阵,比如Q是一个长度为10的序列,K是一个长度为16的序列,其相似度矩阵就是一个10*16的矩阵,再如矩阵第一行表示Q的第一个单词和K序列所有单词的相似度。
similarity:=(lq,dk)∗(dk,lk)=(lq,lk) similarity:=(lq,dk)*(dk,lk)=(lq,lk)
similarity:=(lq,dk)∗(dk,lk)=(lq,lk)

然后,生成带注意力机制的词向量(通常K和V取相同的值,因而有lv=lk),另外上面整合attn_applied和input使用的是cat操作,而这里使用的是残差(类似于unet和resnet),最后使用PositionwiseFeedForward(2个fc层)来融合attn_applied和input,最终生成带注意力机制的词向量。
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv) attention\_applied=(lq,lk)*(lv,dv)=(lq,dv)
attention_applied=(lq,lk)∗(lv,dv)=(lq,dv)

细节部分
在数据预处理部分,对序列s都进行了首尾标记,比如s=’’+ s + ‘’,刚看transform(之前跳过了seq2seq),对下面的代码甚是不解

decoder_input=target_seq[:, :-1] #这里不是去掉终止标记<eos>,去掉的可能是padding_0,只为兼容target_ground_y的序列长度?
encoder_input=input_seq[:, 1:] #encoder的输入序列去掉了起始标记<sos>
target_ground_y= target_seqtrg[:, 1:] #用于计算模型loss的target,去掉了起始标记<sos>
1
2
3
其实在pytorch官方教程中说的比较清楚,看下图

encoder的输入序列和ground_true只需要一个终止符即可,而decoder的输入序列开始必须指定一个起始符,让其根据context预测输出序列的第一个单词,后面根据前一个单词再预测下一个单词,依次类推直到当前预测的单词为终止标记’eos’,才计算loss.
---------------------
作者:PJ-Javis
来源:CSDN
原文:https://blog.csdn.net/jiangpeng59/article/details/84859640
版权声明:本文为博主原创文章,转载请附上博文链接!

转载于:https://www.cnblogs.com/jfdwd/p/11068075.html

pytorch笔记:09)Attention机制相关推荐

  1. Multimodal —— 看图说话(Image Caption)任务的论文笔记(二)引入attention机制

    在上一篇博客中介绍的论文"Show and tell"所提出的NIC模型采用的是最"简单"的encoder-decoder框架,模型上没有什么新花样,使用CNN ...

  2. 【TensorFlow实战笔记】对于TED(en-zh)数据集进行Seq2Seq模型实战,以及对应的Attention机制(tf保存模型读取模型)

    个人公众号 AI蜗牛车 作者是南京985AI硕士,CSDN博客专家,研究方向主要是时空序列预测和时间序列数据挖掘,获国家奖学金,校十佳大学生,省优秀毕业生,阿里天池时空序列比赛rank3.公众号致力于 ...

  3. Attention机制理解笔记(空间注意力+通道注意力+CBAM+BAM)

    Attention机制理解笔记 声明 Attention分类(主要SA和CA) spitial attention channel attention SA + CA(spitial attentio ...

  4. 什么是Attention机制以及Pytorch如何使用

    文章目录 前言 注意力概况 标准注意力 变种注意力 QKV 应用 前言 看了网上大部分人做的,都是说一个比较长的项目(特别是机器翻译的多).其实没有必要,很多人并不是想看一个大项目,只是想看看怎么用, ...

  5. Attention机制的总结笔记

    人类的视觉注意力 Attention机制借鉴了人类的视觉注意力机制.视觉注意力机制是人类视觉所特有的大脑信号处理机制.人类视觉通过快速扫描全局图像,获得需要重点关注的目标区域,也就是一般所说的注意力焦 ...

  6. 收藏 | PyTorch实现各种注意力机制

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 P ...

  7. seq2seq与Attention机制

    学习目标 目标 掌握seq2seq模型特点 掌握集束搜索方式 掌握BLEU评估方法 掌握Attention机制 应用 应用Keras实现seq2seq对日期格式的翻译 4.3.1 seq2seq se ...

  8. PYTORCH笔记 actor-critic (A2C)

    理论知识见:强化学习笔记:Actor-critic_UQI-LIUWJ的博客-CSDN博客 由于actor-critic是policy gradient和DQN的结合,所以同时很多部分和policy ...

  9. (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(2)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅰ

    开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...

最新文章

  1. Uber自动驾驶汽车被赶出了亚利桑那,近300人被裁
  2. CV公开课报名 | 快速搭建基于Python的车辆信息识别系统
  3. 「从源码中学习」面试官都不知道的Vue题目答案
  4. Spring Cloud 2020.0.0 正式发布,对开发者来说意味着什么?
  5. Oracle 排序中使用nulls first 或者nulls last 语法
  6. 四面阿里,看看你都会吗
  7. :after伪类+content经典应用举例
  8. python深度学习库keras——安装
  9. everedit 格式化json_Mac Init
  10. VSCode安装LeetCode插件
  11. BlackBerry Enterprise Service 10 for Android下载
  12. java中double..compare_为什么Java的Double.compare(double,double)实现了它的样子?
  13. 十款代码表白特效,一个比一个浪漫
  14. 为什么博客图片不显示?
  15. 价值工程杂志价值工程杂志社价值工程编辑部2022年第23期目录
  16. 软件架构风格介绍和总结
  17. 浅谈 Lempel-Ziv压缩方法
  18. 上线7天融资1.5亿,还有50多家VC在排队,子弹短信为何受追捧?
  19. 餐厅设置套餐 html,餐厅如何设计爆款套餐?掌握这5个原则就够了
  20. 软件接口设计 六大原则

热门文章

  1. Python进阶之一
  2. Android studio 使用NDK工具实现JNI编程
  3. 基于吉日嘎拉的通用权限管理WebForm版扩展:字典选项管理和缓存管理
  4. 关于.h .lib .dll的总结
  5. WebService安全 身份验证与访问控制
  6. C# C/S系统软件开发平台架构图(原创)
  7. 初识Ajax以及简单应用
  8. Python多线程学习
  9. IPC 之 Binder 初识
  10. Python流程控制语句