仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:作者 | 燃雪  来源 | 知乎专栏

编辑 | 机器学习算法与自然语言处理

地址 | https://zhuanlan.zhihu.com/p/435413218

01

前言

在动手实现transformer过程中,主要参考了李沐老师《动手学深度学习 v2》的代码和视频讲解,收获很大,非常感谢李沐老师在b站发了这么多视频来系统地进行深度学习教学和论文解读(建议教育部新增b站学位证)。但是就我自己的学习过程来说,感觉视频和教程中有些代码和tensor解释并不是很容易理解(比如“state[-1]是最后一个时刻的最后一层”),不少命令是在反复看了许多遍后才理解其中输出和输入的变化过程。因此本文以具体代码为主线,补充tensor变化例子及示意图,力求让初学者更直观地理解Seq2Seq中的向量操作。

代码来源:李沐 《动手学深度学习 v2》 9.7 序列到序列学习

https://zh-v2.d2l.ai/chapter_recurrent-modern/seq2seq.html#

02

文本预处理

考虑一组batch_size = 3的encoder句子输入,首先进行预处理,合法字符替换、按word拆开(Splitting)、词元化(tokenize)、字典化(vocabularizen)、按照每句话的合法长度num_steps = 8进行截断和填充(padding)处理,得到批输入X和输入语言的vocab_src.

03

编码器Encoder

参数举例:batch_size = 3; seq_length = 8, vocab_size = 10000; embed_size = 512;

RNN网络参数: num_hiddens = 24; num_layers = 2;

