我们今天开始分析著名的attention is all you need 论文的pytorch实现的源码解析。
由于项目很大,所以我们会分开几讲来进行讲解。
先上源码:https://github.com/Eathoublu/attention-is-all-you-need-pytorch
大家可以先自行下载并理解。
今天是第一讲,我们先讲解数据的预处理部分:preprocess.py
项目结构:

-transfomer
—__init__.py
—Beam.py
—Constants.py
—Layers.py
—Models.py
—Module.py
—Optim.py
—SubLayers.py
—Translator.py

datasets.py
preprocess.py
train.py
translate.py

我们今天先介绍preprocess.py文件即数据预处理的部分,数据清洗以及词表的构建非常重要,我会用注释的方式进行解析请大家从标号为1的注释开始阅读到标号为12的源码,我将尽可能用简洁的语言解读源码,保证大家都能够读懂。当然,有能力的同学完全可以跳过这一节直接阅读github里面的preprocess.py源码。

好的,现在我们往下翻,翻到def main的位置。
文件源码:

''' Handling the data io '''
import argparse
import torch
import transformer.Constants as Constants#以上引入了解析命令行参数的库以及pytorch和一个transfomer文件夹下的constants文件
#我们先下翻到def main的位置吧#1.好了,就是这里。我们首先看main函数:
def main():''' Main function '''parser = argparse.ArgumentParser()parser.add_argument('-train_src', required=True)parser.add_argument('-train_tgt', required=True)parser.add_argument('-valid_src', required=True)parser.add_argument('-valid_tgt', required=True)parser.add_argument('-save_data', required=True)parser.add_argument('-max_len', '--max_word_seq_len', type=int, default=50)parser.add_argument('-min_word_count', type=int, default=5)parser.add_argument('-keep_case', action='store_true')parser.add_argument('-share_vocab', action='store_true')parser.add_argument('-vocab', default=None)
#2.以上是一些命令行运行时需要传入的参数,required=True的字段是必须传入的,其他是可选opt = parser.parse_args() #3. 解析命令行参数opt.max_token_seq_len = opt.max_word_seq_len + 2 # include the <s> and </s> # 4.我们在调用的时候参数里面有一个是告诉程序我们传入的句子里面最多有多少个词,然后程序会自动帮我们+2,因为可能存在的</s>标签。至于为什么要传入这个最长词序列的长度我们先不管它。# Training set# 5.训练集,以下的两行代码的意思是,我们要调用read_instances_from_file这个函数(本来是有的为了方便阅读在这里我不摆出来了。)这个函数的作用是:传入三个参数(数据集的绝对路径、最长的句子里面有多少个词,是否全是小写)函数的主要功能是逐行读入目标文件的内容(文件中一行就是一个句子),并将每行的句子进行分词转换成一个词的列表,并将所有句子的词的列表组合成一个大的句子的列表,例如:[[什么,?,大清,亡,了,?],  [我,爱,时崎狂三],  [暴走大事件,更新,了],  [我, 来自,东北大学],[他,酒驾,进去,了]] 返回值是就是这样的一个列表啦!只不过源码示例示英文的而已~train_src_word_insts = read_instances_from_file(opt.train_src, opt.max_word_seq_len, opt.keep_case)train_tgt_word_insts = read_instances_from_file(opt.train_tgt, opt.max_word_seq_len, opt.keep_case)#6.这行是做一个规范,规定数据集的数据条数一定要等于标签集的数据条数,否则我们取同样个数的数据集以及标签集,例如100个data,103个target,那么我们data target都取100个。if len(train_src_word_insts) != len(train_tgt_word_insts):print('[Warning] The training instance count is not equal.')min_inst_count = min(len(train_src_word_insts), len(train_tgt_word_insts))train_src_word_insts = train_src_word_insts[:min_inst_count]train_tgt_word_insts = train_tgt_word_insts[:min_inst_count]
#7.接下来是将那些不合法的数据和标签清洗掉,例如把有数据,标签只是一个空格这样的数据去掉。#- Remove empty instancestrain_src_word_insts, train_tgt_word_insts = list(zip(*[(s, t) for s, t in zip(train_src_word_insts, train_tgt_word_insts) if s and t]))
#8.这一步是制作验证集,方法和上面是一样的,都是调用 read_instances_from_file函数,我就不赘述了。# Validation setvalid_src_word_insts = read_instances_from_file(opt.valid_src, opt.max_word_seq_len, opt.keep_case)valid_tgt_word_insts = read_instances_from_file(opt.valid_tgt, opt.max_word_seq_len, opt.keep_case)
#9.接下来的7行代码,是和清洗训练集一样,对验证集进行清洗。if len(valid_src_word_insts) != len(valid_tgt_word_insts):print('[Warning] The validation instance count is not equal.')min_inst_count = min(len(valid_src_word_insts), len(valid_tgt_word_insts))valid_src_word_insts = valid_src_word_insts[:min_inst_count]valid_tgt_word_insts = valid_tgt_word_insts[:min_inst_count]#- Remove empty instancesvalid_src_word_insts, valid_tgt_word_insts = list(zip(*[(s, t) for s, t in zip(valid_src_word_insts, valid_tgt_word_insts) if s and t]))#9.好的,至此我们已经完成了数据清洗的步骤,得到了训练集以及验证集两个部分。现在我们要构建词表了。# Build vocabulary
#10. 请注意,下面的这几个if opt.vocab到else这个代码块在源码示例里面并没有使用到,因为这几个参数都是可选的,我们大可以跳过,暂时不看。请跳到11.的位置继续阅读。if opt.vocab: predefined_data = torch.load(opt.vocab)assert 'dict' in predefined_dataprint('[Info] Pre-defined vocabulary found.')src_word2idx = predefined_data['dict']['src']tgt_word2idx = predefined_data['dict']['tgt']else:if opt.share_vocab:print('[Info] Build shared vocabulary for source and target.')word2idx = build_vocab_idx(train_src_word_insts + train_tgt_word_insts, opt.min_word_count)src_word2idx = tgt_word2idx = word2idx#11. 10以下,11以上的代码是可选参数的处理,我们可以暂时不去理解,我们假定我们运行程序的时候,没有传入这些参数,那么我们将会进入下面的else,创建一个新的词表。这个build_vocab_idx函数就是用来将词语转化成词表的:原理很简单,就是将刚刚产生的所有的句子列表里面的所有的词给拿出来,并给每一个词一个编号,做成一个字典并返回,这货就叫做词表。else:print('[Info] Build vocabulary for source.')src_word2idx = build_vocab_idx(train_src_word_insts, opt.min_word_count)print('[Info] Build vocabulary for target.')tgt_word2idx = build_vocab_idx(train_tgt_word_insts, opt.min_word_count)
#12.下面,我们将每一个训练集里面出现过的单词转化为词表里面的一个下标index,并将原本是词语序列构成的句子转化为以词语在词表中的下标序列构成的列表。例如:我=1,爱=2,时崎狂三=3,那么原本的句子[我,爱,时崎狂三]就变成[1, 2, 3] 实现这个功能的函数就是convert_instance_to_idx_seq,它的返回值就是上述的这个列表。# word to indexprint('[Info] Convert source word instances into sequences of word index.')train_src_insts = convert_instance_to_idx_seq(train_src_word_insts, src_word2idx)valid_src_insts = convert_instance_to_idx_seq(valid_src_word_insts, src_word2idx)print('[Info] Convert target word instances into sequences of word index.')train_tgt_insts = convert_instance_to_idx_seq(train_tgt_word_insts, tgt_word2idx)valid_tgt_insts = convert_instance_to_idx_seq(valid_tgt_word_insts, tgt_word2idx)
#12.好了,现在我们构建一个数据集的字典对象,里面包括了传入的参数、词表以及训练集、验证集。然后用torch.save方法持久化这个字典对象,方便以后调用这个数据集进行训练和测试。至此,源码解析的数据预处理部分就结束了。data = {'settings': opt,'dict': {'src': src_word2idx,'tgt': tgt_word2idx},'train': {'src': train_src_insts,'tgt': train_tgt_insts},'valid': {'src': valid_src_insts,'tgt': valid_tgt_insts}}print('[Info] Dumping the processed data to pickle file', opt.save_data)torch.save(data, opt.save_data)print('[Info] Finish.')if __name__ == '__main__':main()

