调试The Annotated Transformer
The Annotated Transformer 应该是我见过最贴心的‘Attention is All You Need’的复现了。看网页链接像是哈佛大学复现的,质量应该还不错,于是就照着代码按顺序ctrl + c +v了一遍。其实在github上也有代码可以直接下载,只不过是.ipynb格式的。
在调试代码的过程中,遇到了一些问题,在这里记录一下。
1 环境安装
作者没有说明每个依赖库的版本,以下是我个人的版本,可以参考。
python==3.8.8
torch==1.9.0
numpy==1.20.1
matplotlib==3.3.4
spacy==2.2.2
torchtext==0.6.0
numpy和matplotlib的版本影响应该不大;python的版本影响大不大不知道;torch的版本有点影响,这个版本会导致一个小错误,不过可以被解决;spacy和torchtext的版本影响很大!那是2018年发布的博客,torch、spacy和torchtext的版本应该比较低。博客中还提到要安装seaborn,我没有装,好像没影响。
2 遇到的问题
问题1:
在执行以下代码的时候,
python -m spacy download en
python -m spacy download de
出现网络链接的错误:
requests.exceptions.ConnectionError: HTTPSConnectionPool(host=‘raw.githubusercontent.com’, port=443): Max retries exceeded with url: /explosion/spacy-models/master/shortcuts-v2.json (Caused by NewConnectionError(’<urllib3.connection.HTTPSConnection object at 0x11bd5caf0>: Failed to establish a new connection: [Errno 61] Connection refused’))
参考这个方法,安装了en_core_web_sm
和de_core_news_sm
,先手动下载安装包(百度云盘,提取码:0cic),再用pip安装。
pip install en_core_web_sm-2.2.5.tar.gz
pip install de_core_news_sm-2.2.5.tar.gz
相应的,代码要改一下,把
import spacy
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
改成
import en_core_web_sm
import de_core_news_smif True:spacy_de = de_core_news_sm.load()spacy_en = en_core_web_sm.load()
问题2
torchtext版本太高导致from torchtext import data, datasets
的data.Field
和datasets.IWSLT
不存在,把版本降到0.6.0就存在了。
问题3
代码执行到train, val, test = datasets.IWSLT.splits(...)
时,程序会下载数据,也报网络链接的错。经过debug发现,程序先在本地找.data/iwslt/de-gn.tgz
这个文件,找不到才去下载。所以,可以先把.data/iwslt/de-gn.tgz
文件准备好就行了。那么,这是个什么文件呢?这个文件来自WIT3,得在google drive下载,下载下来的文件名是2016-01.tgz
,解压后在里面找到一个叫de-gn.tgz
的文件(得翻几层文件夹),放在.data/iwslt/
目录下就可以了,注意这是一个相对路径。
问题4
OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized.
解决方法:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
问题5
IndexError: invalid index of a 0-dim tensor. Use tensor.item()
in Python or tensor.item<T>()
in C++ to convert a 0-dim tensor to a number
pytorch版本导致的问题,解决方法:
把SimpleLossCompute
的最后一行的return loss.data[0] * norm
改成return loss.data.item() * norm
3 整理代码
把一些零散的代码粘贴到一个函数里:
def train_on_cpu():"""Train the model on cpu."""# For data loading.from torchtext import data, datasetsimport en_core_web_smimport de_core_news_smspacy_de = de_core_news_sm.load()spacy_en = en_core_web_sm.load()def tokenize_de(text):return [tok.text for tok in spacy_de.tokenizer(text)]def tokenize_en(text):return [tok.text for tok in spacy_en.tokenizer(text)]BOS_WORD = '<s>'EOS_WORD = '</s>'BLANK_WORD = "<blank>"SRC = data.Field(tokenize=tokenize_de, pad_token=BLANK_WORD)TGT = data.Field(tokenize=tokenize_en, init_token=BOS_WORD,eos_token=EOS_WORD, pad_token=BLANK_WORD)MAX_LEN = 100train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(SRC, TGT),filter_pred=lambda x: len(vars(x)['src']) <= MAX_LEN andlen(vars(x)['trg']) <= MAX_LEN)MIN_FREQ = 2SRC.build_vocab(train.src, min_freq=MIN_FREQ)TGT.build_vocab(train.trg, min_freq=MIN_FREQ)# Make a model and data iteratorspad_idx = TGT.vocab.stoi["<blank>"]model = make_model(len(SRC.vocab), len(TGT.vocab), N=6)criterion = LabelSmoothing(size=len(TGT.vocab), padding_idx=pad_idx, smoothing=0.1)BATCH_SIZE = 8train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=0,repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),batch_size_fn=batch_size_fn, train=True)valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=0,repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),batch_size_fn=batch_size_fn, train=False)# Trainmodel_opt = NoamOpt(model.src_embed[0].d_model, 1, 2000,torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))for epoch in range(10):model.train()run_epoch((rebatch(pad_idx, b) for b in train_iter),model,SimpleLossCompute(model.generator, criterion, model_opt))model.eval()loss = run_epoch((rebatch(pad_idx, b) for b in valid_iter),model,SimpleLossCompute(model.generator, criterion, None))print(loss)for i, batch in enumerate(valid_iter):src = batch.src.transpose(0, 1)[:1]src_mask = (src != SRC.vocab.stoi["<blank>"]).unsqueeze(-2)out = greedy_decode(model, src, src_mask,max_len=60, start_symbol=TGT.vocab.stoi["<s>"])print("Translation:", end="\t")for i in range(1, out.size(1)):sym = TGT.vocab.itos[out[0, i]]if sym == "</s>": breakprint(sym, end=" ")print()print("Target:", end="\t")for i in range(1, batch.trg.size(0)):sym = TGT.vocab.itos[batch.trg.data[i, 0]]if sym == "</s>": breakprint(sym, end=" ")print()break
在脚本末尾写个程序入口:
if __name__ == '__main__':train_on_cpu() # I have no GPU
最后run一下,就可以看到以下舒心的画面了:
Epoch Step: 1 Loss: 9.118329 Tokens per Sec: 11.898214
Epoch Step: 51 Loss: 8.626973 Tokens per Sec: 23.280647
Epoch Step: 101 Loss: 7.953571 Tokens per Sec: 20.981972
完整的代码可以在github下载。
调试The Annotated Transformer相关推荐
- The Annotated Transformer
The Annotated Transformer Apr 3, 2018 from IPython.display import Image Image(filename='images/aiayn ...
- 《The Annotated Transformer》翻译——注释和代码实现《Attention Is All You Need》
文章目录 预备工作 背景 模型架构 Encoder and Decoder 堆栈 Encoder Decoder Attention 模型中Attention的应用 基于位置的前馈网络 Embeddi ...
- NLP-生成模型-2017-Transformer(二):Transformer各模块代码分析
一.WordEmbedding层模块(文本嵌入层) Embedding Layer(文本嵌入层)的作用:无论是源文本嵌入还是目标文本嵌入,都是为了将文本中词汇的数字表示转变为向量表示, 由一维转为多维 ...
- Dissecting BERT Part 1: The Encoder 解析BERT解码器(transformer)
原文:https://medium.com/dissecting-bert/dissecting-bert-part-1-d3c3d495cdb3 A meaningful representatio ...
- attention seq2seq transformer bert 学习总结 _20201107
https://blog.csdn.net/weixin_44388679/article/details/102575223 Seq2Seq原理详解 一文读懂BERT(原理篇) 2018年的10月1 ...
- Transformer的PyTorch实现
Google 2017年的论文 Attention is all you need 阐释了什么叫做大道至简!该论文提出了Transformer模型,完全基于Attention mechanism,抛弃 ...
- 用Transformer实现OCR字符识别!
Datawhale干货 作者:安晟.袁明坤,Datawhale成员 在CV领域中,transformer除了分类还能做什么?本文将采用一个单词识别任务数据集,讲解如何使用transformer实现一个 ...
- 10分钟带你深入理解Transformer原理及实现
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自|深度学习这件小事 基于 Transformer<A ...
- Transformer 模型详解
Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer.Transformer 模型使用了 Self- ...
- 【NLP】关于Transformer的常见问题及解答
作者 | Adherer 编辑 | NewBeeNLP PDF版文末自行下载哈~ 写在前面 前些时间,赶完论文,开始对 Transformer.GPT.Bert 系列论文来进行仔仔细细的研读,然后顺手 ...
最新文章
- select下拉框下拉跳转代码
- struts2之单个文件上传
- linux内核之内存管理.doc,linux内核之内存管理.doc
- 在lcd屏幕上窗口显示一个bitmap_SmartDrop——LED/LCD大屏内容投屏管理软件
- u8 和 char如何转化_EXCEL小知识——如何快速实现文本与数值的互相转化
- Windows Server 2012/2012 R2:安装和配置 SMTP 服务器
- 面对安利,谁能笑到最后
- [转]ios面试题收集(二)
- The Backrooms - Level 0.2 - 我爱杏仁水
- Android-组件化开发
- 27服务-安全访问状态转换
- win7万能声卡驱动_驱动精灵标准版 v9.61.3708.3054下载
- 概率论与数理统计前两章总结
- 解决VMware虚拟机中鼠标闪烁问题
- 【深度】最新万字综述自动驾驶,深度解构核心技术!
- 递归专题---[2]开根号
- android 正在检查更新,关于在app启动android上检查更新的新手问题
- 小时,速度,筒仓团队和甘特斯
- python微软雅黑字体_win10+python3.7下matplotlib显示中文,可使用微软雅黑.md
- 帮助你拿到offer的金融测试面试题
热门文章
- 西电软工oop面向对象程序设计实验三上机报告
- 常用性能工具:工欲善其事,必先利其器
- css 好看滚动条样式大全,CSS 个性化滚动条样式
- 解决:adb devices error protocol falut(no status)
- 【- Flutter 桌面篇 -】 FlutterUnit win版闪亮登场
- wireshark 过滤omci包_中兴OLT、ONU常见故障问题处理
- 图片提取文字很神奇?试试三步实现OCR!
- [0CTF 2016] piapiapia 题解
- 小米2A com.android.phone,104.android 简单的检查小米手机系统和华为手机系统是否打开通话自动录音功能,跳转通话录音页面...
- BZOJ1023 [SHOI2008]cactus仙人掌图