上一篇:NLP【06】RCNN原理及文本分类实战(附代码详解)

下一篇:NLP【08】深度学习模型在NLP中的发展——从Word2vec到Bert的演变过程

完整代码下载:https://github.com/ttjjlw/NLP/tree/main/Classify%E5%88%86%E7%B1%BB/rnn-cnn/tf1.x/models

一、前言

当前,bert横行,而bert的基础是transformer,自然,掌握transformer成为了基操。论文中的transformer是seq2seq模型,分为编码和解码,而在本文中,我所讲的transformer主要是transformer中的attention,也就是利用transformer中的attention与位置编码,替换cnn或者rnn部分完成文本分类。建议在看我这篇之前,先整体看一篇trnasformer的介绍,细节没看明白的,再来瞧瞧我这篇。

二、transformer的简单介绍

transformer分为encoder和decoder,这里主要讲encoder的三部分Positional Encoding、multi-head attention以及残差连接。Positional Encoding即位置编码也就是每个位置利用一个向量表示,具体公式如下:

这个公式怎么理解呢?pos就是词的位置,2i或2i+1就是 词向量的维度上的偶数或奇数的位置,举个例子:定义一个长度为100的位置向量,位置向量的维度为64,那么最终这个位置向量pos_embs的shape即为(100,64),那么这个位置向量表怎么得到呢,就是通过上面的公式。具体就是pos_embs[pos][i%2==0]=sin()这一串,而pos_embs[pos][i%2==1]=cos()这一串。而pos取值是从0到99,i取值是从0到63。我想看到这,应该都明白了,如果没明白,可以结合下面的代码实现来理解。

第二部分就是multi-head attention,首先,对比一下以前self.attention的做法,在做文本分类时,lookup词嵌入矩阵后,再经cnn或者rnn,会得到shape为(batch-size,seq_len,dim)的向量,记为M,然后我们是怎么做self.attention的呢?

1、先初始化一个可以训练的权重,shape为(batch_size, dim,1),记为W

2、然后M和W做矩阵相乘就的得到shape为(batch_size,seq_len)然后经softmax处理,就得到了每个词的权重

3、再把这个权重和原来的M做相乘(multiply),最后在seq_len的维度上做reduce_sum(),也就是output=reduce_sum(M,axis=1),则output的shape变为(batch-size,dim),也就是attention的最后输出,以上省略了所有的reshape过程。

该过程实现可参考https://github.com/ttjjlw/NLP/blob/main/Classify%E5%88%86%E7%B1%BB/rnn-cnn/tf1.x/models/bilstmatten.py。那transformer中的self.attention又是怎么做的呢?

1、把M复制三份,命名为query,key,value

2、分别初始化三个矩阵q_w,k_w,v_w,然后query,keyvalue与对应的矩阵做矩阵相乘,得Q,K,V,此时三者的shape都为(batch_size,seq_len,dim)

3、如果head为1的话,那就是similarity=matmul(Q,K的转置),所以similarity的shape为(batch_size,seq_len,seq_len),其这个矩阵记录就是每个词与所有词的相似性

4、output=matmul(similarity,V),所以output 的shape为(batch_size,seq_len,dim)

第三部分残差连接,残差连接就简单了:公式为 :H(x) = F(x) + x,这里就是H(x)=query+output,然后给H(x)进行层归一化。

整个过程大概就是如此,其中省略些细节,如similarity的计算其实还要除于根号dim,防止softmax(similarity)后非0即1,不利于参数学习。举个例子就明白了,softmax([1,10]) —> [1.2339458e-04, 9.9987662e-01] 而 softmax([0.1,1.0]) —> [0.2890505, 0.7109495]。

三、代码详解

1、位置编码

    def _position_embedding(self):"""生成位置向量:return:"""batch_size = self.config["batch_size"]sequence_length = self.config["sequence_length"]embedding_size = self.config["embedding_size"]# 生成位置的索引,并扩张到batch中所有的样本上position_index = tf.tile(tf.expand_dims(tf.range(sequence_length), 0), [batch_size, 1])position_embedding = np.zeros([sequence_length, embedding_size])for pos in range(sequence_length):for i in range(embedding_size):denominator = np.power(10000.0, i/ embedding_size)if i % 2 == 0:position_embedding[pos][i] = np.sin(pos / denominator)else:position_embedding[pos][i] = np.cos(pos / denominator)position_embedding = tf.cast(position_embedding, dtype=tf.float32)# 得到三维的矩阵[batchSize, sequenceLen, embeddingSize]embedded_position = tf.nn.embedding_lookup(position_embedding, position_index)return embedded_position

lookup后的词向量与位置向量相加形成新的向量

embedded_words = tf.nn.embedding_lookup(embedding_w, self.inputs)
embedded_position = self._position_embedding()
embedded_representation = embedded_words + embedded_position

