梯度累积 - gradient accumulation

在深度学习训练的时候,数据的batch size大小受到GPU内存限制,batch size大小会影响模型最终的准确性和训练过程的性能。在GPU内存不变的情况下,模型越来越大,那么这就意味着数据的batch size只能缩小,这个时候,梯度累积(Gradient Accumulation)可以作为一种简单的解决方案来解决这个问题。

梯度累积(Gradient Accumulation)是一种不需要额外硬件资源就可以增加批量样本数量(Batch Size)的训练技巧。这是一个通过时间换空间的优化措施,它将多个Batch训练数据的梯度进行累积,在达到指定累积次数后,使用累积梯度统一更新一次模型参数,以达到一个较大Batch Size的模型训练效果。累积梯度等于多个Batch训练数据的梯度的平均值

所谓梯度累积过程,其实很简单,我们梯度下降所用的梯度,实际上是多个样本算出来的梯度的平均值,以batch_size=128为例,你可以一次性算出128个样本的梯度然后平均,我也可以每次算16个样本的平均梯度,然后缓存累加起来,算够了8次之后,然后把总梯度除以8,然后才执行参数更新。当然,必须累积到了8次之后,用8次的平均梯度才去更新参数,不能每算16个就去更新一次,不然就是batch_size=16了。

传统的深度学习

for i, (inputs, labels) in enumerate(trainloader):optimizer.zero_grad()                   # 梯度清零outputs = net(inputs)                   # 正向传播loss = criterion(outputs, labels)       # 计算损失loss.backward()                         # 反向传播,计算梯度optimizer.step()                        # 更新参数if (i+1) % evaluation_steps == 0:evaluate_model()

具体流程:

  1. optimizer.zero_grad(),将前一个batch计算之后的网络梯度清零
  2. 正向传播,将数据和标签传入网络,过infer计算得到预测结果
  3. 根据预测结果与label,计算损失值
  4. loss.backward() ,利用损失进行反向传播,计算参数梯度
  5. optimizer.step(),利用计算的参数梯度更新网络参数

简单的说就是进来一个batch的数据,计算一次梯度,更新一次网络。

梯度累积方式

for i, (inputs, labels) in enumerate(trainloader):outputs = net(inputs)                   # 正向传播loss = criterion(outputs, labels)       # 计算损失函数loss = loss / accumulation_steps        # 梯度均值,损失标准化loss.backward()                         # 梯度均值累加,反向传播,计算梯度# 累加到指定的 steps 后再更新参数if (i+1) % accumulation_steps == 0:     optimizer.step()                    # 更新参数optimizer.zero_grad()               # 梯度清零if (i+1) % evaluation_steps == 0:evaluate_model()

具体流程:

  1. 正向传播,将数据传入网络,得到预测结果
  2. 根据预测结果与label,计算损失值
  3. 利用损失进行反向传播,计算参数梯度
  4. 重复1-3,不清空梯度,而是将梯度累加
  5. 梯度累加达到固定次数之后,更新参数,然后将梯度清零

梯度累积时,每个batch 仍然正常前向传播以及反向传播,但是反向传播之后并不进行梯度清零,因为 PyTorch 中的backward() 执行的是梯度累加的操作,所以当我们调用N次 loss.backward() 后,这N个batch 的梯度都会累加起来。但是,我们需要的是一个平均的梯度,或者说平均的损失,所以我们应该将每次计算得到的 loss除以 accum_steps。

总结来讲,梯度累积就是每计算一个batch的梯度,不进行清零,而是做梯度的累加,当累加到一定的次数(accumulation_steps)之后,再更新网络参数,然后将梯度清零。
        通过这种参数延迟更新的手段,可以实现与采用大batch size相近的效果。在平时的实验过程中,我一般会采用梯度累加技术,大多数情况下,采用梯度累加训练的模型效果,要比采用小batch size训练的模型效果要好很多。

注意事项:

  • 一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize '变相' 扩大了8倍,是实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大