class Seq2SeqEncoder(d2l.Encoder):"""用于序列到序列学习的循环神经网络编码器。"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super(Seq2SeqEncoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)      #Embedding NN,输入维度vocab_size,输出维度embed_sizeself.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)#GRU神经网络,输入维度embed_size,num_layers隐藏层,每层num_hiddens个神经元,是否dropoutdef forward(self, X, *args):X = self.embedding(X)# Embedding StepX = X.permute(1, 0, 2)# Switching Step, 将 batch size 和 word position 调换位置output, state = self.rnn(X)# 输出RNN网络的output和state,state表示完整的RNN神经网络的状态,是隐藏神经元状态按时间步的整合# 比如state[0]表示第1个时间步(word position/Time Step 1: I/aa/A)对应的RNN隐藏神经元状态return output, state

Batch input ➡️ Onehot 编码 ➡️ Embedding ➡️ 调换顺序 ➡️ 调整为RNN网络时序输入

Encoder中输入的逐层变化

Encoder_RNN神经网络的输入、隐藏状态及输出

04

解码器Decoder

Decoder输入的是target语言,此处对于target语言设置参数:batch_size = 3; seq_length = 8, vocab_size = 25000; embed_size = 512;

假设decoder使用与encoder相同设置的RNN网络参数: num_hiddens = 24; num_layers = 2;

class Seq2SeqDecoder(d2l.Decoder):"""用于序列到序列学习的循环神经网络解码器。"""def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):super(Seq2SeqDecoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)#同样的embedding NN与RNN,区别在于此处embedding网络根据target语言的embed_size和vocab_size参数设置#且RNN网络的输入进行了一次拼接(Concat Step),输入维度为embed_size+num_hiddensself.dense = nn.Linear(num_hiddens, vocab_size)#输出层,encoder的RNN网络不需要输出,decoder的RNN网络输出翻译的目标句子每个word位置(time step)上,#长度为vocab size的概率分布,表示vocab中每个word出现的概率def init_state(self, enc_outputs, *args):return enc_outputs[1]   #enc_outputs = [output, state],此处是将encoder最后一个时间步 time step 8的隐藏层参数取出作为初始化def forward(self, X, state):X = self.embedding(X).permute(1, 0, 2)#Embedding & Permute Step,将目标语言的batch input进行embedding and permute stepcontext = state[-1].repeat(X.shape[0], 1, 1)# Expand Step,将从encoder拿到的state[7]的最后一层隐藏层的状态作为表示原句信息的`context`#并将其扩展成tensor,使其具有与目标语言输入`X`相同的`seq_length`X_and_context = torch.cat((X, context), 2)#Concat Step,将句子信息context与目标语言Embedding tensor进行拼接,作为RNN_decoder的输入output, state = self.rnn(X_and_context, state)#decoder的RNN的输出和状态output = self.dense(output).permute(1, 0, 2)#Permute Step,将decoder的RNN输出后两位对调位置,方便下文处理 vocab 概率return output, state

Target Batch input ➡️ Onehot 编码 ➡️ Embedding ➡️ 调换顺序 ➡️ Concat Step

从Encoder拿到最后一个时间步的State ➡️ Extract Step ➡️ Expand Step ➡️ Concat Step

Decoder_RNN神经网络的输入、隐藏状态及输出

05

交叉熵损失

Mask Function,0值化屏蔽

def sequence_mask(X, valid_len, value=0):"""在序列中屏蔽不相关的项。"""maxlen = X.size(1)mask = torch.arange((maxlen), dtype=torch.float32,device=X.device)[None, :] < valid_len[:, None]X[~mask] = valuereturn X

举例:Input X = tensor( [1, 2, 3, 4, 5, 6, 7, 8], [a, b, c, d, e, f, g, h] ), 有效字符数 [4, 6]

sequence_mask,零值化屏蔽

带遮蔽的softmax交叉熵损失函数,MaskedCrossEntropy

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):"""带遮蔽的softmax交叉熵损失函数"""# `pred` 的形状:(`batch_size`, `seq_length_tgt`, `vocab_size_tgt`)# `label` 的形状:(`batch_size`, `seq_length_tgt`)# `valid_len` 的形状:(`batch_size`,)def forward(self, pred, label, valid_len):weights = torch.ones_like(label)              #按照label的shape进行单位矩阵初始化weights = sequence_mask(weights, valid_len)   #按照合法字符数量valid_len进行mask,遮蔽<pad>self.reduction='none'unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label)weighted_loss = (unweighted_loss * weights).mean(dim=1)return weighted_loss

06

训练过程

Train Function

def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):"""训练序列到序列模型。"""def xavier_init_weights(m):if type(m) == nn.Linear:nn.init.xavier_uniform_(m.weight)if type(m) == nn.GRU:for param in m._flat_weights_names:if "weight" in param:nn.init.xavier_uniform_(m._parameters[param])net.apply(xavier_init_weights)net.to(device)optimizer = torch.optim.Adam(net.parameters(), lr=lr)loss = MaskedSoftmaxCELoss()net.train()animator = d2l.Animator(xlabel='epoch', ylabel='loss',xlim=[10, num_epochs])for epoch in range(num_epochs):timer = d2l.Timer()metric = d2l.Accumulator(2)  # 训练损失总和,词元数量for batch in data_iter:X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],device=device).reshape(-1, 1)dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 教师强制Y_hat, _ = net(X, dec_input, X_valid_len)l = loss(Y_hat, Y, Y_valid_len)l.sum().backward()      # 损失函数的标量进行“反传”d2l.grad_clipping(net, 1)num_tokens = Y_valid_len.sum()optimizer.step()with torch.no_grad():metric.add(l.sum(), num_tokens)if (epoch + 1) % 10 == 0:animator.add(epoch + 1, (metric[0] / metric[1],))print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} 'f'tokens/sec on {str(device)}')

开始训练

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,dropout)
net = d2l.EncoderDecoder(encoder, decoder)
train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

07

预测及评估

预测网络搭建

def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,device, save_attention_weights=False):"""序列到序列模型的预测"""# 在预测时将`net`设置为评估模式net.eval()src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]enc_valid_len = torch.tensor([len(src_tokens)], device=device)src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])# 添加批量轴enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)enc_outputs = net.encoder(enc_X, enc_valid_len)dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)# 添加批量轴dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)output_seq, attention_weight_seq = [], []for _ in range(num_steps):Y, dec_state = net.decoder(dec_X, dec_state)# 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入dec_X = Y.argmax(dim=2)pred = dec_X.squeeze(dim=0).type(torch.int32).item()# 保存注意力权重(稍后讨论)if save_attention_weights:attention_weight_seq.append(net.decoder.attention_weights)# 一旦序列结束词元被预测,输出序列的生成就完成了if pred == tgt_vocab['<eos>']:breakoutput_seq.append(pred)return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq

BLEU预测序列评价指标

def bleu(pred_seq, label_seq, k):  """计算 BLEU"""pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')len_pred, len_label = len(pred_tokens), len(label_tokens)score = math.exp(min(0, 1 - len_label / len_pred))for n in range(1, k + 1):num_matches, label_subs = 0, collections.defaultdict(int)for i in range(len_label - n + 1):label_subs[''.join(label_tokens[i: i + n])] += 1for i in range(len_pred - n + 1):if label_subs[''.join(pred_tokens[i: i + n])] > 0:num_matches += 1label_subs[''.join(pred_tokens[i: i + n])] -= 1score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))return score

预测实例

engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)print(f'{eng} => {translation}, bleu {bleu(translation, fra, k=2):.3f}')

08

Related Work

[1] LSTM神经网络状态,知乎用户 Scofield:

https://www.zhihu.com/question/41949741/answer/318771336

[2] @盛源车,知乎原文找不到了,此处是微信公众号的原文,可视化理解做的很直观:

https://mp.weixin.qq.com/s/0k71fKKv2SRLv9M6BjDo4w

[3] 可视化图解Attention based Seq2Seq模型:

https://zhuanlan.zhihu.com/p/60127009

[4] 基于TensorFlow的Seq2Seq代码实现:

https://zhuanlan.zhihu.com/p/47929039

[5] 本文使用的可视化工具——ML Visuals:

https://github.com/dair-ai/ml-visuals

推荐阅读:

我的2022届互联网校招分享

我的2021总结

浅谈算法岗和开发岗的区别

互联网校招研发薪资汇总

2022届互联网求职现状,金9银10快变成铜9铁10!!

公众号:AI蜗牛车

保持谦逊、保持自律、保持进步

发送【蜗牛】获取一份《手把手AI项目》(AI蜗牛车著)

发送【1222】获取一份不错的leetcode刷题笔记

发送【AI四大名著】获取四本经典AI电子书

深度剖析Seq2Seq原理代码相关推荐

  1. 深度剖析hdfs原理

    大数据底层技术的三大基石起源于Google在2006年之前的三篇论文GFS.Map-Reduce. Bigtable,其中GFS.Map-Reduce技术直接支持了Apache Hadoop项目的诞生 ...

  2. Mysql binlog应用场景与原理深度剖析

    1 基于binlog的主从复制 Mysql 5.0以后,支持通过binary log(二进制日志)以支持主从复制.复制允许将来自一个MySQL数据库服务器(master) 的数据复制到一个或多个其他M ...

  3. 好文推荐 | MySQL binlog应用场景与原理深度剖析

    作者:田守枝 来自:田守枝的博客(公众号) 本文深入介绍Mysql Binlog的应用场景,以及如何与MQ.elasticsearch.redis等组件的保持数据最终一致.最后通过案例深入分析binl ...

  4. 单片机c语言必背代码_【典藏】深度剖析单片机程序的运行(C程序版)

    1.日常聊一聊 今天为大家带来一篇对于单片机学习的小伙伴非常重量级的一篇文章<深度剖析单片机程序的运行(C语言版本)>,该文章会比较全面的为大家解析我们的用C语言编译出来的程序是如何在单片 ...

  5. git原理详解与实操指南_全网最精:学git一套就够了,从入门到原理深度剖析

    以上资源收集至互联网 如有侵权请联系删除 资源获取方式 扫码关注资源库公众号 回复密码'20190812' 即可获得 截图展示 课程信息 课程难度:中级 学习人数:148352 课程状态:已完结 时长 ...

  6. 深度剖析浏览器渲染性能原理,你到底知道多少?

    深度剖析浏览器渲染性能原理,你到底知道多少? 渲染卡顿是怎么回事? 网页不仅应该被快速加载,同时还应该流畅运行,比如快速响应的交互,如丝般顺滑的动画等. 大多数设备的刷新频率是60次/秒,也就说是浏览 ...

  7. 【微信小程序控制硬件④】 深度剖析微信公众号配网 Airkiss 原理与过程,esp8266如何自定义回调参数给微信,实现绑定设备第一步!(附带源码)

    [微信小程序控制硬件第1篇 ] 全网首发,借助 emq 消息服务器带你如何搭建微信小程序的mqtt服务器,轻松控制智能硬件! [微信小程序控制硬件第2篇 ] 开始微信小程序之旅,导入小程序Mqtt客户 ...

  8. 唯一插件化Replugin源码及原理深度剖析--插件的安装、加载原理

    上一篇 唯一插件化Replugin源码及原理深度剖析–唯一Hook点原理 在Replugin的初始化过程中,我将他们分成了比较重要3个模块,整体框架的初始化.hook系统ClassLoader.插件的 ...

  9. 老夫带你深度剖析Redisson实现分布式锁的原理

    Redis实现分布式锁的原理 前面讲了Redis在实际业务场景中的应用,那么下面再来了解一下Redisson功能性场景的应用,也就是大家经常使用的分布式锁的实现场景. 引入redisson依赖 < ...

最新文章

  1. 重磅!阿里达摩院发布《2020十大科技趋势》
  2. MOSA 4600 Plus IP PBX FAQ(应用常见知识点-故障排除)(2)
  3. IE userdata
  4. 如何部署 Hyperic ,使得从内网监测外网服务器
  5. idea如何设置类头注释和方法注释
  6. centos 5.6安装nginx+mysql+php(php-fpm)+phpmyadmin总结
  7. 毕业五年同是程序员为什么差距这么大?他年薪百万,他月薪一万
  8. 语音信号处理(赵力)作业答案第8章——语音合成
  9. 【计算理论】计算理论总结 ( 非确定性有限自动机 NFA 转为确定性有限自动机 DFA ) ★★
  10. 基于springboot网上商城交易平台源码
  11. java数字转为大写_java 数字转大写汉字
  12. 荣耀v10图片是html格式,荣耀V10真机上手图赏 参数配置分析详解
  13. LT-mapper,LT-SLAM代码运行与学习
  14. 18天掌握Java SE jvav梳理总结 从jvav到架构师
  15. 解决夜神模拟器连接eclipse的问题
  16. 最新全球学术排名出炉:23所中国大学跻身世界前100名!
  17. GEE遥感云大数据如何应用在林业生态领域中?监测森林扰动、火灾、砍伐退化、生理参数、植被状态
  18. Google打印没有彩色,浏览器打印预览没有背景颜色和没有颜色
  19. 金蝶K3 WISE版本过服务期后打补丁方法
  20. 关于初学者用哪种C/C++编译器(集成开发环境)的问题

热门文章

  1. web版本 开源压测工具_14款好用开源的Web应用压力负载,性能测试工具推荐
  2. 普法 | 如果你被裁员,赔偿金是N、N+1、2N呢?
  3. Something of Information Security Management
  4. 整合第三方登录之微信扫码登录
  5. 读不读博士的问题--转载
  6. Java随机生成大乐透号码
  7. 烤仔TVの尚书房 | 关于量子霸权,你以为你以为的就是你以为的吗?
  8. 虚拟机安装我的世界服务器,如何在Ubuntu 20.04上搭建我的世界Minecraft服务器
  9. Excel单元格中引用当前工作表名称
  10. 网速测试利器-iperf3