史上最直白之Attention详解(原理+代码)
目录
- 为什么要了解Attention机制
- Attention 的直观理解
- 图解深度学习中的Attention机制
- 总结
为什么要了解Attention机制
在自然语言处理领域,近几年最火的是什么?是BERT!谷歌团队2018提出的用于生成词向量的BERT算法在NLP的11项任务中取得了非常出色的效果,堪称2018年深度学习领域最振奋人心的消息。而Transformer的Encoder部分是 BERT 模型的核心组成部分,Transformer中最为巧妙的结构又是attention机制,这次咱们从Attention机制的原理写这篇博客既是对我自己学习的一个总结,也希望或许能对你有所帮助!!!
Attention 的直观理解
Attention 机制直观理解很像人类看图片的逻辑,当我们看一张图片的时候,我们并没有看清图片的全部内容,而是将注意力集中在了图片的焦点上。大家看一下下面这张图:
我们一定会看清「锦江饭店」4个字,如下图:
但是我相信没人会第一时间去关注「路上的行人」也不会意识到路的尽头还有一个「優の良品」,所以,当我们看一张图片的时候,其实是这样的:
而我们上面所说的,通过引入我们的视觉系统这种关注图片中的突出信息的例子,就是我们深度学习中Attention机制的最直观的理解,在深度学习中Attention机制就是通过矩阵运算的方式将模型的注意力集中在输入信息重点特征上,从而节省资源,快速获得最有效的信息。
图解深度学习中的Attention机制
来拿seq2seq的模型来举例子,一般的基于seq2seq的翻译模型模型如下图:
机器翻译场景中,输入的中文句子为:我是一个学生,Encoder-Decoder框架通过encoding得出了一个包含中文句子全部信息的H6向量,并通过H6逐步生成中文单词:”I“、”am“、”a“、”student“。在翻译”student“这个单词的时候,分心模型里面每个英文单词对于翻译目标单词”student“的贡献程度是相同的,这很显然是不合道理的。显然”学生“对于翻译成”student“更为重要。
那么它会存在什么问题呢?类似RNN无法捕捉长序列的道理,没有引入Attention机制在输入句子较短时影响不大,但是如果输入句子比较长,此时所有语义通过一个中间语义向量表示,单词自身的信息避免不了会消失,也就是会丢失很多细节信息,这也是为何引入Attention机制的原因。例如上面的例子,如果引入Attention的话,在翻译”student“的时候,会体现出英文单词对于翻译当前中文单词的不同程度影响,比如给出类似下面的概率分布:
(我,0.2)
(是,0.1)
(一个,0.2)
(学生,0.5)
那么attention机制是通过什么方式来对于输入信息实现这种功能的能,答案是引入所谓的q、k、v三个矩阵并进行运算实现的,Attention有很多不同种类,本文具体以self-attention的来讲解Attention机制的实现过程(self-Attention中的Q是对自身(self)输入的变换,而在传统的Attention中,Q来自于外部):1、在self-attention中,会有三种矩阵向量,即Q(Query)查询向量、K(key)键值向量、V(value)值向量。它们是通过X乘以三个不同的权值矩阵WQW_QWQ、WkW_kWk、WvW_vWv具体操作步骤如下:
注意,这里的每个单词都会通过这三个向量产生这三种矩阵,而这三种向量是怎么把每个单词联系起来的呢?
答案是在进行Attention运算时,首先会把当前单词产生的q(查询矩阵)和所有的k(键值矩阵进行相乘)得到一个中间结果,最后把自己的v(值矩阵)向量乘上这个中间结果矩阵,得到一个含有句子所有词语上下文信息的新向量。
q,k,v这三个向量在通过反向传播不断的学习,而逐步习得句子中那些信息是模型需要关注的重要特征。
self-Attention的实现代码:
# Muti-head Attention 机制的实现
from math import sqrt
import torch
import torch.nnclass Self_Attention(nn.Module):# input : batch_size * seq_len * input_dim# q : batch_size * input_dim * dim_k# k : batch_size * input_dim * dim_k# v : batch_size * input_dim * dim_vdef __init__(self,input_dim,dim_k,dim_v):super(Self_Attention,self).__init__()self.q = nn.Linear(input_dim,dim_k)self.k = nn.Linear(input_dim,dim_k)self.v = nn.Linear(input_dim,dim_v)self._norm_fact = 1 / sqrt(dim_k)def forward(self,x):Q = self.q(x) # Q: batch_size * seq_len * dim_kK = self.k(x) # K: batch_size * seq_len * dim_kV = self.v(x) # V: batch_size * seq_len * dim_vatten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self._norm_fact # Q * K.T() # batch_size * seq_len * seq_lenoutput = torch.bmm(atten,V) # Q * K.T() * V # batch_size * seq_len * dim_vreturn output
Self-Attention可以通过qkv矩阵的计算过程中直接将句子中任意两个单词的联系通过一个计算步骤直接联系起来,所以远距离依赖特征之间的距离被极大缩短,有利于有效地利用这些特征。除此外,Self-Attention对于增加计算的并行性也有直接帮助作用。正好弥补了RNN机制的两个缺点,这就是为何Self-Attention现在被广泛使用的主要原因。
总结
Attention机制笔者认为是Transformer模型中最出彩的设计,效果很好的同时可解释性也很强,在笔者后续的文章中会向大家再介绍大名鼎鼎的Transformer和BERT。希望看到这里,能帮助小伙伴你搞懂Attention机制,这样才能更好的理解后续的Transformer和BERT模型。
史上最直白之Attention详解(原理+代码)相关推荐
- 史上最小白之Attention详解
1.前言 在自然语言处理领域,近几年最火的是什么?是BERT!谷歌团队2018提出的用于生成词向量的BERT算法在NLP的11项任务中取得了非常出色的效果,堪称2018年深度学习领域最振奋人心的消息. ...
- 史上最直白的RNN详解(结合torch的example)
本文主要是结合torch的代码介绍RNN模型的过程及原理 目录 为什么需要RNN RNN的基本结构 torch中的RNN RNN的优缺点 为什么需要RNN 在最基本的全连接神经网络中,我们所建立的网络 ...
- 史上最小白之Transformer详解
1.前言 博客分为上下两篇,您现在阅读的是下篇史上最小白之Transformer详解,在阅读该篇博客之前最好你能够先明白Encoder-Decoder,Attention机制,self-Attenti ...
- 史上最简单MySQL教程详解(进阶篇)之存储过程(一)
史上最简单MySQL教程详解(进阶篇)之存储过程(一) 史上最简单MySQL教程详解(进阶篇)之存储过程(一) 什么是存储过程 存储过程的作用 如何使用存储过程 创建存储过程 DELIMITER改变分 ...
- 史上最简单MySQL教程详解(进阶篇)之存储引擎介绍及默认引擎设置
什么是存储引擎? MySQL存储引擎种类 MyISAM 引擎 InnoDB引擎 存储引擎操作 查看存储引擎 存储引擎的变更 修改默认引擎 什么是存储引擎? 与其他数据库例如Oracle 和SQL Se ...
- 史上最简单MySQL教程详解(进阶篇)之索引及失效场合总结
史上最简单MySQL教程详解(进阶篇)之索引及其失效场合总结 什么是索引及其作用 索引的种类 各存储引擎对于索引的支持 简单介绍索引的实现 索引的设置与分析 普通索引 唯一索引(Unique Inde ...
- 史上最简单MySQL教程详解(进阶篇)之视图
史上最简单MySQL教程详解(进阶篇)之视图 为什么要用视图 视图的本质 视图的作用 如何使用视图 创建视图 修改视图 删除视图 查看视图 使用视图检索 变更视图数据 WITH CHECK OPTIO ...
- 史上最小白之BM25详解与实现
史上最小白之BM25详解与实现 原理 BM25算法是一种计算句子与文档相关性的算法,它的原理十分简单:将输入的句子sentence进行分词,然后分别计算句子中每个词word与文档doc的相关度,然后进 ...
- 史上最小白之Bert详解
1.前言 关于BERT,张俊林博士有一篇特别好的文章:从Word Embedding到Bert模型-自然语言处理中的预训练技术发展史 非常透彻地讲解了Bert是怎么样从NNLM->Word2Ve ...
最新文章
- OpenCV中的TermCriteria模板类
- linux启动程序api编程,Linux编程中关于API函数与系统调用间关系
- 无法加载mspdb140.dll
- js 宽窄屏切换效果代码优化
- C#网络编程(异步传输字符串) - Part.3[转自JimmyZhang博客]
- mysql 共享表空间存储_MySQL InnoDB共享表空间
- Metasploit Framework(6)客户端渗透(上)
- CVPR 2018 挑战赛
- [答案解析]华工数电实验:简易交通灯控制电路的设计
- hashmap java 排序_Java 对HashMap进行排序的三种常见方法
- 职场潜规则:非985院校的简历,一律扔进垃圾桶
- 廊坊-北京,一月期满,回顾、感恩、奋进。。。
- Proxmox VE7.3+Ceph超融合私有云建设案例(低成本高价值,拿走不谢)
- 19张插画让你秒懂Kubernetes
- GraphQL 学习笔记
- 个人云电脑-推荐方案 - Parsec / Fastlink
- phpStorm中使用模板快速创建html基本网页代码
- 爱码物联SaaS一物一码_化妆品二维码防伪溯源系统
- 大虾说工具 -- 横展开
- keytool-importkeypair 的使用