❤️点击上方,选择星标置顶,每天给你送上干货❤️

最近总有学妹问我,论文要截稿了,模型来不及跑怎么办?

还有学妹问我,有好多idea,验证一个就要跑一周怎么办?

这时候我想起了下面这张图,我想这句话反映了大多数从事人工智能的科研工作者们目前的状态。

于是我告诉学妹,想要快,找我就对了,我教你怎样让你的模型训练加速3倍以上!

这里我们需要用到的就是字节跳动AI Lab最近开源的「新版训练加速引擎」——LightSeq。

具体的原理这里就不做过多介绍了,过两天会专门发布一篇介绍技术细节的文章,敬请期待。今天我来教大家如何使用LightSeq。

安装步骤

源码安装

你可以从源码进行安装,使用如下命令:

git clone https://github.com/bytedance/lightseq.git
cd lightseq
pip install -e .

如果你想执行LightSeq提供的现成样例,或者使用它的单元测试工具,那最好从源码安装。

pip安装

当然如果你想直接调用LightSeq的接口,不需要它的样例或者单元测试工具,我更推荐你用下面pip的方式安装,更加方便:

pip install lightseq

使用教程

自定义模型

大多数同学可能想自己搭建一个Transformer模型,然后用来训练各种数据,那我这里就教大家如何快速搭建一个LightSeq版本的Transformer编码层。

你只需要创建一个配置对象,然后用它创建LightSeq的编码层即可。

我写了一份完整的训练代码,非常浅显易懂,看注释就行了,亲测可以直接运行哦:

import torch
from lightseq.training.ops.pytorch.transformer_encoder_layer import LSTransformerEncoderLayerdef train(model, inputs, masks):inputs = inputs.to(device="cuda:0")masks = masks.to(device="cuda:0")model.to(device="cuda:0")model.train()opt = torch.optim.Adam(model.parameters(), lr=1e-3)for epoch in range(1000):opt.zero_grad()outputs = model(inputs, masks)loss = torch.square(outputs).mean()loss.backward()opt.step()if epoch % 200 == 0:print("epoch {:>3d}: loss = {:>5.3f}".format(epoch, loss))if __name__ == "__main__":# 定义LightSeq配置config = LSTransformerEncoderLayer.get_config(max_batch_tokens=4096,max_seq_len=256,hidden_size=1024,intermediate_size=4096,nhead=16,attn_prob_dropout_ratio=0.1,activation_dropout_ratio=0.1,hidden_dropout_ratio=0.1,pre_layer_norm=True,fp16=False,local_rank=0)# 随机生成输入bsz, sl = 10, 80inputs = torch.randn(bsz, sl, config.hidden_size)masks = torch.zeros(bsz, sl)# 定义LightSeq编码层并进行训练model = LSTransformerEncoderLayer(config)train(model, inputs, masks)


下面两个样例都放在了lightseq/training/examples目录下,推荐大家采用源码安装的方式安装LightSeq,这样可以直接运行样例。

Hugging Face

Hugging Face是目前用的最多的预训练模型库了吧,主要是用起来太方便了,模型也很全。直接pip install transformers安装即可。

以BERT在NER任务上微调为例,直接运行LightSeq提供的脚本就行:

sh lightseq/training/examples/huggingface/run_ner.sh

Fairseq

Fairseq是目前最主流的序列生成库之一,用来做机器翻译、文本生成等任务都是非常方便的。安装的话也很简单,直接pip install fairseq即可。

LightSeq同样提供了现成的运行脚本,如果想运行LightSeq加速后的模型,执行如下命令:

sh lightseq/training/examples/fairseq/ls_fairseq_wmt14en2de.sh

再来看看细节,一般如果我们想用Fairseq来训练一个机器翻译模型,通常首先会准备好数据集,然后执行如下命令:

fairseq-train DATA_DIR \--arch transformer_wmt_en_de_big_t2t \--optimizer adam \--criterion label_smoothed_cross_entropy \...

注意这里我们只列出了同LightSeq有关的三个参数:--arch--optimizer--criterion,分别指定了模型结构、参数优化器和损失函数。

如果想用LightSeq进行加速,直接将上面的运行命令改为下面这样:

lightseq-train DATA_DIR \--arch ls_transformer_wmt_en_de_big_t2t \--optimizer ls_adam \--criterion ls_label_smoothed_cross_entropy \...

注意改动的地方有4个。fairseq-train改成lightseq-train,这是为了导入LightSeq的目录。--arch--optimizer--criterion都加上了ls_前缀,这样就快速替换为了LightSeq的组件。

训练速度

说了这么多,实际速度到底怎么样?我用Fairseq测了一下训练的总耗时:

不同模型大小、不同批处理大小、不同显卡上加速效果都是有区别的,但总体上都能缩短一半左右的训练时间。

如果你的显卡比较老旧(我相信大多数学校实验室都是这样的),显存又比较小,那么批处理大小只能设置的很小,那加速比甚至能达到3倍以上。

项目地址

学妹试了直叫好,说用起来确实快。

你也别忘了点个star,让更多的人享受到极速的快乐。

「LightSeq地址:」
https://github.com/bytedance/lightseq

- END -

我是godweiyang,华东师范大学计算机系本硕专业第一,字节跳动AI Lab NLP算法工程师,秋招斩获上海三家互联网大厂ssp offer,主要研究方向为机器翻译、句法分析、模型压缩与加速。最大特点就是脾气好、有耐心,有任何问题都可以随时咨询我,不管是技术上的还是生活上的。

