作者 | 小宋是呢

转载自CSDN博客

一、Self-Attention概念详解

了解了模型大致原理,我们可以详细的看一下究竟Self-Attention结构是怎样的。其基本结构如下

对于self-attention来讲,Q(Query), K(Key), V(Value)三个矩阵均来自同一输入,首先我们要计算Q与K之间的点乘,然后为了防止其结果过大,会除以一个尺度标度640?wx_fmt=svg,其中640?wx_fmt=svg为一个query和key向量的维度。再利用Softmax操作将其结果归一化为概率分布,然后再乘以矩阵V就得到权重求和的表示。该操作可以表示为

640?wx_fmt=svg

这里可能比较抽象,我们来看一个具体的例子(图片来源于https://jalammar.github.io/illustrated-transformer/),该博客讲解的极其清晰,强烈推荐),假如我们要翻译一个词组Thinking Machines,其中Thinking的输入的embedding vector用640?wx_fmt=svg表示,Machines的embedding vector用640?wx_fmt=svg表示。

当我们处理Thinking这个词时,我们需要计算句子中所有词与它的Attention Score,这就像将当前词作为搜索的query,去和句子中所有词(包含该词本身)的key去匹配,看看相关度有多高。我们用640?wx_fmt=svg代表Thinking对应的query vector,640?wx_fmt=svg640?wx_fmt=svg分别代表Thinking以及Machines对应的key vector,则计算Thinking的attention score的时候我们需要计算640?wx_fmt=svg640?wx_fmt=svg的点乘,同理,我们计算Machines的attention score的时候需要计算640?wx_fmt=svg640?wx_fmt=svg的点乘。如上图中所示我们分别得到了640?wx_fmt=svg640?wx_fmt=svg的点乘积,然后我们进行尺度缩放与softmax归一化,如下图所示:

显然,当前单词与其自身的attention score一般最大,其他单词根据与当前单词重要程度有相应的score。然后我们在用这些attention score与value vector相乘,得到加权的向量。

如果将输入的所有向量合并为矩阵形式,则所有query, key, value向量也可以合并为矩阵形式表示:

其中640?wx_fmt=svg是我们模型训练过程学习到的合适的参数。上述操作即可简化为矩阵形式:

二、Self_Attention模型搭建

笔者使用Keras来实现对于Self_Attention模型的搭建,由于网络中间参数量比较多,这里采用自定义网络层的方法构建Self_Attention。

Keras实现自定义网络层。需要实现以下三个方法:(注意input_shape是包含batch_size项的

  • build(input_shape): 这是你定义权重的地方。这个方法必须设 self.built = True,可以通过调用 super([Layer], self).build() 完成。

  • call(x): 这里是编写层的功能逻辑的地方。你只需要关注传入 call 的第一个参数:输入张量,除非你希望你的层支持masking。

  • compute_output_shape(input_shape): 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。

    实现代码如下:

    
    

    这里可以对照一中的概念讲解来理解代码

    如果将输入的所有向量合并为矩阵形式,则所有query, key, value向量也可以合并为矩阵形式表示

    上述内容对应

    
    

    其中640?wx_fmt=svg是我们模型训练过程学习到的合适的参数。上述操作即可简化为矩阵形式:

    上述内容对应(为什么使用batch_dot呢?这是由于input_shape是包含batch_size项的

    
    

    这里 QK = QK / (64**0.5) 是除以一个归一化系数,(64**0.5)是笔者自己定义的,其他文章可能会采用不同的方法。

    三、训练网络

    项目完整代码如下,这里使用的是Keras自带的imdb影评数据集。

    
    

    四、结果输出

    
    

    参考链接:

    https://zhuanlan.zhihu.com/p/47282410

    原文链接

    https://blog.csdn.net/xiaosongshine/article/details/90600028

    (*本文为 AI科技大本营转载文章,转载请联系原作者)

    精彩推荐

    大会开幕倒计时8天!

    2019以太坊技术及应用大会特邀以太坊创始人V神与众多海内外知名技术专家齐聚北京,聚焦区块链技术,把握时代机遇,深耕行业应用,共话以太坊2.0新生态。即刻扫码,享优惠票价。

    推荐阅读

    • 真正的博士是如何参加AAAI, ICML, ICLR等AI顶会的?

    • Python最抢手、Java最流行、Go最有前途,7000位程序员揭秘2019软件开发现状

    • 程序员学Python编程或许不知的十大提升工具

    • 不要让 Chrome 成为下一个 IE!

    • 这位博士跑赢“地震波”:提前 10 秒预警宜宾地震!

    • 一张图告诉你到底学Python还是Java!

    • 鸿蒙将至,安卓安否?

    • 25岁创立加密城堡, 曾经独角兽创始人社会名流天才黑客是这里的沙发客, 如今却无人问津……

    • 352万帧标注图片,1400个视频,亮风台推最大单目标跟踪数据集

    你点的每个“在看”,我都认真当成了喜欢

    机器如何读懂人心:Keras实现Self-Attention文本分类相关推荐

    1. python attention机制_[深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心)...

      [深度应用]·Keras实现Self-Attention文本分类(机器如何读懂人心) 笔者在[深度概念]·Attention机制概念学习笔记博文中,讲解了Attention机制的概念与技术细节,本篇内 ...

    2. 万字总结Keras深度学习中文文本分类

      摘要:文章将详细讲解Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CNN.TextCNN. 本文分享自华为云社区<Keras深度学习中文 ...

    3. [Python人工智能] 二十八.Keras深度学习中文文本分类万字总结(CNN、TextCNN、LSTM、BiLSTM、BiLSTM+Attention)

      从本专栏开始,作者正式研究Python深度学习.神经网络及人工智能相关知识.前一篇文章分享了BiLSTM-CRF模型搭建及训练.预测,最终实现医学命名实体识别实验.这篇文章将详细讲解Keras实现经典 ...

    4. 基于keras中IMDB的文本分类 demo

      本次demo主题是使用keras对IMDB影评进行文本分类: importtensorflow as tffrom tensorflow importkerasimportnumpy as nppri ...

    5. 能读懂人心的人工智能 甚至可能植入人类大脑

      人工智能系统正在变得越来越聪明,它们不仅能下围棋.炒股票,现在还学会了写代码.由微软和剑桥大学研究员一同开发的人工智能系统DeepCoder,完成了人类编程挑战赛所设定的基本挑战. DeepCoder ...

    6. 【NLP】让AI读懂法律文书:一种基于多分类的关键句识别方法

      法律领域是近年来在 NLP 社区兴起的一个研究场景,许多研究者从不同的角度对其进行了大量研究,例如对当事人的情感分析.提取案件当事人信息,提取侦破案件的关键判决信息,预测案件的结果等等. 近日,来自斯 ...

    7. 朱松纯团队新作:让AI「读懂」人类价值观!登上Science Robotics

      视学算法专栏 作者:朱松纯团队 今日(7月14日),国际顶级学术期刊<Science Robotics >发表了朱松纯团队(UCLA袁路遥.高晓丰.北京通用人工智能研究院郑子隆.北京大学人 ...

    8. 公开课 | 让机器读懂你的意图——人体姿态估计入门

      机器视觉的主要任务是让机器看懂世界,而世界的主要组成是人类社会.我们一直在围绕物和人的识别展开研究:物品检测识别.行人检测与跟踪.人脸识别. 事实上,行人检测是人的整体粗粒度识别,人脸识别是人的局部特 ...

    9. 读懂女人心:互联网产业GDP未来靠女性撑起

      "一个成功的男人背后一定有一个伟大的女人,但马云除外,他成功的背后有千千万万的女人."这虽然是句玩笑话,却印证着一个事实,即女性用户已成为网购消费最大的贡献者. 无论是每年&quo ...

    最新文章

    1. Virtual Machine Remote Control Client Plus
    2. 【采用】概率图模型在反欺诈的应用(无监督机器学习)
    3. linux下改变python的版本
    4. 网络规划设计师学习攻略(2)
    5. 简述java的线程_Java多线程的简述
    6. Android GPS及地磁传感器 API
    7. 作者:周一懋(1982-),男,江苏汇誉通数据科技有限公司大数据事业部总监、工程师...
    8. 中国通信业:那些年,我们给用户挖的坑
    9. pdf格式压缩大小,pdf如何压缩大小?
    10. 解决:openstack-dashboard-登陆后显示报错
    11. Matlab提示Ill-conditioned covariance created at iteration
    12. Java详解:java对象转json字符串不加引号
    13. 低门槛,多玩法打金游戏 Tiny World
    14. 3D建模软件功能解析之Maya篇
    15. Spark Streaming简介 (三十四)
    16. mysql数据库特别大怎么备份_如何备份还原mysql数据库 mysql数据库太大备份与还原方法...
    17. 在Ubuntu安装和使用Anbox完整说明(一种在Linux使用Android应用的方法)
    18. 伦敦经济学院开设加密货币相关课程
    19. 如何在 Ubuntu 20.04 / KylinOS-V10-SP1 上安装 Sublime Text 4
    20. 09_JavaScript数据结构与算法(九)字典

    热门文章

    1. 3、JPA一些常用的注解
    2. 谈谈UI架构设计的演化
    3. Delphi开发的IOCP测试Demo以及使用说明。
    4. linux/nginx 安全增强
    5. php pkcs 1格式的公钥,解说--2--微信支付RSA公钥PKCS1格式转化成PKCS8格式的公钥
    6. C++基本知识点集锦(2022秋招)
    7. PAT1036:Boys vs Girls
    8. 快速入门linux系统的iptables防火墙 1 本机与外界的基本通信管理
    9. HTML5 Canvas编写五彩连珠(3):设计
    10. XML(eXtensible Markup Language)文件的解析