写在后面:在自然语言处理任务中,数据清洗是非常重要的一步,因此希望大家十分重视,正所谓垃圾in垃圾out。另外,由于本人水平有限,如果我有什么没说明白或者说错了的地方,非常欢迎大家指出,可以留言,另外本人工作邮箱:1012950361@qq.com 我将以最快的速度更正以及补全,谢谢大家!

敬请关注下一期:Attention is all you need pytorch实现 源码解析02 - 模型的训练 (train.py)

Attention is all you need pytorch实现 源码解析01 - 数据预处理、词表的构建相关推荐

  1. AlphaFold2源码解析(3)--数据预处理

    AlphaFold2源码解析(3)–数据预处理 数据预处理整体流程 数据处理入口: feature_dict = data_pipeline.process( input_fasta_path=fas ...

  2. weiler-atherton多边形裁剪算法_EAST算法超详细源码解析:数据预处理与标签生成...

    作者简介 CW,广东深圳人,毕业于中山大学(SYSU)数据科学与计算机学院,毕业后就业于腾讯计算机系统有限公司技术工程与事业群(TEG)从事Devops工作,期间在AI LAB实习过,实操过道路交通元 ...

  3. EAST算法超详细源码解析:数据预处理与标签生成

    作者简介 CW,广东深圳人,毕业于中山大学(SYSU)数据科学与计算机学院,毕业后就业于腾讯计算机系统有限公司技术工程与事业群(TEG)从事Devops工作,期间在AI LAB实习过,实操过道路交通元 ...

  4. android 输入法如何启动流程_android输入法02:openwnn源码解析01—输入流程

    android 输入法 02:openwnn 源码解析 01-输入流程 之后要开始 android 日文输入法的测试,因此现在开始研究 android 输入法.之前两 篇文章已经对 android 自 ...

  5. YOLOv3源码解析2-数据预处理Dataset()

    YOLOv3源码解析1-代码整体结构 YOLOv3源码解析2-数据预处理Dataset() YOLOv3源码解析3-网络结构YOLOV3() YOLOv3源码解析4-计算损失compute_loss( ...

  6. 谷歌BERT预训练源码解析(二):模型构建

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_39470744/arti ...

  7. pytorch YoLOV3 源码解析 train.py

    train.py 总体分为三部分(不算import 库) 初始的一些设定 + train函数 + main函数 源码地址: https://github.com/ultralytics/yolov3 ...

  8. 大数据之-Hadoop3.x_MapReduce_ReduceTask源码解析---大数据之hadoop3.x工作笔记0127

    然后我们接着去reducetask的源码: 可以看到上面,maptask执行以后,数据被分区,然后溢写到磁盘文件中,然后 就到了执行reducetask的时候,首先走到reducetask的上面这个位 ...

  9. 大数据之-Hadoop3.x_MapReduce_MapTask源码解析---大数据之hadoop3.x工作笔记0126

    然后我们来看一下maptask的源码,这个对理解maptask如何工作很重要 我们在一个例子的基础上去debug,去看,可以看到,我们用 partitioner2这个案例,这个是我们之前,用来区分,把 ...

最新文章

  1. 递归/回溯:八皇后问题N-Queens
  2. python nDPI 流量分析框架 Nfstream 简介
  3. AUTOSAR从入门到精通100讲(八十一)-AUTOSAR基础篇之FiM
  4. 数据结构和算法练习网站_视频和练习介绍了10种常见数据结构
  5. 应届生开40万年薪?OPPO大手笔招揽芯片人才引热议
  6. leetcode题解279-完全平方数
  7. 姓名的首字母组成的图案C语言怎么编,c语言编写一个程序,根据用户输入英文名和姓先显示姓氏,其后跟一个逗号,然后显示名的首字母:...
  8. 弹出层之3:JQuery.tipswindow
  9. 3月7日 当前动力电池竞争格局
  10. android 微信浮窗实现_Android仿微信文章悬浮窗效果的实现代码
  11. HIVE SQL分位数percentile使用方法案例
  12. 图论(十四)——图的着色
  13. python 手写字符识别
  14. 华为面试题(小朋友高矮排序,要求移动距离最小)-java版
  15. #内存泄露# #valgrind# valgrind使用
  16. 渐变折射率(GRIN)透镜的建模
  17. Java初学者——小白篇(一)
  18. 在EXCEL中导入txt文本数据
  19. 一个可实施的技术方案模板
  20. stm32+cubemx+adc+time定时采集+dma多通道采集

热门文章

  1. 并查集及路径压缩模板
  2. [转帖]你不曾见过的国产CPU:可能是最全的龙芯系列芯片家谱(下)
  3. css 圆形背景icon_css3画实心圆和圆角的方法
  4. MediaPlayer控件的属性集合(完整版)
  5. 联想电脑忘记密码了解决办法
  6. 1086: ASCII码排序(多实例测试)
  7. 段码液晶屏可以修复吗?
  8. mybatis iterate
  9. A-Level经济真题每期一练(54)
  10. 数据结构——数据结构模拟银行排号叫号系统参考