1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目刚发布就揽星600+
点击上方“视学算法”,选择加"星标"或“置顶”
重磅干货,第一时间送达
丰色 发自 凹非寺
量子位 报道 | 公众号 QbitAI
CUDA error: out of memory.
多少人用PyTorch“炼丹”时都会被这个bug困扰。
一般情况下,你得找出当下占显存的没用的程序,然后kill掉。
如果不行,还需手动调整batch size到合适的大小……
有点麻烦。
现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。
有多厉害?
相关项目在GitHub才发布没几天就收获了600+星。
一行代码解决内存溢出错误
软件包名叫koila,已经上传PyPI,先安装一下:
pip install koila
现在,假如你面对这样一个PyTorch项目:构建一个神经网络来对FashionMNIST数据集中的图像进行分类。
先定义input、label和model:
# A batch of MNIST image
input = torch.randn(8, 28, 28)# A batch of labels
label = torch.randn(0, 10, [8])class NeuralNetwork(Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = Flatten()self.linear_relu_stack = Sequential(Linear(28 * 28, 512),ReLU(),Linear(512, 512),ReLU(),Linear(512, 10),)def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits
然后定义loss函数、计算输出和losses。
loss_fn = CrossEntropyLoss()# Calculate losses
out = nn(t)
loss = loss_fn(out, label)# Backward pass
nn.zero_grad()
loss.backward()
好了,如何使用koila来防止内存溢出?
超级简单!
只需在第一行代码,也就是把输入用lazy张量wrap起来,并指定bacth维度——
koila就能自动帮你计算剩余的GPU内存并使用正确的batch size了。
在本例中,batch=0,则修改如下:
input = lazy(torch.randn(8, 28, 28), batch=0)
完事儿!就这样和PyTorch“炼丹”时的OOM报错说拜拜。
灵感来自TensorFlow的静态/懒惰评估
下面就来说说koila背后的工作原理。
“CUDA error: out of memory”这个报错通常发生在前向传递(forward pass)中,因为这时需要保存很多临时变量。
koila的灵感来自TensorFlow的静态/懒惰评估(static/lazy evaluation)。
它通过构建图,并仅在必要时运行访问所有相关信息,来确定模型真正需要多少资源。
而只需计算临时变量的shape就能计算各变量的内存使用情况;而知道了在前向传递中使用了多少内存,koila也就能自动选择最佳batch size了。
又是算shape又是算内存的,koila听起来就很慢?
NO。
即使是像GPT-3这种具有96层的巨大模型,其计算图中也只有几百个节点。
而Koila的算法是在线性时间内运行,任何现代计算机都能够立即处理这样的图计算;再加上大部分计算都是单个张量,所以,koila运行起来一点也不慢。
你又会问了,PyTorch Lightning的batch size搜索功能不是也可以解决这个问题吗?
是的,它也可以。
但作者表示,该功能已深度集成在自己那一套生态系统中,你必须得用它的DataLoader,从他们的模型中继承子类,才能训练自己的模型,太麻烦了。
而koila灵活又轻量,只需一行代码就能解决问题,非常“大快人心”有没有。
不过目前,koila还不适用于分布式数据的并行训练方法(DDP),未来才会支持多GPU。
以及现在只适用于常见的nn.Module类。
ps. koila作者是一位叫做RenChu Wang的小哥。
项目地址:
https://github.com/rentruewang/koila
参考链接:
https://www.reddit.com/r/MachineLearning/comments/r4zaut/p_eliminate_pytorchs_cuda_error_out_of_memory/
— 完 —
本文系网易新闻•网易号特色内容激励计划签约账号【量子位】原创内容,未经账号授权,禁止随意转载。
点个在看 paper不断!
1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目刚发布就揽星600+相关推荐
- 【学习react中遇到的坑:内存泄漏报错】
学习react中遇到的坑:内存泄漏报错 对就是这个错误 Can't perform a React state update on an unmounted component. This is a ...
- SAP QM执行事务代码QE23为检验批录入结果,报错-No selected set exists for the inspection point 200 or plant NMDC-
SAP QM执行事务代码QE23为检验批录入结果,报错-No selected set exists for the inspection point 200 or plant NMDC- 检验批#8 ...
- SAP LSMW 事务代码HUPAST的录屏后台执行报错 - Runtime error RAISE_EXCEPTION has occurred - 之分析
SAP LSMW 事务代码HUPAST的录屏后台执行报错 - Runtime error RAISE_EXCEPTION has occurred - 之分析 因项目上成品库存管理启用了handlin ...
- ANSYS-CFX,计算时报错,内存参数报错,return code 1【终极解决方案】
ANSYS-CFX,计算时报错,内存参数报错,return code 1[终极解决方案] 在CFX计算时经常会遇到内存不足的错误报告,有的算例网格并不多也会出现这样的问题,本文就最近遇到的内存错误问题 ...
- oracle lms进程 内存,【案例】Oracle ges resource消耗内存高报错ORA-04031 MOS解决办法...
天萃荷净 Oracle研究中心案例分析:运维DBA反映Oracle数据库10.2.0.4.12每间隔一段时间就必须重启,运行一断时间报ORA-04031错误oracle ges res cache l ...
- 一段简单的代码告诉你什么叫内存溢出
#include <stdio.h>int FooArray[4] = {1, 1, 1, 1}; int VeryImportantValue = 7;void main() {prin ...
- android华为手机获取内存目录,华为手机读取内存文件报错
该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 public String getDataColumn(Context context, Uri uri, String selection, Strin ...
- git第一次提交代码到码云,git pull 报错:fatal: refusing to merge unrelated histories
第一次提交的步骤: 1.进入项目目录,执行 git init 2.连接远程仓库 git remote add origin 远程仓库地址(从码云乎哟这github上复制地址即可) 3.报错:git p ...
- matlab stk 代码,STK与matlab互联,stkSetPropClassical报错
该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 使用stkSetPropClassical设置卫星参数,新手上路,代码报错,在线等大佬 代码: stkNewObj('*/','Satellite','S ...
最新文章
- VC中基于 Windows 的精确定时
- 奇葩错误 -- modelsim波形显示no data(全X)
- 兴趣部落的 Git 迁移实践
- 手机重写alert方法(去除网址和关闭网页按钮)
- lombok @Builder 是如何实现的
- 分布式事务 -- seata框架AT模式实现原理
- FragmentPagerAdapter实现刷新
- roads 用户体验标准_世界智能大会与ROAD用户体验报告
- 密度图的密度估计_不同类型的二维密度图小教程
- android studio1.5 for mac,适用于Mac的Android Studio 1.5.x随机崩溃
- cookie分号后面没有值_浏览器Cookie介绍
- C++ static关键字作用讲解
- 使用Xshell连接Linux虚拟机(NAT)
- C语言编程>第十六周 ① 给定程序的功能是求1/4的圆周长。函数通过形参得到圆的直径,函数返回1/4的圆周长(圆周长公式为:L=Πd,在程序中定义的变量名要与公式的变量相同)。
- galgame序列号怎么查看_国行Switch能完整体验的游戏有哪些?Switch支架掉了怎么办? | Jump指南...
- win10 许可证即将过期
- 地理位置坐标标准以及转换
- Nginx 502的解决方法
- 2021-11-15 基于音乐商店NetMusicShop的复杂查询(二)
- ADI电路设计电子书课件分享