文章均从个人微信公众号“ AI牛逼顿”转载,文末扫码,欢迎关注!


过年的脚步越来越近,是不是该给家里贴上一副对联呢?除了买买买,有没有想过自己动手写出一副对联?来吧,撸起袖子加油干!只是提笔之后,有没有像穿肠兄这样呢

锅锅我当然没有华安的文采,又不能让夺命书生耀武扬威,只能借助AI去学习已有的对联数据,然后替我力挽狂澜。言归正传,开始本期的正题——基于attention机制的seq2seq模型的序列生成任务。

一、seq2seq模型的理解

1.  什么是seq2seq模型

seq2seq模型全称是sequence to sequence模型,是一个编码--解码(encoder--decoder )结构的网络,如图一所示。它的输入是一个序列,输出也是一个序列。 encoder(编码) 将一个可变长度的信号序列变为固定长度的向量表达(图一中的context向量),decoder 将这个固定长度的向量变成可变长度的目标信号序列。应用场景有:比如这里的对联生成——输入上联,就能输出下联。又比如图一中的场景——双语翻译。

图一

                                                                                               

2.  模型的输入是什么

输入肯定不是图一里的文字,计算机很智障的,怎么能理解这么复杂的东西!!需要把原始的文字进行向量化表示,模型才能开始进行训练。如图二所示,输入应该为这样。

图二

这里的每一个词都用一个长度为4的向量表示(举例说明,实际使用的过程中,这个向量长度就是词向量的维度。另外还假设了向量的取值,颜色的深浅与数字没有必然联系,只是为了加以区分)。

3.   seq2seq的计算流程

输入进来的一个训练句子,先要经过encoder过程,从而获取整个句子的信息,相当于计算机“理解”句子的含义。encoder过程如图三所示。

图三

这就是RNN网络的计算过程(如果不了解RNN,可以在网易云课堂观看吴恩达的《深度学习》课程,目前已经免费了)。图中的符号的意义分别为:

hidden  state  #0是网络计算所需的初始化参数(向量)。

input表示的是每个词的向量,如图二所示。

每个蓝色圆圈就是RNN细胞单元,可以是基础的RNN单元,也可以是LSTM或者GRU单元。图中可以看到,每个细胞单元有两个输入,一个输出(这里是seq2seq模型,所以只输出当前时间步的hidden  state,不输出output)。

hidden state #3是最后一个时间步上,细胞单元输出的结果(向量)。从计算流程图可以看出,这个输出要用到所有时间步上的input和中间的hidden state,所以hidden state #3就能包含整个输入语句的意思了。这个就是图一中的context向量。

接下来就是decoder过程了,就是把计算机“理解”的内容,最终转化成输出。该过程如图四所示(要比encoder过程复杂一些)

图四

计算过程的按图中的箭头方向推进。网络结构和encoder没有区别,只是输出方面有些差异。图中的细胞单元有两个输出,除了往右边的输出(hidden state,因为和图三一样,所以没有画出),还有往上边的输出,即每个时间步还会输出output。

(根据RNN的计算方法outputi=f(hidden statei),下标表示时间步,f就是激活函数)

input输入的是目标句子里的词向量。这里在第一个词的前面插入了一个‘<s>’标识符,表示开始,相当于告诉计算机,decoder过程开始了。

projection layers就是常见的全连接层网络加上softmax激活。

通过projection layers层,就能输出预测的词了,结尾处输出了一个‘<e>’标识符,表示句子的结尾。

可以看到,输出的预测结果是‘I am a driver’(假设的),与原始的目标句子不同。这样,就可以用预测结果与原始目标构建损失函数,常用交叉熵。迭代优化损失函数,则整个模型的参数不断的调整,直到优化结果满足条件。

要说明的是,图四展示的是训练过程的计算流程。如果模型训练好了,只做预测,则input过程有些变化:第一个input依然是‘<s>’,以此预测出一个词,然后把这个预测结果作为下一步的input,然后又预测出一个词,后面依次类推。

4.  seq2seq模型的特点

属于端对端模型,或者说,在序列模型里属于many to many类型的建模。可以解决句子的变长问题(通过设定最大句子长度和padding操作来实现)。由于考虑了词语在句子中的顺序,可以学到整个句子要表达的语义信息。不足在于,如果输入的句子较长时,encoder部分很难学到全部的句子信息。

二、attention机制的理解

如图五所示,先从整体上看attention机制,区别在于把encoder过程里,每个时间步输出的hidden state都保留下来,利用这每一步的hidden state来进行decoder过程。

图五

