在TPU上运行PyTorch的技巧总结
TPU芯片介绍
Google定制的打机器学习专用晶片称之为TPU(Tensor Processing Unit),Google在其自家称,由于TPU专为机器学习所运行,得以较传统CPU、 GPU降低精度,在计算所需的电晶体数量上,自然可以减少,也因此,可从电晶体中挤出更多效能,每秒执行更复杂、强大的机器学习模组,并加速模组的运用,使得使用者更快得到答案,Google最早是计划用FPGA的,但是财大气粗,考虑到自己的特殊应用,就招了很多牛人来做专用芯片TPU。
TPUs已经针对TensorFlow进行了优化,并且主要用于TensorFlow。但是Kaggle和谷歌在它的一些比赛中分发了免费的TPU时间,并且一个人不会简单地改变他最喜欢的框架,所以这是一个关于我在GCP上用TPU训练PyTorch模型的经验的备忘录(大部分是成功的)。
PyTorch/XLA是允许这样做的项目。它仍在积极的开发中,问题得到了解决。希望在不久的将来,运行它的体验会更加顺畅,一些bug会得到修复,最佳实践也会得到更好的交流。
https://github.com/pytorch/xla
设置
这里有两种方法可以获得TPU的使用权
GCP计算引擎虚拟机与预构建的PyTorch/XLA映像并按照PyTorch/XLA github页面上的“使用预构建的计算VM映像”部分进行设置。
或者使用最简单的方法,使用google的colab笔记本可以获得免费的tpu使用。
针对一kaggle的比赛您可以在虚拟机上使用以下代码复制Kaggle API令牌并使用它下载竞争数据。还可以使用gsutil cp将文件复制回GS bucket。
gcloud auth login
gsutil cp gs://bucket-name/kaggle-keys/kaggle.json ~/.kaggle
chmod 600 ~/.kaggle/kaggle.json
kaggle competitions download -c recursion-cellular-image-classification
除了谷歌存储之外,我还使用github存储库将数据和代码从我的本地机器传输到GCP虚拟机,然后再返回。
注意,在TPU节点上也有运行的软件版本。它必须匹配您在VM上使用的conda环境。由于PyTorch/XLA目前正在积极开发中,我使用最新的TPU版本:
使用TPU训练
让我们看看代码。PyTorch/XLA有自己的多核运行方式,由于TPUs是多核的,您希望利用它。但在你这样做之前,你可能想要把你的模型中的device = ’ cuda '替换为
import torch_xla_py.xla_model as xm...
device = xm.xla_device()...
xm.optimizer_step(optimizer)
xm.mark_step()
仅在TPU的一个核上测试您的模型。上面代码片段中的最后两行替换了常规的optimizer.step()调用。
对于多核训练,PyTorch/XLA使用它自己的并行类。在这里的测试目录中可以找到一个使用并行训练循环的示例(https://github.com/pytorch/xla/blob/master/test/test_train_mnist.py)
我想强调与它相关的以下三点。
- DataParallel并行持有模型对象的副本(每个TPU设备一个),并以相同的权重保持同步。你可以通过访问其中一个模型进行保存,因为权重都是同步的:
torch.save(model_parallel._models[0].state_dict(), filepath)
每个并行内核必须运行相同批数量,并且只允许运行完整批。因此,每个历元在小于100%的样本下运行,剩余部分被忽略。对于数据集变换,这对于训练循环来说不是大问题,但对于推理来说却是个问题。如前所述,我只能使用单核运行进行推理。
直接在jupyter笔记本上运行的DataParallel代码对我来说非常不稳定。它可能运行一段时间,但随后会抛出系统错误、内核崩溃。运行它作为一个脚本似乎是稳定的,所以我们使用以下命令进行转换
!jupyter nbconvert --to script MyModel.ipynb
!python MyModel.py
工作的局限性
PyTorch/XLA的设计导致了一系列PyTorch功能的限制。事实上,这些限制一般适用于TPU设备,并且显然也适用于TensorFlow模型,至少部分适用。具体地说
张量形状在迭代之间是相同的,这也限制了mask的使用。
应避免步骤之间具有不同迭代次数的循环。
不遵循准则会导致(严重)性能下降。 不幸的是,在损失函数中,我需要同时使用掩码和循环。 就我而言,我将所有内容都移到了CPU上,现在速度要快得多。 只需对所有张量执行 my_tensor.cpu().detach().numpy() 即可。 当然,它不适用于需要跟踪梯度的张量,并且由于迁移到CPU而导致自身速度降低。
性能比较
我的Kaggle比赛队友Yuval Reina非常同意分享他的机器配置和训练速度,以便在本节中进行比较。 我还为笔记本添加了一列(这是一台物理机),但它与这些重量级对象不匹配,并且在其上运行的代码未针对性能进行优化。
网络的输入是具有6个通道的512 x 512图像。 我们测量了在训练循环中每秒处理的图像,根据该指标,所描述的TPU配置要比Tesla V100好得多。
如上所述(不带DataParallel)的单核TPU的性能为每秒26张图像,比所有8个核在一起的速度慢约4倍。
由于竞争仍在进行中,我们没有透露Yuval使用的体系结构,但其大小与resnet50并没有太大差异。 但是请注意,由于我们没有运行相同的架构,因此比较是不公平的。
尝试将训练映像切换到GCP SSD磁盘并不能提高性能。
总结
总而言之,我在PyTorch / XLA方面的经验参差不齐。 我遇到了多个错误/工件(此处未全部提及),现有文档和示例受到限制,并且TPU固有的局限性对于更具创意的体系结构而言可能过于严格。 另一方面,它大部分都可以工作,并且当它工作时性能很好。
最后,最重要的一点是,别忘了在完成后停止GCP VM!
作者:Zahar Chikishev
deephub翻译组
在TPU上运行PyTorch的技巧总结相关推荐
- 百度Ai studio上运行pytorch和tensorflow(转载)
转载 链接:https://www.zhihu.com/question/336485090/answer/1017905011 来源:知乎 著作权归作者所有.商业转载请联系作者获得授权,非商业转载请 ...
- cuda无法在电脑上运行_办公技巧 | 专治PPT在别的电脑上无法播放的神器!
不坑老师 教学办公技巧分享 你有135个好友已关注 关注在前面的文章中,不坑老师给大家分享过"解决换电脑后PPT中字体不显示"的办法,不知道大家掌握得怎么样了?今天,不坑老师又来给 ...
- 在linux上运行python脚本(安装pytorch踩坑记录,pyinstaller使用方式,构建docker镜像)
背景 脚本需要导入pytorch等库才能运行. 脚本在windows上运行成功,尝试放到linux上运行. linux服务器内存较小. 方法一:在linux上安装依赖 把脚本放到linux上,直接安装 ...
- PyTorch训练:多个项目在同一块GPU显卡上运行
多个项目在同一块GPU显卡上运行 多个项目在同一块GPU显卡上运行注意事项:
- 编写高效的PyTorch代码技巧(下)
点击上方"算法猿的成长",关注公众号,选择加"星标"或"置顶" 总第 133 篇文章,本文大约 3000 字,阅读大约需要 15 分钟 原文 ...
- pytorch macos_Windows,Linux和MacOS上的PyTorch安装
pytorch macos The installation of PyTorch is pretty straightforward and can be done on all major ope ...
- PyTorch学习记录——PyTorch进阶训练技巧
PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...
- 从Qcheck 1.3 不能在不同操作系统上运行问题(chro124、chro342)说开来------
[本文重在技巧学习,授人以鱼,不如授人以渔!!!] 因为公司项目需要对带宽占用进行测试, 最近看电子工业出版社<网络管理工具使用详解>就qcheck 1.3 不能在不同的操作系统之间运行 ...
- 让win7系统高速运行的优化技巧
这里要跟大家分享的是关于如何让win7系统高速运行的优化技巧,任何一款电脑系统用久了之后速度都会变慢,因此很多用户会到处寻找各种优化渠道,如果想要让自己的win7系统像新安装前一样的速度,其实我们只需 ...
- 为什么要把进程/线程绑定到特定cpu核上运行?(cpu core id coreIdx)opdevsdk_sys_bindThreadCoreId()
看海康hikflow_demo代码,在线程处理函数里调用了绑定函数,把这个线程绑定到某个cpu核上,不知为何要这么做? 原因 答1 现在大家使用的基本上都是多核cpu,一般是4核的.平时应用程序在运行 ...
最新文章
- python flask 返回值 状态码 设置
- BZOJ 4386 Luogu P3597 [POI2015]Wycieczki (矩阵乘法)
- 352. Data Stream as Disjoint Intervals
- jquery获取select选中的文本的值
- 电子商务网站 数据库产品表设计方案
- lua运行外部程序_Lua 协同程序(coroutine)
- mac中的csv文件到windows平台乱码的解决办法
- [转] CPU GPU TPU
- linux底下dig命令报错
- 为什么手机发射功率这么小而基站却能收到信号?
- redhat linux防火墙状态,Redhat下配置iptables防火墙
- 支付宝PC(二维码扫码)支付(Java开发)完整版
- Hello Juejin
- 旋转弹飞控系统半实物仿真平台ETest
- jsonObject.getString()解析任意字段均可强转为string
- 2022中科院自动化所人工智能暑期学校(部分内容)
- 什么是软文营销?为什么做软文营销?
- hihocode-2月29
- Nginx网站服务配置(Nginx服务基础,访问状态统计,访问控制,虚拟主机)
- 如何利用自己的开发能力在国内创建数字藏品 ----如何在国内创建合约发行数字藏品(nft)
热门文章
- 一、Maven-单一架构案例(创建工程,引入依赖,搭建环境:持久化层,)
- Clipboard.js实现复制文本到剪贴板功能
- 【理财】指数基金投资指南
- Ubuntu搜狗输入法不能显示问题
- 电脑桌面的计算机网络回收站图标不见了,桌面回收站图标不见了怎么办 回收站图标找回方法【图文】...
- 五脏与五声 五脏排毒法(五声功)
- 浅谈微信域名防封 微信域名检测工作原理
- 【工控老马】MODBUS通讯协议及编程详解
- Java将PDF转换成图片
- oracle12c安全补丁包,Oracle 12c 及以上版本补丁更新说明及下载方法