文章目录

  • 前言
  • 一、大幅减少显存占用方法
    • 1. 模型
    • 2. 数据
  • 二、小幅减少显存占用方法
    • 1. 使用inplace
    • 2. 加载、存储等能用CPU就绝不用GPU
    • 3. 低精度计算
    • 4. torch.no_grad
    • 5. 及时清理不用的变量
    • 6. 分段计算
  • 总结

前言

如今的深度学习发展的如火如荼,相信各行各业的大家都或多或少接触过深度学习的知识。相信很多人在跑模型时都见过以下语句:
RuntimeError: CUDA out of memory.
显存不足是很多人感到头疼的问题,毕竟能拥有大量显存的实验室还是少数,而现在的模型已经越跑越大,模型参数量和数据集也越来越大。在别人的Batchsize在32甚至64以上时,我们还在为强行塞下1个Batch而努力。

以下给大家提供一些节省PyTorch显存占用的小技巧,虽然提升不大,但或许能帮你达到可以勉强运行的及格线。

一、大幅减少显存占用方法

想大幅减少显存占用,必定要从最占用显存的方面进行缩减,即 模型数据

1. 模型

在模型上主要是将Backbone改用轻量化网络或者减少网络层数等方法,可以很大程度上减少模型参数量,从而减少显存占用。

2. 数据

数据方面上要减少显存占用主要是使用小 BatchSize 或者将输入数据 Resize 到较小的尺寸。

二、小幅减少显存占用方法

有时候我们可能不想更改模型,而又恰好差一点点显存或者想尽量多塞几个BatchSize,有一些小技巧可以挤出一点点显存。

1. 使用inplace

PyTorch中的一些函数,例如 ReLU、LeakyReLU 等,均有 inplace 参数,可以对传入Tensor进行就地修改,减少多余显存的占用。

2. 加载、存储等能用CPU就绝不用GPU

GPU存储空间宝贵,我们可以选择使用CPU做一些可行的分担,虽然数据传输会浪费一些时间,但是以时间换空间,可以视情况而定,在模型加载中,如 torch.load_state_dict 时,先加载再使用 model.cuda(),尤其是在 resume 断点续训时,可能会报显存不足的错误。数据加载也是,在送入模型前在送入GPU。其余中间的数据处理也可以依循这个原则。

3. 低精度计算

可以使用 float16 半精度混合计算,也可以有效减少显存占用,但是要注意一些溢出情况,如 mean 和 sum等。

4. torch.no_grad

对于 eval 等不需要 bp 及 backward 的时候,可已使用with torch.no_grad,这个和model.eval()有一些差异,可以减少一部分显存占用。

5. 及时清理不用的变量

对于一些使用完成后的变量,及时del掉,例如 backward 完的 Loss,缓存torch.cuda.empty_cache()等。

6. 分段计算

骚操作,我们可以将模型或者数据分段计算。

  1. 模型分段,利用checkpoint将模型分段计算

    # 首先设置输入的input=>requires_grad=True
    # 如果不设置可能会导致得到的gradient为0
    input = torch.rand(1, 10, requires_grad=True)
    layers = [nn.Linear(10, 10) for _ in range(1000)]# 定义要计算的层函数,可以看到我们定义了两个
    # 一个计算前500个层,另一个计算后500个层
    def run_first_half(*args):x = args[0]for layer in layers[:500]:x = layer(x)return xdef run_second_half(*args):x = args[0]for layer in layers[500:-1]:x = layer(x)return x# 引入checkpoint
    from torch.utils.checkpoint import checkpointx = checkpoint(run_first_half, input)
    x = checkpoint(run_second_half, x)
    # 最后一层单独执行
    x = layers[-1](x)
    x.sum.backward()
    
  2. 数据分段,例如原来需要64个batch的数据forward一次后backward一次,现在改为32个batch的数据forward两次后backward一次。

总结

以上是我总结的一些PyTorch节省显存的一些小技巧,希望可以帮助到大家,如果有其它好方法,也欢迎和我讨论。