在给出详细的计算公式之前,先给出计算流程,如图六所示:

图六
哈哈,跟着箭头上的数字走,就不会看得懵逼

第1步:decoder过程计算出第1步的hidden state(这里用dec_h1表示,图中没有画出来)。输入是encoder过程的最后一步hidden state(这里用enc_h3表示)和’<s>’的向量表示。

第2步:将dec_h1和enc_h1、enc_h2、enc_h3来计算相似性得分,最简单的得分计算办法就是求向量的內积,这个得分再经过一个softmax归一化后,得到每个enc_h的权重。

第3步:用第2步得到的权重和enc_h1、enc_h2、enc_h3进行加权求和,这就是attention机制的计算办法。

第4步:第3步的计算结果就是context_1向量,这个向量不同程度的融合了enc_h1、enc_h2、enc_h3的信息。比单纯用enc_h3当做context向量要好。

第5步:把dec_h1与context_1进行拼接,拼接的结果输入全连接层和softmax去预测目标词。

第6步:输入第二个时间步的word embedding,去计算dec_h2。

第7步:将dec_h2和enc_h1、enc_h2、enc_h3来计算相似性得分。和第2步完全一样。不过要强调的是,由于dec_h2和dec_h1不同,所以这一步得到的权重与第2步得到的权重不同。

第8步:与第3步的计算过程相同。

第9步:与第4步的计算过程相同。由于权重的值变了,所以得到的context_2向量和contex_1向量是不同的。这里也能再次看出,没有attention机制的过程里,解码依赖的context向量是相同的;而attention机制里,每一步的解码,依赖的context向量都不同。

第10步:与第5步完全一样。

有没有一步一步,似魔鬼的爪牙?!没关系,下图给出的公式,能够辅助上面的文字叙述。公式的重点是理解下标的含义!!!下标的含义!!!下标的含义!!!

图七

恭喜各位,到此,基于attention机制的seq2seq模型的原理就差不多说完了。还有一些细节问题有待后面的文章中再分享,另有一些就直接在文末附上链接供大家学习。接下来就是案例实现的代码详解了。

三、代码详解

1.  文件说明

代码实验环境为:win7 + Python3.6 +TensorFlow1.12。尤其是TensorFlow的版本要注意了,不同版本,有些函数的用法会发生变化。图八显示的是整个代码里包含的文件。

图八

couplet文件夹里有训练数据和测试数据,还有词表数据

models文件夹存放训练阶段里保存的模型

attention_new_utils.py文件里写好了带有attention机制的seq2seq  模型

data_utils.py文件里封装了数据处理的类

eval_function.py文件里定义了BLEU计算的函数,用了评估训练阶段的模型效果

model.py文件里定义了seq2seq模型,是整个代码块的主函数

2.  代码详解

(1) 数据处理部分(data_utils.py)

图九

padding_seg

函数用来补齐句子的长度(如果句子长度小于最大长度,补零填充到最大长度);

encode_text

函数用来把文本序列映射为下标id序列,这个函数在模型的输入阶段要用到。(例如输入的句子为‘我 要 去 打球’,映射为‘[59, 79, 42, 31]’,这里的输入句子已经分过词,下标值是举例假设的,具体的值根据词表里的词与id对应关系决定)

decode_text

函数用来把输出的下标id序列映射回文本序列,这个函数在模型的输出阶段要用到,模型输出的是id序列,要输出最后的文本结果,需要这个函数进行转换。这里的’</s>’是词表里表示句子结尾的符号。

read_vocab

函数读入词表。由于原始词表里只有’<s>’、’</s>’分别表示句子的开头和结尾,但是用零填充句子后,id值为0的数字没有标识符与之对应,所以要添加一个’<pad>’标识符与之对应。

(2) seq2seq模型

(attention_new_utils.py)

图十

getLayeredCell

函数定义了多层RNN的网络结构,RNN的单元结构为LSTM结构,每层输出进行dropout操作。

bi_encoder

函数定义了双向RNN网络,并对网络的输出进行了处理。这个函数的主体是两部分:

调用自定义的getLayeredCell函数,返回双向RNN需要传入的RNNCell实例;

调用函数tf.nn.bidirectional_dynamic_rnn,创建一个双向循环神经网络。由于参数里默认选择了time_major=False,所以输出的结果里,bi_encoder_output是一个(output_fw, output_bw)元组,包含前向和后向RNN输出的张量。

这里各自输出的形状如下:

output_fw:

[batch_size,max_time,cell_fw.output_size]

output_bw:

[batch_size,max_time,cell_bw.output_size]