把添加了位置向量的词向量,输入到self._multihead_attention()中(该方法就是依次经过attention,残差连接与层归一化得到最终的向量,就是上面详细介绍的2,3步过程),然后再经过self._feed_forward(), self._multihead_attention()与self._feed_forward()组成一层transformer

        with tf.name_scope("transformer"):for i in range(self.config["num_blocks"]):with tf.name_scope("transformer-{}".format(i + 1)):with tf.name_scope("multi_head_atten"):# 维度[batch_size, sequence_length, embedding_size]multihead_atten = self._multihead_attention(inputs=self.inputs,queries=embedded_representation,keys=embedded_representation)with tf.name_scope("feed_forward"):# 维度[batch_size, sequence_length, embedding_size]embedded_representation = self._feed_forward(multihead_atten,[self.config["filters"],self.config["embedding_size"]])outputs = tf.reshape(embedded_representation,[-1, self.config["sequence_length"] * self.config["embedding_size"]])output_size = outputs.get_shape()[-1].value

其中num_blocks就是设置要过几层transformer,output就是最终的结果。

完整代码下载:https://github.com/ttjjlw/NLP/tree/main/Classify%E5%88%86%E7%B1%BB/rnn-cnn/tf1.x/models

NLP【07】transformer原理、实现及如何与词向量做对接进行文本分类(附代码详解)相关推荐

  1. 【NLP傻瓜式教程】手把手带你HAN文本分类(附代码)

    继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...

  2. 【NLP傻瓜式教程】手把手带你fastText文本分类(附代码)

    写在前面 已经发布: [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) 继续NLP傻瓜式教程系列,今天的教程是基于FAIR的Bag of ...

  3. NLP【05】pytorch实现glove词向量(附代码详解)

    上一篇:NLP[04]tensorflow 实现Wordvec(附代码详解) 下一篇:NLP[06]RCNN原理及文本分类实战(附代码详解) 完整代码下载:https://github.com/ttj ...

  4. 【NLP傻瓜式教程】手把手带你RCNN文本分类(附代码)

    继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...

  5. 【NLP保姆级教程】手把手带你RNN文本分类(附代码)

    写在前面 这是NLP保姆级教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关于循环神经网络在多任务文本分类上的应用:Recur ...

  6. 【NLP】保姆级教程:手把手带你CNN文本分类(附代码)

    分享一篇老文章,文本分类的原理和代码详解,非常适合NLP入门! 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classifi ...

  7. 【NLP保姆级教程】手把手带你CNN文本分类(附代码)

    分享一篇老文章,文本分类的原理和代码详解,非常适合NLP入门! 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classifi ...

  8. 【NLP傻瓜式教程】手把手带你RNN文本分类(附代码)

    文章来源于NewBeeNLP,作者kaiyuan 写在前面 这是NLP傻瓜式教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关 ...

  9. 【NLP傻瓜式教程】手把手带你CNN文本分类(附代码)

    文章来源于NewBeeNLP,作者kaiyuan 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classification[ ...

  10. Transformer代码详解: attention-is-all-you-need-pytorch

    Transformer代码详解: attention-is-all-you-need-pytorch 前言 Transformer代码详解-pytorch版 Transformer模型结构 各模块结构 ...

最新文章

  1. 十二、进程的同步与互斥
  2. php怎么创建事务,php事务的实现方法介绍(代码示例)
  3. Regex pattern in openresty
  4. Silverlight 4新控件PivotViewer介绍
  5. hdp对应hadoop的版本_好程序员大数据学习路线分享hadoop的知识总结
  6. Node+fs+定时器(node-schedule)+MySql
  7. 结对-结对编项目作业名称-最终程序
  8. 模块化,组件化,插件化简析
  9. 智能物流的常见应用场景及系统
  10. 登录mysql服务器的典型命令_Mysql 远程登录及常用命令
  11. 2159: H.ly的小迷弟
  12. 金融python入门书籍推荐_学习金融工程,有哪些推荐的入门书籍?
  13. 图床,全网最稳定的免费图床
  14. sRGB转CIEXYZ转CIELAB,以及色彩距离
  15. 英语语法三大从句刷题总结
  16. Java飞书三方网站对接
  17. 制造企业发展第三方物流的思路探讨 (zt)
  18. matlab 求旁瓣,主瓣、栅瓣和旁瓣(MATLAB源代码 解释)
  19. tp5 自动加上html,【TP5.1】HTML标签自动转义,导致CKEditor保存内容无法正常显示!...
  20. 存储计算机当前正执行的应用程序,存储计算机当前正在执行的应用程序和相应的数据的存储器是什么...

热门文章

  1. submail 发送国际短信,国内短信,国际国内邮箱工具类
  2. 硬见小百科:PCB多层板各层含义与设计原则
  3. 100层楼扔两个鸡蛋的问题
  4. On Visible Surface Generation BY A PRIORI TREE STRUCTURES
  5. 66个史上最全的行业数据研报网站
  6. 由于启用计算机,win10由于启动计算机时出现了页面文件配置问题的详细解决方案...
  7. Gmail配置邮箱客户端
  8. 计算机导论 ——绪论
  9. java判断百分数_Java 验证前台返回的是不是百分数 在后台用正则表达式验证百分比数据...
  10. 10.恩智浦-车规级-MCU:S32K11X FTM-PWM输出实验