torch之optimizer.step() 和loss.backward()和scheduler.step()的关系与区别

  • 由于接触torch时间不久,所有对此比较困惑,遇到如下博文解释十分详细,故转载至此。(原文地址)

1.optimizer.step()

因为有人问我optimizer的step为什么不能放在min-batch那个循环之外,还有optimizer.step和loss.backward的区别;那么我想把答案记录下来。

首先需要明确optimzier优化器的作用, 形象地来说,优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用,这也是机器学习里面最一般的方法论。

  1. 从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西:1. 优化器需要知道当前的网络或者别的什么模型的参数空间,这也就是为什么在训练文件中,正式开始训练之前需要将网络的参数放到优化器里面,比如使用pytorch的话总会出现类似如下的代码:
optimizer_G = Adam(model_G.parameters(), lr=train_c.lr_G)   # lr 使用的是初始lr
optimizer_D = Adam(model_D.parameters(), lr=train_c.lr_D)
  1. 需要知道反向传播的梯度信息,我们还是从代码入手,如下所示是Pytorch 中SGD优化算法的step()函数具体写法,具体SGD的写法放在参考部分
def step(self, closure=None):"""Performs a single optimization step.Arguments:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:weight_decay = group['weight_decay']momentum = group['momentum']dampening = group['dampening']nesterov = group['nesterov']for p in group['params']:if p.grad is None:continued_p = p.grad.dataif weight_decay != 0:d_p.add_(weight_decay, p.data)if momentum != 0:param_state = self.state[p]if 'momentum_buffer' not in param_state:buf = param_state['momentum_buffer'] = d_p.clone()else:buf = param_state['momentum_buffer']buf.mul_(momentum).add_(1 - dampening, d_p)if nesterov:d_p = d_p.add(momentum, buf)else:d_p = bufp.data.add_(-group['lr'], d_p)return loss

从上面的代码可以看到step这个函数使用的是参数空间(param_groups)中的grad,也就是当前参数空间对应的梯度,这也就解释了为什么optimzier使用之前需要zero清零一下,因为如果不清零,那么使用的这个grad就得同上一个mini-batch有关,这不是我们需要的结果。再回过头来看,我们知道optimizer更新参数空间需要基于反向梯度,因此,当调用optimizer.step()的时候应当是loss.backward()的时候(loss.backward()的具体运算过程可以参看Pytorch 入门),这也就是经常会碰到,如下情况

total_loss.backward()
optimizer_G.step()

loss.backward()在前,然后跟一个step。

那么为什么optimizer.step()需要放在每一个batch训练中,而不是epoch训练中,这是因为现在的mini-batch训练模式是假定每一个训练集就只有mini-batch这样大,因此实际上可以将每一次mini-batch看做是一次训练,一次训练更新一次参数空间,因而optimizer.step()放在这里。


2.scheduler.step()

scheduler.step() 按照Pytorch的定义是用来 更新优化器的学习率的,一般是按照epoch为单位进行更换,即多少个epoch后更换一次学习率,因而scheduler.step()放在epoch这个大循环下。通常我们有:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

torch之optimizer.step() 和loss.backward()和scheduler.step()的关系与区别相关推荐

  1. Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别

    参考   Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 - 云+社区 - 腾讯云 首先需要明确optimzier优化 ...

  2. Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 (Pytorch 代码讲解)

    https://blog.csdn.net/xiaoxifei/article/details/87797935

  3. Pytorch:optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用

    在用pytorch训练模型时,通常会在遍历epochs的每一轮batach的过程中依次用到以下三个函数 optimizer.zero_grad(): loss.backward(): optimize ...

  4. optimizer.zero_grad()和loss.backward()

    1.optimizer.zero_grad()和loss.backward()先后问题 刚开始学习深度学习,就是不明白,为什么第一次运行的时候就要optimizer.zero_grad()(梯度清零) ...

  5. optimizer.zero_grad(),loss.backward(),optimizer.step()的作用原理

    目录 前言 一.optimizer.zero_grad() 二. loss.backward() 三.optimizer.step() 前言 在用pytorch训练模型时,通常会在遍历epochs的过 ...

  6. model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】

    1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...

  7. Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

    引言 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.s ...

  8. optimizer.zero_grad(), loss.backward(), optimizer.step()的理解及使用

    optimizer.zero_grad,loss.backward,optimizer.step 用法介绍 optimizer.zero_grad(): loss.backward(): optimi ...

  9. Pytorch的model.train() model.eval() torch.no_grad() 为什么测试的时候不调用loss.backward()计算梯度还要关闭梯度

    使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval model.train() 启用 BatchNormalization 和 Dropout 告诉我们的网络,这 ...

最新文章

  1. QuarkXPress 2020中文版
  2. 堆栈转化8进制_11张卡片学会进制转换
  3. oracle socct用户,oracle 19c 添加 scott用户与表
  4. beaninfo详解源码解析 java_【Spring源码分析】Bean加载流程概览
  5. geek软件是干什么的_极客是什么?
  6. JS双引号替换单引号
  7. C# List最大值最小值问题 List排序问题 List Max/Min
  8. Event Loop - JavaScript和node运行机制
  9. 4针串口线接法图_​RS232串口线常见接法与引脚定义
  10. 2017 CCPC 秦皇岛 G 题 ZOJ 3987 - Numbers (高精度+贪心)
  11. Linux系统无网络安装nginx
  12. [Java] 身份证号码验证
  13. 利用sklearn.cluster实现k均值聚类
  14. mysql中字段长度到底是字符数还是字节数?
  15. oracle指定导出低版本号,oracle 高版本导出低版本数据库并且导入到低版本数据的方法...
  16. MIKE 21 教程 1.7 网格生成过程中的常见报错与问题
  17. 局域网病毒入侵原理和防御
  18. 视频系统 流媒体 rtsp hls h264 h265 aac 高并发 低延时 系统 设计 录像 视频合成 转发 点播 快进 快退 单步播放 分布式集群
  19. 个人博客标签和文章的表结构设计
  20. 用计算机弹暖暖数字代码,奇迹暖暖网页版计算器

热门文章

  1. STM32 ADC最大采样频率
  2. 微信小程序收款手续费_【微信支付】微信小程序支付开发者文档
  3. 个人计算机是国产芯片,全球最纯国产PC诞生!所有芯片/系统都是国产
  4. java 算出下一个工作日_Java 计算一段时间段内除去周六日、节假日的工作日数———超详细(全)...
  5. samba配置共享打印机
  6. Gradle 学习 ----Gradle 进阶说明
  7. PPT制作(文字排版)
  8. MFC如何让背景图随窗口大小改变
  9. 配置完dcom需要重启计算机,DCOM电脑自动重启(win7系统一直反复重启)
  10. 中关村-DIY之国外网盘下载测试