这里cell_fw.output_size与cell_bw.output_size不一定相等,要根据cell_fw和cell_bw的传入的尺寸来决定。

encoder_output:

[batch_size, max_time, cell_bw.output_size + cell_bw.output_size]

bi_encoder_state是一个(output_state_fw, output_state_bw)的元组,包含双向RNN的前向和后向的最终状态。

图十一

TensorFlow里已经封装好了attention机制的代码块,则三步即可实现attention机制:

定义一个RNNCell实例;选择注意力机制(这里提供了两种不同的机制BahdanauAttention和LuongAttenion,区别在于计算的方式有点区别);利用AttentionWrapper进行封装。

图十二
(说明:截图是初始seq2seq模型,后续增加了beam_search算法的功能)

模型的搭建流程非常明了:

embedding向量----encoder层----带attention机制的decoder层。

解码时要将训练阶段的解码和预测阶段的解码分开。两种区别就在于helper的定义不同,产生这种差别的原因在于,训练阶段,是有真实的目标句子的。而在预测阶段,模型已经训练好,根据输入的测试句子,要产生最可能的句子作为输出。最后,输出的outputs里,包含的内容有:

对于不使用beam_search的时候,它里面包含两项:

(rnn_outputs, sample_id),形状分别为

rnn_output:

[batch_size, decoder_targets_length, vocab_size]

sample_id:

[batch_size, decoder_targets_length]

对于使用beam_search的时候,它里面包含两项

(predicted_ids, beam_search_decoder_output),形状分别为

predicted_ids:

[batch_size, decoder_targets_length, beam_size]

beam_search_decoder_output:

(scores, predicted_ids, parent_ids)

(3)评估函数(eval_function.py)

这里的函数实现bleu得分的计算,关于bleu得分的详细计算,阅读文末的链接内容。

(4)主函数(model.py)

里面定义一个类,主要实现的功能就训练、评估和推断。由于模型已经搭建好,所以这里的代码块,主要是根据TensorFlow的语法,定义图、会话、占位符,然后传入数据运行会话。附带实现的功能包括保存训练模型,训练日志的记录,模型的加载等。部分说明见详细代码的注释。

代码调试这一块的巨坑,谁调谁知道!!个人经验是——必须先弄懂模型的原理,这样的话,代码的流程看起来就不会吃力;另外就是要知道每一步计算的输出长什么样,这是调试最耗费时间的地方,稍不留意,分分钟给你报错;最后要适当看看源码,里面的注释或者案例写法有巨大的帮助。

3.  实验结果

直接上图。开始训练50轮后,输出的结果如下图。这是什么鬼?(说明:src是测试集里输入的上联,output是机器输出的下联,target是测试集里与上联对应的下联。)

图十三

训练2000轮后,输出的结果如下图。看起来好多了,这水平能不能怼过穿肠兄不要紧,反正可以碾压我了。

图十四

4.  代码自评

(1)功能

实现多层的双向LSTM网络结构

带有attention机制的解码过程

beam_search算法搜索预测结果

bleu得分的计算

(2)不足

灵活性不足,体现在:只设置了多层的双向LSTM网络结构,如果要使用其他结果,要重新设置;

bleu得分的计算是针对评估过程中所有的输出结果,对整体进行计算一个得分,可以理解为平均值,没有做到对单个输出结果的逐个得分计算;

模型的调用不是很方便,如果从头开始,要先调用一次,给定训练的参数。训练完成后,还要再调用一次模型,给出推断的参数,此时才能完成智能输出的目的;

模型推理的结果有beam size个,但是输出好像不是按照一定的大小顺序排列的,这个结果应该还是要人工来选择最优结果。

千里之行始于足下!定期分享人工智能的干货,通俗展现原理和案例实现,并探索案例在中学物理教育过程中的使用。还有各种有趣的物理科普哟。坚持原创分享!坚持理解并吸收后的转发分享!欢迎大家的关注与交流。

BLEU得分:

https://blog.csdn.net/weixin_40240670/article/details/85112078

beam_search算法:

https://zhuanlan.zhihu.com/p/28048246

参考:

Visualizing A Neural Machine Translation Model

原始论文:

https://arxiv.org/pdf/1409.0473.pdf

代码下载链接:https://download.csdn.net/download/weixin_43917778/11635581

