点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达

丰色 发自 凹非寺
量子位 报道 | 公众号 QbitAI

CUDA error: out of memory.

多少人用PyTorch“炼丹”时都会被这个bug困扰。

一般情况下,你得找出当下占显存的没用的程序,然后kill掉。

如果不行,还需手动调整batch size到合适的大小……

有点麻烦。

现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。

有多厉害?

相关项目在GitHub才发布没几天就收获了600+星。

一行代码解决内存溢出错误

软件包名叫koila,已经上传PyPI,先安装一下:

pip install koila

现在,假如你面对这样一个PyTorch项目:构建一个神经网络来对FashionMNIST数据集中的图像进行分类。

先定义input、label和model:

# A batch of MNIST image
input = torch.randn(8, 28, 28)# A batch of labels
label = torch.randn(0, 10, [8])class NeuralNetwork(Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = Flatten()self.linear_relu_stack = Sequential(Linear(28 * 28, 512),ReLU(),Linear(512, 512),ReLU(),Linear(512, 10),)def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits

然后定义loss函数、计算输出和losses。

loss_fn = CrossEntropyLoss()# Calculate losses
out = nn(t)
loss = loss_fn(out, label)# Backward pass
nn.zero_grad()
loss.backward()

好了,如何使用koila来防止内存溢出?

超级简单!

只需在第一行代码,也就是把输入用lazy张量wrap起来,并指定bacth维度——

koila就能自动帮你计算剩余的GPU内存并使用正确的batch size了。

在本例中,batch=0,则修改如下:

input = lazy(torch.randn(8, 28, 28), batch=0)

完事儿!就这样和PyTorch“炼丹”时的OOM报错说拜拜。

灵感来自TensorFlow的静态/懒惰评估

下面就来说说koila背后的工作原理。

“CUDA error: out of memory”这个报错通常发生在前向传递(forward pass)中,因为这时需要保存很多临时变量。

koila的灵感来自TensorFlow的静态/懒惰评估(static/lazy evaluation)。

它通过构建图,并仅在必要时运行访问所有相关信息,来确定模型真正需要多少资源。

而只需计算临时变量的shape就能计算各变量的内存使用情况;而知道了在前向传递中使用了多少内存,koila也就能自动选择最佳batch size了。

又是算shape又是算内存的,koila听起来就很慢?

NO。

即使是像GPT-3这种具有96层的巨大模型,其计算图中也只有几百个节点。

而Koila的算法是在线性时间内运行,任何现代计算机都能够立即处理这样的图计算;再加上大部分计算都是单个张量,所以,koila运行起来一点也不慢。

你又会问了,PyTorch Lightning的batch size搜索功能不是也可以解决这个问题吗?

是的,它也可以。

但作者表示,该功能已深度集成在自己那一套生态系统中,你必须得用它的DataLoader,从他们的模型中继承子类,才能训练自己的模型,太麻烦了。

koila灵活又轻量,只需一行代码就能解决问题,非常“大快人心”有没有。

不过目前,koila还不适用于分布式数据的并行训练方法(DDP),未来才会支持多GPU

以及现在只适用于常见的nn.Module类。

ps. koila作者是一位叫做RenChu Wang的小哥。

项目地址:
https://github.com/rentruewang/koila

参考链接:
https://www.reddit.com/r/MachineLearning/comments/r4zaut/p_eliminate_pytorchs_cuda_error_out_of_memory/

本文系网易新闻•网易号特色内容激励计划签约账号【量子位】原创内容,未经账号授权,禁止随意转载。

点个在看 paper不断!

1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目刚发布就揽星600+相关推荐

  1. 【学习react中遇到的坑:内存泄漏报错】

    学习react中遇到的坑:内存泄漏报错 对就是这个错误 Can't perform a React state update on an unmounted component. This is a ...

  2. SAP QM执行事务代码QE23为检验批录入结果,报错-No selected set exists for the inspection point 200 or plant NMDC-

    SAP QM执行事务代码QE23为检验批录入结果,报错-No selected set exists for the inspection point 200 or plant NMDC- 检验批#8 ...

  3. SAP LSMW 事务代码HUPAST的录屏后台执行报错 - Runtime error RAISE_EXCEPTION has occurred - 之分析

    SAP LSMW 事务代码HUPAST的录屏后台执行报错 - Runtime error RAISE_EXCEPTION has occurred - 之分析 因项目上成品库存管理启用了handlin ...

  4. ANSYS-CFX,计算时报错,内存参数报错,return code 1【终极解决方案】

    ANSYS-CFX,计算时报错,内存参数报错,return code 1[终极解决方案] 在CFX计算时经常会遇到内存不足的错误报告,有的算例网格并不多也会出现这样的问题,本文就最近遇到的内存错误问题 ...

  5. oracle lms进程 内存,【案例】Oracle ges resource消耗内存高报错ORA-04031 MOS解决办法...

    天萃荷净 Oracle研究中心案例分析:运维DBA反映Oracle数据库10.2.0.4.12每间隔一段时间就必须重启,运行一断时间报ORA-04031错误oracle ges res cache l ...

  6. 一段简单的代码告诉你什么叫内存溢出

    #include <stdio.h>int FooArray[4] = {1, 1, 1, 1}; int VeryImportantValue = 7;void main() {prin ...

  7. android华为手机获取内存目录,华为手机读取内存文件报错

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 public String getDataColumn(Context context, Uri uri, String selection, Strin ...

  8. git第一次提交代码到码云,git pull 报错:fatal: refusing to merge unrelated histories

    第一次提交的步骤: 1.进入项目目录,执行 git init 2.连接远程仓库 git remote add origin 远程仓库地址(从码云乎哟这github上复制地址即可) 3.报错:git p ...

  9. matlab stk 代码,STK与matlab互联,stkSetPropClassical报错

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 使用stkSetPropClassical设置卫星参数,新手上路,代码报错,在线等大佬 代码: stkNewObj('*/','Satellite','S ...

最新文章

  1. VC中基于 Windows 的精确定时
  2. 奇葩错误 -- modelsim波形显示no data(全X)
  3. 兴趣部落的 Git 迁移实践
  4. 手机重写alert方法(去除网址和关闭网页按钮)
  5. lombok @Builder 是如何实现的
  6. 分布式事务 -- seata框架AT模式实现原理
  7. FragmentPagerAdapter实现刷新
  8. roads 用户体验标准_世界智能大会与ROAD用户体验报告
  9. 密度图的密度估计_不同类型的二维密度图小教程
  10. android studio1.5 for mac,适用于Mac的Android Studio 1.5.x随机崩溃
  11. cookie分号后面没有值_浏览器Cookie介绍
  12. C++ static关键字作用讲解
  13. 使用Xshell连接Linux虚拟机(NAT)
  14. C语言编程>第十六周 ① 给定程序的功能是求1/4的圆周长。函数通过形参得到圆的直径,函数返回1/4的圆周长(圆周长公式为:L=Πd,在程序中定义的变量名要与公式的变量相同)。
  15. galgame序列号怎么查看_国行Switch能完整体验的游戏有哪些?Switch支架掉了怎么办? | Jump指南...
  16. win10 许可证即将过期
  17. 地理位置坐标标准以及转换
  18. Nginx 502的解决方法
  19. 2021-11-15 基于音乐商店NetMusicShop的复杂查询(二)
  20. ADI电路设计电子书课件分享

热门文章

  1. 用Quartus II Timequest Timing Analyzer进行时序分析 :实例讲解 (一)
  2. myeclipse中安装svn插件
  3. 【组队学习】【31期】青少年编程(Scratch 四级)
  4. 【通知】2021-2022-1线性代数课程答疑安排
  5. 【复盘】端端,棒棒哒!
  6. 【通俗理解线性代数】 -- 内积与相关
  7. 算法基础知识科普:8大搜索算法之红黑树(下)
  8. 【Python】Scrapy爬虫实战(腾讯社会招聘职位检索)
  9. 终于“打造”出了一个可以随时随地编程的工具
  10. 都是程序员,凭什么他能站在鄙视链的顶端?