公众号后台回复【内推

可以通过我的内推码投递简历,加我微信还能随时查进度、咨询问题。

公众号后台回复【加群

可以进我的技术交流群和内推群。

记得一键③连,今天的你格外的可爱????

只用几行代码,我让模型『训练』加速了3倍以上!相关推荐

  1. Python3:我只用1行代码就下载全网视频,我被我的才华和颜值征服了!!

    you-get库使用 1.引言 2.代码实战 2.1 you-get介绍 2.2 you-get安装 2.3 you-get下载视频 2.3.1 指定存储和重命名 2.3.2 查看视频信息 2.3.3 ...

  2. 只用3行代码,让Python提速4倍!最强辅助

    Python是一门非常适合处理数据和自动化完成重复性工作的编程语言.我们在用数据训练机器学习模型之前,通常都需要对数据进行预处理,而Python就非常适合完成这项工作,比如需要重新调整几十万张图像的尺 ...

  3. 只用2000行代码实现google protocol buffer c++版的功能

    2019独角兽企业重金招聘Python工程师标准>>> google protocol buffer (下面简称gpb)功能强大,应用广泛,但在实际应用中,gpb需要写.proto脚 ...

  4. 95行代码实现最大熵模型训练

    关于最大熵模型的介绍请看:http://www.cnblogs.com/hexinuaa/p/3353479.html 以下是GIS训练算法的python实现,代码不到100行. from colle ...

  5. 那些下载不了的视频,Python只用1行代码就能直接下载

    现在有很多网站都并不支持直接下载的,例如我们常去的B站里面的视频,在页面是没有下载按钮的,还有的视频需要我们下载客户端才能下载-虽然这并不能拦住多少人,有些聪明的小伙伴就会去下载一些第三方软件去下载, ...

  6. 只用70行代码,手把手教你遍历当前windows所有进程!

    大家好,我是KookNut39,在CSDN写文,分享一些自己认为在学习过程中比较重要的东西,希望可以帮助你进步.最近在更新C/C++方面的知识,感兴趣的可以欢迎关注博主,去专栏查看之前的文章,希望未来 ...

  7. Python如何只用20行代码给证件照换底色,学会了不怕没有女朋友!!!

    本文只是一种实现思路,当然PS很好用(一张的话建议使用PS哦~,多张图片的话用代码快很多哦~),希望大家能够学习更多的知识,才分享了这个文章.更多精彩,请关注公众号:[测试员小何],可以获取最新软件测 ...

  8. 一天star量破千,300行代码,特斯拉AI总监Karpathy写了个GPT的Pytorch训练库

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 整理:公众号@机器之心 本文仅做学术分享,如有侵权,请联系删除. 如果说 GPT 模型是所向披靡的战舰 ...

  9. python实现目标识别_Python10行代码实现目标检测

    ImageAI可以让程序员和软件开发者只用几行代码,就能轻易地把最先进的计算机视觉技术整合到他们现有的以及新的应用程序里面. 用ImageAI实现目标检测,你只需要以下步骤: 安装Python 安装I ...

  10. 大道至简,仅需4行代码提升多标签分类性能!ICCV21 南大提出Residual Attention

    ▊ 写在前面 多标签图像识别是一项具有挑战性的计算机视觉任务.然而,目前解决这一任务的方法复杂.计算量大.缺乏直观解释 .为了能够有效地感知不同类别物体所占据的空间区域,作者提出了一个非常简单的模块, ...

最新文章

  1. SSR rendering exceeded timeout 3000, fallbacking to CSR for
  2. SAP Hybris install.bat工作原理分析
  3. Scrapy 教程(十)-管道与数据库
  4. 清华大学全面审查文科博士论文!
  5. BugkuCTF-MISC题1和0的故事
  6. 5.产品的三种流程图,你都知道吗?
  7. Leetcode 242. 有效的字母异位
  8. 计算机系统基础知识——进制转换(二进制、八进制、十进制、十六进制)
  9. STL容器-queue队列
  10. oblog商业版本4.6注射漏洞,直接拿管理员
  11. Python茅台抢购脚本详细教程
  12. MIUI通知类短信权限的坑
  13. 最短路径之Bellman-Ford
  14. CodeForces 954D-Fight Against Traffic(加边最短路)
  15. 计组_定点数一位乘_布斯公式
  16. 《Unity着色器和屏幕特效开发秘笈》—— 第3章 利用镜面反射让游戏闪耀起来...
  17. 没有配置resolv.conf
  18. 微信公众号文章爬取下载各种格式
  19. 布尔教育mysql优化_布尔教育燕十八mysql优化视频课件源码分享
  20. pygame之pygame模块

热门文章

  1. EntityFramework Code-First 简易教程(八)-------一对一
  2. 长连接与心跳包 Persistent connection and HearBeats
  3. Oracle EBS之把自定义concurrent加入Pick Release Document Set(All Pick Release Documents)的几个注意点...
  4. maven如果正常配置不成功,就按照我的就可以配置成功了
  5. 笔记3:数字和数学计算
  6. 传输层协议TCP和UDP分析
  7. [NOIP2000]方格取数
  8. WCF生成的json与Extjs交互的日期型问题
  9. ZeroC IceGrid demo构建(继承Ice::Application类)
  10. C/C++编程语言中“crosses initialization”编译错误分析