对联智能生成的原理(学习笔记附代码实现与详解)相关推荐

  1. 小猫爪:i.MX RT1050学习笔记26-RT1xxx系列的FlexCAN详解

    i.MX RT1050学习笔记26-RT1xxx系列的FlexCAN详解 1 前言 2 FlexCAN简介 2.1 MB(邮箱)系统 2.1.1 正常模式下 2.1.2 激活了CAN FD情况下 2. ...

  2. JDBC学习笔记02【ResultSet类详解、JDBC登录案例练习、PreparedStatement类详解】

    黑马程序员-JDBC文档(腾讯微云)JDBC笔记.pdf:https://share.weiyun.com/Kxy7LmRm JDBC学习笔记01[JDBC快速入门.JDBC各个类详解.JDBC之CR ...

  3. IP地址和子网划分学习笔记之《IP地址详解》

    在学习IP地址和子网划分前,必须对进制计数有一定了解,尤其是二进制和十进制之间的相互转换,对于我们掌握IP地址和子网的划分非常有帮助,可参看如下目录详文. IP地址和子网划分学习笔记相关篇章: 1.I ...

  4. 我的学习笔记——CSS背景渐变(Gradients)详解

    我的学习笔记--CSS背景渐变(Gradients)详解 一.线性渐变(Linear Gradients) 1.语法 background-image: linear-gradient(directi ...

  5. IP地址和子网划分学习笔记之《子网划分详解》

    一,子网划分概述 IP地址和子网划分学习笔记相关篇章: 1.IP地址和子网划分学习笔记之<预备知识:进制计数> 2.IP地址和子网划分学习笔记之<IP地址详解> 3.IP地址和 ...

  6. redis学习笔记(2)之redis主从详解

    redis主从详解 主从详解 主从配置 拓扑 原理 数据同步 概念 复制偏移量 复制积压缓冲区 主节点运行ID Psync命令 全量复制流程 部分复制流程 心跳 缓冲大小调节 读写分离 内容来源为六星 ...

  7. redis学习笔记(7)之redis哨兵详解

    redis哨兵详解 sentinel命令 客户端连接 素材代码 思路 实现过程 哨兵的切换实现原理 发布订阅基础 哨兵的实现原理 部署建议 需要关注的问题 代码流程 内容来源为六星教育,这里仅作为学习 ...

  8. Apollo星火计划学习笔记——Apollo决策规划技术详解及实现(以交通灯场景检测为例)

    文章目录 前言 1. Apollo决策技术详解 1.1 Planing模块运行机制 1.2 Apollo决策功能的设计与实现 1.2.1参考路径 Reference Line 1.2.2 交规决策 T ...

  9. MIT 6.824 学习笔记(一)--- RPC 详解

    从本文开始,将记录作者学习 MIT 6.824 分布式系统的学习笔记,如果有志同道合者,欢迎一起交流. RPC 的定义和结构 RPC 全称为 Remote Procedure Call,他表示一种远程 ...

最新文章

  1. AI化身监工,上班还能摸鱼吗?
  2. php源码之计算两个文件的相对路径
  3. 用云服务器实现janus之web端与web通话!
  4. 手动实现kt(java)同步工作流和异步工作流
  5. 姜汝祥的-赢在执行 - 制度执行力的三要三化
  6. 解释一下Spring支持的几种bean的作用域
  7. 前端学习(1979)vue之电商管理系统电商系统之让文本框获得焦点
  8. 【转】HTML5杂谈 概念与现行游戏 割绳子 宝石迷阵
  9. 【转】嵌入式软件:C语言编码规范
  10. word计算机公式怎么算,word怎么实现自动计算公式
  11. 【原创】导读”淘宝褚霸关于 gen_tcp 的分享“
  12. R语言knn算法的两种方法:class包与kknn包
  13. layui表格合并的方法
  14. 对拍--from Altf4
  15. 20系列显卡服务器,RTX20系列被严重低估,他不仅是一张游戏显卡
  16. cannot unbox null value
  17. Linux最全面试题100问答,纯纯爽文
  18. 一张图片切割成九宫格,微信朋友圈发布
  19. CSS的世界(十四)
  20. 京津冀计算机学科大学排名,2021京津冀地区民办大学排名前十

热门文章

  1. java 职级评定申报_职位等级评价方法(职级评价法)
  2. 空心圆圈里的数字,实现
  3. NYOJ 455题 黑色帽子
  4. executeQuery() 实现什么功能?
  5. “数字公务员”纷纷上岗,提高12345热线工单处理效率
  6. 恢复Win10中缺少的电源计划
  7. 希腊棺材之谜——复盘
  8. 实战 | 我用“大白鲨”让你看见 TCP
  9. 程序员面试必备——《Java程序员面试笔试宝典》pdf
  10. 基于51单片机的直流电机转速显示+加速减速启停