参考:
https://blog.csdn.net/Princeicon/article/details/108058822
https://blog.csdn.net/weixin_43643246/article/details/107785089

假设情景:
batch_size = 10 #每批次大小
total_num = 1000 #数据总量
按照 训练一个批次数据,更新一次梯度;
训练步数 train_steps = 1000 / 10 = 100
梯度更新步数 = 1000 / 10 = 100

当显存不足以支持每次 10 的训练量!
需要减小 batch_size

通过设置gradient_accumulation_steps = 2
batch_size = 10 / 2 =5
即训练2个批次数据,更新一次梯度,每个批次数据量为5(减小了显存压力,但未改变梯度更新数据量–10个数据一更新)

结果:训练步数 tran_steps = 1000 / 5 = 200 增加了一倍
梯度更新步数 1000 / 10 = 100 未改变

总结:梯度累加就是,每获取1个batch的数据,计算一次梯度,梯度不清空,不断累加,累加到gradient_accumulation_steps次数后,根据累加梯度更新参数,清空梯度,进行下一次循环。

gradient_accumulation_steps --梯度累加理解相关推荐

  1. 通俗理解深度学习梯度累加(Gradient Accumulation)的原理

    首先你得明白什么是梯度,可以看我之前写的一篇博客 : 微分与梯度的概念理解 本质上,梯度是一种方向导数,是一个矢量,因此这里的梯度累加并不是简单的相加,而是类似于初高中物理学的力的合成,梯度作为一种方 ...

  2. pytorch多gpu DataParallel 及梯度累加解决显存不平衡和显存不足问题

      最近在做图像分类实验时,在4个gpu上使用pytorch的DataParallel 函数并行跑程序,批次为16时会报如下所示的错误:   RuntimeError: CUDA out of mem ...

  3. 梯度累加策略对准确率的影响

    从曲线整体分析来看等效的(geng_xing_bu_chang*batch_size=等效batch_size的大小,)倍数越大准确率损失越严重(虽然30到300的采样太稀疏但是可以忽略) 如下图 从 ...

  4. TF实现多minibatch梯度累加及反向更新

    参考链接: TF中optimizor源码: https://blog.csdn.net/Huang_Fj/article/details/102688509 如何累加梯度进行反向: https://s ...

  5. 【计数网络】梯度累加增加LCFCN的BatchSize

    LCFCN是一个以分割网络为基础的专用于计数的网络. LCFCN模型由于loss的特殊性 batch size 目前只能为1 LCFCN代码 https://github.com/ElementAI/ ...

  6. Pytorch分布式训练/多卡训练(二) —— Data Parallel并行(DDP)(2.2)(代码示例)(BN同步主卡保存梯度累加多卡测试inference随机种子seed)

    DDP的使用非常简单,因为它不需要修改你网络的配置.其精髓只有一句话 model = DistributedDataPrallel(model, device_ids=[local_rank], ou ...

  7. Gradient Accumulation 梯度累加 (Pytorch)

    我们在训练神经网络的时候,batch_size的大小会对最终的模型效果产生很大的影响.一定条件下,batch_size设置的越大,模型就会越稳定.batch_size的值通常设置在 8-32 之间,但 ...

  8. 梯度累加(Gradient Accumulation)

    受显存限制,运行一些预训练的large模型时,batch-size往往设置的比较小1-4,否则就会'CUDA out of memory',但一般batch-size越大(一定范围内)模型收敛越稳定效 ...

  9. 【深度学习训练小技巧】1080ti与2080ti区别、apex与梯度累加

    文章目录 1080ti与2080ti区别 在目标检测和分割任务中使用apex 梯度累加(一般不在目标检测中使用) torch.no_grad() 当我们没有足够的显卡训练模型时,apex和梯度累加是有 ...

最新文章

  1. ABS是啥,为什么区块链可以与它完美结合?
  2. 大年初一微信闪退?看看如何修复的
  3. Hadoop学习之HDFS架构(二)
  4. 汇编对sp指针进行修改_从汇编理解函数调用的过程
  5. Python折半查找(二分查找)
  6. vb怎么自动连接服务器,VB 如何制作连接服务器的进程
  7. Windows下安装postgresql_psycopg2时出现 Unabled to find vcvarsall.bat 的解决办法
  8. ns2 java_【NS2】用eclipse调试NS2(转载)
  9. 实现一个定时任务管理器
  10. 详细讲述matlab中矩阵的卷积函数convn
  11. 110kV变电站电气一次系统设计
  12. Hadoop的学习前景怎么样,Hadoop培训后的职业规划
  13. Ztree Fa-Awesome 图标使用
  14. 冒烟测试的7个好处,你是否经常用到它?
  15. Error creating bean with name ‘sqlSessionFactory’ defined in class path reso
  16. wordcloud词云可视化
  17. 生活中软件易用性的例子_多用“举出例子”“比如说”,来进行生活中的语言交流...
  18. c语言在测绘工程中的作用,测绘C程序设计实习报告
  19. 扬州大学计算机控制技术课设,计算机控制技术的课设.doc
  20. 百度地图 由起点和终点 获取中间路线的坐标集

热门文章

  1. 如何找到蓝奏云网盘登录后的ylogin、phpdisk_info?
  2. celeron处理器_NISE3600E 基于第三代Intel Core i7/i5/i3处理器,无风扇系统,1个PCIex4扩展槽...
  3. 6-3 定义Person类
  4. ROS path问题解决方案
  5. 非理工科编程零基础文科生秒懂python学习笔记:pandas库数据表格创建和运算基础有哪些?
  6. 基于PHP的租赁商城系统(包括数据库和后台)
  7. 应该怎样学习Unity3D
  8. SAP开发框架系列之 自定义单据
  9. FPGA零基础学习:基于FPGA的音乐蜂鸣器设计(附代码)
  10. JAVA实体Do-Dto转换类 Converter