PyTorch 轻松节省显存的小技巧相关推荐

  1. PyTorch节省显存占用方法

    1-使用inplace操作 2-使用混合精度运算 参考: [1]混合精度训练 http://kevinlt.top/2018/09/14/mixed_precision_training/ [2]py ...

  2. pytorch节省显存_节省新房子的照明

    pytorch节省显存 Our final move into the new house is this weekend. We did a three phase, three week move ...

  3. 释放pytorch占用的gpu显存_Pytorch 节省显存的训练方法总结

    前言 最近的工作中,用到了Pytorch框架训练医学图像分割模型.精心设计的模型经常会因为显存不足而失败.减小模型训练过程中对显存的占用,可能我们能想到最简单的方法就是减小batchsize,减少卷积 ...

  4. torch.cuda.amp自动混合精度训练 —— 节省显存并加快推理速度

    torch.cuda.amp自动混合精度训练 -- 节省显存并加快推理速度 文章目录 torch.cuda.amp自动混合精度训练 -- 节省显存并加快推理速度 1.什么是amp? 2.为什么需要自动 ...

  5. pytorch 优化GPU显存占用,避免out of memory

    pytorch 优化GPU显存占用,避免out of memory 分享一个最实用的招: 用完把tensor删掉,pytorch不会自动清理显存! 代码举例,最后多删除一个,gpu显存占用就会下降,训 ...

  6. 内存和显存_小科普 |“内存”和“显存”有啥关系?

    上周,我们一起了解了什么是DIMM.什么是DDR内存(戳这里),相信有不少人心里还有个疑惑:"内存与显存有什么差别?为什么显卡都GDDR6了,CPU还在用DDR4?"那么我们今天就 ...

  7. 《南溪的目标检测学习笔记》——训练PyTorch模型遇到显存不足的情况怎么办(“OOM: CUDA out of memory“)

    1 前言 在目标检测中,可能会遇到显存不足的情况,我们在这里记录一下解决方案: 2 如何判断真正是出现显存溢出(不是"软件误报") 当前需要分配的显存在600MiB以下, 例如: ...

  8. pytorch训练时显存溢出

    网络在前期可以正常训练,但训练几轮后就发生显存爆炸的问题,调整输入大小或者每次循环都清除显存 也无法解决问题,后来经过查询,是在对loss求和时,直接使用 tl += loss 可以看到,loss是张 ...

  9. pytorch如何查看显存利用情况

    最近搞LSTM优化,但是显存利用率不稳定,想看一下LSTM的显存占用情况,搜罗了一通,发现一个不错的开源工具,记录分享一下. 首先上项目地址:https://github.com/Oldpan/Pyt ...

  10. 不优雅地解决pytorch模型测试阶段显存溢出问题

    在一次测试一个超分辨模型LESRCNN(作者提供了已训练好的模型)时,发生了CUDA out of memory的错误(虽然显卡有8G显存,但还是差了些): RuntimeError: CUDA ou ...

最新文章

  1. 使用元组输入进行计算和归约
  2. ios底部栏设计规范_UI设计:iOS 界面规范
  3. Android开发--AsyncTask异步任务(二)
  4. sql 分类汇总 列_分类汇总哪家强?R、Python、SAS、SQL?
  5. GoLand配置数据库、远程host以及远程调试
  6. PHP开发APP接口(二)
  7. intel服务器ssd系列,英特尔发布S3710/S3610服务器SSD新品
  8. 华为交换机vlan配置
  9. string不能输入空格,如何输入有空格字符串呢
  10. android 投屏 版本号,安卓设备投屏画质模糊及投屏延迟的调整方法
  11. 引入tinymce-vue后调试器报错 Refused to apply styl
  12. Flutter更改主题颜色报错:type ‘Color‘ is not a subtype of type ‘MaterialColor‘
  13. 3分钟搞定下载微信视频号视频!无需第三方软件,亲测有效!
  14. CS229 --Lecture1 Introduction
  15. PLSQL - 递归子查询RSF打破CONNECT BY LOOP限制
  16. 如何了解舆情传播的平台及路径?
  17. E-PUCK机器人-软件
  18. 笔记本可以跑虚拟机吗_什么笔记本跑虚拟机不卡?
  19. Apache Durid (HDFS原理 特性 读写测试 集群部署 架构设计)
  20. 你绝对没用过的三电源切换电路

热门文章

  1. “视”不可挡:征兵招警,近视手术成“通关法宝”
  2. 【Win10电脑更新】Win10电脑更新后小娜Cortana不能登录、咨询和兴趣不能查看的问题怎么解决
  3. windows 10 cortana搜索功能失效
  4. 1000:从今天开始入坑C语言
  5. Maxcompute Sql性能调优(1)
  6. php三D立体模拟,CSS3使用3D环境实现立体魔方效果的实例代码分享
  7. 腾讯实习结束总结+感悟
  8. 【大学总结】迟到但未缺席的大学总结
  9. Python非线性拟合自定义函数参数(对标MATLAB-nlinfit函数)
  10. 京东自动评价助手/京东评价