对联智能生成的原理(学习笔记附代码实现与详解)
文章均从个人微信公众号“ AI牛逼顿”转载,文末扫码,欢迎关注!
过年的脚步越来越近,是不是该给家里贴上一副对联呢?除了买买买,有没有想过自己动手写出一副对联?来吧,撸起袖子加油干!只是提笔之后,有没有像穿肠兄这样呢
锅锅我当然没有华安的文采,又不能让夺命书生耀武扬威,只能借助AI去学习已有的对联数据,然后替我力挽狂澜。言归正传,开始本期的正题——基于attention机制的seq2seq模型的序列生成任务。
一、seq2seq模型的理解
1. 什么是seq2seq模型
seq2seq模型全称是sequence to sequence模型,是一个编码--解码(encoder--decoder )结构的网络,如图一所示。它的输入是一个序列,输出也是一个序列。 encoder(编码) 将一个可变长度的信号序列变为固定长度的向量表达(图一中的context向量),decoder 将这个固定长度的向量变成可变长度的目标信号序列。应用场景有:比如这里的对联生成——输入上联,就能输出下联。又比如图一中的场景——双语翻译。
2. 模型的输入是什么
输入肯定不是图一里的文字,计算机很智障的,怎么能理解这么复杂的东西!!需要把原始的文字进行向量化表示,模型才能开始进行训练。如图二所示,输入应该为这样。
这里的每一个词都用一个长度为4的向量表示(举例说明,实际使用的过程中,这个向量长度就是词向量的维度。另外还假设了向量的取值,颜色的深浅与数字没有必然联系,只是为了加以区分)。
3. seq2seq的计算流程
输入进来的一个训练句子,先要经过encoder过程,从而获取整个句子的信息,相当于计算机“理解”句子的含义。encoder过程如图三所示。
这就是RNN网络的计算过程(如果不了解RNN,可以在网易云课堂观看吴恩达的《深度学习》课程,目前已经免费了)。图中的符号的意义分别为:
hidden state #0是网络计算所需的初始化参数(向量)。
input表示的是每个词的向量,如图二所示。
每个蓝色圆圈就是RNN细胞单元,可以是基础的RNN单元,也可以是LSTM或者GRU单元。图中可以看到,每个细胞单元有两个输入,一个输出(这里是seq2seq模型,所以只输出当前时间步的hidden state,不输出output)。
hidden state #3是最后一个时间步上,细胞单元输出的结果(向量)。从计算流程图可以看出,这个输出要用到所有时间步上的input和中间的hidden state,所以hidden state #3就能包含整个输入语句的意思了。这个就是图一中的context向量。
接下来就是decoder过程了,就是把计算机“理解”的内容,最终转化成输出。该过程如图四所示(要比encoder过程复杂一些)
计算过程的按图中的箭头方向推进。网络结构和encoder没有区别,只是输出方面有些差异。图中的细胞单元有两个输出,除了往右边的输出(hidden state,因为和图三一样,所以没有画出),还有往上边的输出,即每个时间步还会输出output。
(根据RNN的计算方法outputi=f(hidden statei),下标i 表示时间步,f就是激活函数)
input输入的是目标句子里的词向量。这里在第一个词的前面插入了一个‘<s>’标识符,表示开始,相当于告诉计算机,decoder过程开始了。
projection layers就是常见的全连接层网络加上softmax激活。
通过projection layers层,就能输出预测的词了,结尾处输出了一个‘<e>’标识符,表示句子的结尾。
可以看到,输出的预测结果是‘I am a driver’(假设的),与原始的目标句子不同。这样,就可以用预测结果与原始目标构建损失函数,常用交叉熵。迭代优化损失函数,则整个模型的参数不断的调整,直到优化结果满足条件。
要说明的是,图四展示的是训练过程的计算流程。如果模型训练好了,只做预测,则input过程有些变化:第一个input依然是‘<s>’,以此预测出一个词,然后把这个预测结果作为下一步的input,然后又预测出一个词,后面依次类推。
4. seq2seq模型的特点
属于端对端模型,或者说,在序列模型里属于many to many类型的建模。可以解决句子的变长问题(通过设定最大句子长度和padding操作来实现)。由于考虑了词语在句子中的顺序,可以学到整个句子要表达的语义信息。不足在于,如果输入的句子较长时,encoder部分很难学到全部的句子信息。
二、attention机制的理解
如图五所示,先从整体上看attention机制,区别在于把encoder过程里,每个时间步输出的hidden state都保留下来,利用这每一步的hidden state来进行decoder过程。
在给出详细的计算公式之前,先给出计算流程,如图六所示:
第1步:decoder过程计算出第1步的hidden state(这里用dec_h1表示,图中没有画出来)。输入是encoder过程的最后一步hidden state(这里用enc_h3表示)和’<s>’的向量表示。
第2步:将dec_h1和enc_h1、enc_h2、enc_h3来计算相似性得分,最简单的得分计算办法就是求向量的內积,这个得分再经过一个softmax归一化后,得到每个enc_h的权重。
第3步:用第2步得到的权重和enc_h1、enc_h2、enc_h3进行加权求和,这就是attention机制的计算办法。
第4步:第3步的计算结果就是context_1向量,这个向量不同程度的融合了enc_h1、enc_h2、enc_h3的信息。比单纯用enc_h3当做context向量要好。
第5步:把dec_h1与context_1进行拼接,拼接的结果输入全连接层和softmax去预测目标词。
第6步:输入第二个时间步的word embedding,去计算dec_h2。
第7步:将dec_h2和enc_h1、enc_h2、enc_h3来计算相似性得分。和第2步完全一样。不过要强调的是,由于dec_h2和dec_h1不同,所以这一步得到的权重与第2步得到的权重不同。
第8步:与第3步的计算过程相同。
第9步:与第4步的计算过程相同。由于权重的值变了,所以得到的context_2向量和contex_1向量是不同的。这里也能再次看出,没有attention机制的过程里,解码依赖的context向量是相同的;而attention机制里,每一步的解码,依赖的context向量都不同。
第10步:与第5步完全一样。
有没有一步一步,似魔鬼的爪牙?!没关系,下图给出的公式,能够辅助上面的文字叙述。公式的重点是理解下标的含义!!!下标的含义!!!下标的含义!!!
恭喜各位,到此,基于attention机制的seq2seq模型的原理就差不多说完了。还有一些细节问题有待后面的文章中再分享,另有一些就直接在文末附上链接供大家学习。接下来就是案例实现的代码详解了。
三、代码详解
1. 文件说明
代码实验环境为:win7 + Python3.6 +TensorFlow1.12。尤其是TensorFlow的版本要注意了,不同版本,有些函数的用法会发生变化。图八显示的是整个代码里包含的文件。
couplet文件夹里有训练数据和测试数据,还有词表数据
models文件夹存放训练阶段里保存的模型
attention_new_utils.py文件里写好了带有attention机制的seq2seq 模型
data_utils.py文件里封装了数据处理的类
eval_function.py文件里定义了BLEU计算的函数,用了评估训练阶段的模型效果
model.py文件里定义了seq2seq模型,是整个代码块的主函数
2. 代码详解
(1) 数据处理部分(data_utils.py)
padding_seg
函数用来补齐句子的长度(如果句子长度小于最大长度,补零填充到最大长度);
encode_text
函数用来把文本序列映射为下标id序列,这个函数在模型的输入阶段要用到。(例如输入的句子为‘我 要 去 打球’,映射为‘[59, 79, 42, 31]’,这里的输入句子已经分过词,下标值是举例假设的,具体的值根据词表里的词与id对应关系决定)
decode_text
函数用来把输出的下标id序列映射回文本序列,这个函数在模型的输出阶段要用到,模型输出的是id序列,要输出最后的文本结果,需要这个函数进行转换。这里的’</s>’是词表里表示句子结尾的符号。
read_vocab
函数读入词表。由于原始词表里只有’<s>’、’</s>’分别表示句子的开头和结尾,但是用零填充句子后,id值为0的数字没有标识符与之对应,所以要添加一个’<pad>’标识符与之对应。
(2) seq2seq模型
(attention_new_utils.py)
getLayeredCell
函数定义了多层RNN的网络结构,RNN的单元结构为LSTM结构,每层输出进行dropout操作。
bi_encoder
函数定义了双向RNN网络,并对网络的输出进行了处理。这个函数的主体是两部分:
调用自定义的getLayeredCell函数,返回双向RNN需要传入的RNNCell实例;
调用函数tf.nn.bidirectional_dynamic_rnn,创建一个双向循环神经网络。由于参数里默认选择了time_major=False,所以输出的结果里,bi_encoder_output是一个(output_fw, output_bw)元组,包含前向和后向RNN输出的张量。
这里各自输出的形状如下:
output_fw:
[batch_size,max_time,cell_fw.output_size]
output_bw:
[batch_size,max_time,cell_bw.output_size]
这里cell_fw.output_size与cell_bw.output_size不一定相等,要根据cell_fw和cell_bw的传入的尺寸来决定。
encoder_output:
[batch_size, max_time, cell_bw.output_size + cell_bw.output_size]
bi_encoder_state是一个(output_state_fw, output_state_bw)的元组,包含双向RNN的前向和后向的最终状态。
TensorFlow里已经封装好了attention机制的代码块,则三步即可实现attention机制:
定义一个RNNCell实例;选择注意力机制(这里提供了两种不同的机制BahdanauAttention和LuongAttenion,区别在于计算的方式有点区别);利用AttentionWrapper进行封装。
模型的搭建流程非常明了:
embedding向量----encoder层----带attention机制的decoder层。
解码时要将训练阶段的解码和预测阶段的解码分开。两种区别就在于helper的定义不同,产生这种差别的原因在于,训练阶段,是有真实的目标句子的。而在预测阶段,模型已经训练好,根据输入的测试句子,要产生最可能的句子作为输出。最后,输出的outputs里,包含的内容有:
对于不使用beam_search的时候,它里面包含两项:
(rnn_outputs, sample_id),形状分别为
rnn_output:
[batch_size, decoder_targets_length, vocab_size]
sample_id:
[batch_size, decoder_targets_length]
对于使用beam_search的时候,它里面包含两项
(predicted_ids, beam_search_decoder_output),形状分别为
predicted_ids:
[batch_size, decoder_targets_length, beam_size]
beam_search_decoder_output:
(scores, predicted_ids, parent_ids)
(3)评估函数(eval_function.py)
这里的函数实现bleu得分的计算,关于bleu得分的详细计算,阅读文末的链接内容。
(4)主函数(model.py)
里面定义一个类,主要实现的功能就训练、评估和推断。由于模型已经搭建好,所以这里的代码块,主要是根据TensorFlow的语法,定义图、会话、占位符,然后传入数据运行会话。附带实现的功能包括保存训练模型,训练日志的记录,模型的加载等。部分说明见详细代码的注释。
代码调试这一块的巨坑,谁调谁知道!!个人经验是——必须先弄懂模型的原理,这样的话,代码的流程看起来就不会吃力;另外就是要知道每一步计算的输出长什么样,这是调试最耗费时间的地方,稍不留意,分分钟给你报错;最后要适当看看源码,里面的注释或者案例写法有巨大的帮助。
3. 实验结果
直接上图。开始训练50轮后,输出的结果如下图。这是什么鬼?(说明:src是测试集里输入的上联,output是机器输出的下联,target是测试集里与上联对应的下联。)
训练2000轮后,输出的结果如下图。看起来好多了,这水平能不能怼过穿肠兄不要紧,反正可以碾压我了。
4. 代码自评
(1)功能
实现多层的双向LSTM网络结构
带有attention机制的解码过程
beam_search算法搜索预测结果
bleu得分的计算
(2)不足
灵活性不足,体现在:只设置了多层的双向LSTM网络结构,如果要使用其他结果,要重新设置;
bleu得分的计算是针对评估过程中所有的输出结果,对整体进行计算一个得分,可以理解为平均值,没有做到对单个输出结果的逐个得分计算;
模型的调用不是很方便,如果从头开始,要先调用一次,给定训练的参数。训练完成后,还要再调用一次模型,给出推断的参数,此时才能完成智能输出的目的;
模型推理的结果有beam size个,但是输出好像不是按照一定的大小顺序排列的,这个结果应该还是要人工来选择最优结果。
千里之行始于足下!定期分享人工智能的干货,通俗展现原理和案例实现,并探索案例在中学物理教育过程中的使用。还有各种有趣的物理科普哟。坚持原创分享!坚持理解并吸收后的转发分享!欢迎大家的关注与交流。
BLEU得分:
https://blog.csdn.net/weixin_40240670/article/details/85112078
beam_search算法:
https://zhuanlan.zhihu.com/p/28048246
参考:
Visualizing A Neural Machine Translation Model
原始论文:
https://arxiv.org/pdf/1409.0473.pdf
代码下载链接:https://download.csdn.net/download/weixin_43917778/11635581
对联智能生成的原理(学习笔记附代码实现与详解)相关推荐
- 小猫爪:i.MX RT1050学习笔记26-RT1xxx系列的FlexCAN详解
i.MX RT1050学习笔记26-RT1xxx系列的FlexCAN详解 1 前言 2 FlexCAN简介 2.1 MB(邮箱)系统 2.1.1 正常模式下 2.1.2 激活了CAN FD情况下 2. ...
- JDBC学习笔记02【ResultSet类详解、JDBC登录案例练习、PreparedStatement类详解】
黑马程序员-JDBC文档(腾讯微云)JDBC笔记.pdf:https://share.weiyun.com/Kxy7LmRm JDBC学习笔记01[JDBC快速入门.JDBC各个类详解.JDBC之CR ...
- IP地址和子网划分学习笔记之《IP地址详解》
在学习IP地址和子网划分前,必须对进制计数有一定了解,尤其是二进制和十进制之间的相互转换,对于我们掌握IP地址和子网的划分非常有帮助,可参看如下目录详文. IP地址和子网划分学习笔记相关篇章: 1.I ...
- 我的学习笔记——CSS背景渐变(Gradients)详解
我的学习笔记--CSS背景渐变(Gradients)详解 一.线性渐变(Linear Gradients) 1.语法 background-image: linear-gradient(directi ...
- IP地址和子网划分学习笔记之《子网划分详解》
一,子网划分概述 IP地址和子网划分学习笔记相关篇章: 1.IP地址和子网划分学习笔记之<预备知识:进制计数> 2.IP地址和子网划分学习笔记之<IP地址详解> 3.IP地址和 ...
- redis学习笔记(2)之redis主从详解
redis主从详解 主从详解 主从配置 拓扑 原理 数据同步 概念 复制偏移量 复制积压缓冲区 主节点运行ID Psync命令 全量复制流程 部分复制流程 心跳 缓冲大小调节 读写分离 内容来源为六星 ...
- redis学习笔记(7)之redis哨兵详解
redis哨兵详解 sentinel命令 客户端连接 素材代码 思路 实现过程 哨兵的切换实现原理 发布订阅基础 哨兵的实现原理 部署建议 需要关注的问题 代码流程 内容来源为六星教育,这里仅作为学习 ...
- Apollo星火计划学习笔记——Apollo决策规划技术详解及实现(以交通灯场景检测为例)
文章目录 前言 1. Apollo决策技术详解 1.1 Planing模块运行机制 1.2 Apollo决策功能的设计与实现 1.2.1参考路径 Reference Line 1.2.2 交规决策 T ...
- MIT 6.824 学习笔记(一)--- RPC 详解
从本文开始,将记录作者学习 MIT 6.824 分布式系统的学习笔记,如果有志同道合者,欢迎一起交流. RPC 的定义和结构 RPC 全称为 Remote Procedure Call,他表示一种远程 ...
最新文章
- AI化身监工,上班还能摸鱼吗?
- php源码之计算两个文件的相对路径
- 用云服务器实现janus之web端与web通话!
- 手动实现kt(java)同步工作流和异步工作流
- 姜汝祥的-赢在执行 - 制度执行力的三要三化
- 解释一下Spring支持的几种bean的作用域
- 前端学习(1979)vue之电商管理系统电商系统之让文本框获得焦点
- 【转】HTML5杂谈 概念与现行游戏 割绳子 宝石迷阵
- 【转】嵌入式软件:C语言编码规范
- word计算机公式怎么算,word怎么实现自动计算公式
- 【原创】导读”淘宝褚霸关于 gen_tcp 的分享“
- R语言knn算法的两种方法:class包与kknn包
- layui表格合并的方法
- 对拍--from Altf4
- 20系列显卡服务器,RTX20系列被严重低估,他不仅是一张游戏显卡
- cannot unbox null value
- Linux最全面试题100问答,纯纯爽文
- 一张图片切割成九宫格,微信朋友圈发布
- CSS的世界(十四)
- 京津冀计算机学科大学排名,2021京津冀地区民办大学排名前十