pytorch 梯度累积(gradient accumulation)相关推荐

  1. 梯度累加(Gradient Accumulation)

    受显存限制,运行一些预训练的large模型时,batch-size往往设置的比较小1-4,否则就会'CUDA out of memory',但一般batch-size越大(一定范围内)模型收敛越稳定效 ...

  2. pytorch梯度累积

    增大batchsize训练模型,一般都能带来一定的提升.在显卡内存不够的情况下,可以通过梯度累积的方式,来扩大batchsize. 因为pytorch中,反向传播之后,梯度是不清零的,因此要实现梯度累 ...

  3. 通俗理解深度学习梯度累加(Gradient Accumulation)的原理

    首先你得明白什么是梯度,可以看我之前写的一篇博客 : 微分与梯度的概念理解 本质上,梯度是一种方向导数,是一个矢量,因此这里的梯度累加并不是简单的相加,而是类似于初高中物理学的力的合成,梯度作为一种方 ...

  4. Gradient Accumulation 梯度累加 (Pytorch)

    我们在训练神经网络的时候,batch_size的大小会对最终的模型效果产生很大的影响.一定条件下,batch_size设置的越大,模型就会越稳定.batch_size的值通常设置在 8-32 之间,但 ...

  5. PyTorch中的梯度累积

    我们在训练神经网络的时候,超参数batch_size的大小会对模型最终效果产生很大的影响,通常的经验是,batch_size越小效果越差:batch_size越大模型越稳定.理想很丰满,现实很骨感,很 ...

  6. [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

    [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 文章目录 [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 0x00 摘要 0x01 概述 1.1 前 ...

  7. 梯度累积(Gradient Accumulation)

    随着深度学习模型参数量的增加,现有GPU加载一个深度模型(尤其是预训练模型)后,剩余显存无法容纳很多的训练数据,甚至会仅能容纳一条训练数据. 梯度累积(Gradient Accumulation)是一 ...

  8. pytorch DDP加速之gradient accumulation设置

    pytorch DDP 参考:https://zhuanlan.zhihu.com/p/250471767 GPU高效通信算法-Ring Allreduce: https://www.zhihu.co ...

  9. AI系统——梯度累积算法

    明天博士论文要答辩了,只有一张12G二手卡,今晚通宵要搞定10个模型实验 挖槽,突然想出一个T9开天霹雳模型,加载不进去我那张12G的二手卡,感觉要错过今年上台Best Paper领奖 上面出现的问题 ...

最新文章

  1. shell逐行读取文件
  2. SVN终端演练(个人开发\多人开发)
  3. 两种 js下载文件的方法(转)
  4. 理清Python网络编程
  5. Cpp 对象模型探索 / 对象访问成员变量的原理
  6. centos 7.6安装java_安装 QRadar Community Edition
  7. Java 新手习题()
  8. 十八般武艺玩转GaussDB(DWS)性能调优:SQL改写
  9. python深浅拷贝 面试_[面试题二]百度资深面试官:python赋值、浅拷贝与深拷贝
  10. SpringBoot重复配置数据库导致Access denied for user ‘root‘@‘localhost‘ (using password: YES)
  11. 牛客网多校第9场 E Music Game 【思维+数学期望】
  12. 计算机网络 - 物理层
  13. 让“王码五笔输入法”成为你的专用输入法!
  14. 伺服电机转矩常数的标定方法
  15. 阿里云服务器ECS不能通过浏览器(外网)访问的解决办法
  16. 去中心化存储项目终极指南 | Filecoin, Storj 和 PPIO 项目异同(上)
  17. 接入百家号流量的方法
  18. PMP章节练习(第六章:项目进度管理)
  19. 面试篇——Spring
  20. Fortran文件操作-open

热门文章

  1. 网络安全——漏洞扫描工具(AWVS的使用)
  2. imx6 linux 时钟,迅为-iMX6开发板-驱动-实时时钟RTC以及Linux-c测试例程
  3. Halcon形态学处理-腐蚀、膨胀、开运算、闭运算、顶帽运算和底帽运算
  4. 解决github下载慢的问题
  5. 条条大路通罗马:mPaaS 新增体验入口
  6. 关于建构主义的一点思考
  7. Python数据可视化Pyecharts的全局配置
  8. 计算机信息管理员考试知识点,网络管理员考试知识点1—计算机硬件基础
  9. ECCV 2020 论文大盘点-去雾去雨去模糊篇
  10. java的3大注释快捷键大全