torch之optimizer.step() 和loss.backward()和scheduler.step()的关系与区别
torch之optimizer.step() 和loss.backward()和scheduler.step()的关系与区别
- 由于接触torch时间不久,所有对此比较困惑,遇到如下博文解释十分详细,故转载至此。(原文地址)
1.optimizer.step()
因为有人问我optimizer的step为什么不能放在min-batch那个循环之外,还有optimizer.step和loss.backward的区别;那么我想把答案记录下来。
首先需要明确optimzier优化器的作用, 形象地来说,优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用,这也是机器学习里面最一般的方法论。
- 从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西:1. 优化器需要知道当前的网络或者别的什么模型的参数空间,这也就是为什么在训练文件中,正式开始训练之前需要将网络的参数放到优化器里面,比如使用pytorch的话总会出现类似如下的代码:
optimizer_G = Adam(model_G.parameters(), lr=train_c.lr_G) # lr 使用的是初始lr
optimizer_D = Adam(model_D.parameters(), lr=train_c.lr_D)
- 需要知道反向传播的梯度信息,我们还是从代码入手,如下所示是Pytorch 中SGD优化算法的step()函数具体写法,具体SGD的写法放在参考部分
def step(self, closure=None):"""Performs a single optimization step.Arguments:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:weight_decay = group['weight_decay']momentum = group['momentum']dampening = group['dampening']nesterov = group['nesterov']for p in group['params']:if p.grad is None:continued_p = p.grad.dataif weight_decay != 0:d_p.add_(weight_decay, p.data)if momentum != 0:param_state = self.state[p]if 'momentum_buffer' not in param_state:buf = param_state['momentum_buffer'] = d_p.clone()else:buf = param_state['momentum_buffer']buf.mul_(momentum).add_(1 - dampening, d_p)if nesterov:d_p = d_p.add(momentum, buf)else:d_p = bufp.data.add_(-group['lr'], d_p)return loss
从上面的代码可以看到step这个函数使用的是参数空间(param_groups)中的grad,也就是当前参数空间对应的梯度,这也就解释了为什么optimzier使用之前需要zero清零一下,因为如果不清零,那么使用的这个grad就得同上一个mini-batch有关,这不是我们需要的结果。再回过头来看,我们知道optimizer更新参数空间需要基于反向梯度,因此,当调用optimizer.step()的时候应当是loss.backward()的时候(loss.backward()的具体运算过程可以参看Pytorch 入门),这也就是经常会碰到,如下情况
total_loss.backward()
optimizer_G.step()
loss.backward()在前,然后跟一个step。
那么为什么optimizer.step()需要放在每一个batch训练中,而不是epoch训练中,这是因为现在的mini-batch训练模式是假定每一个训练集就只有mini-batch这样大,因此实际上可以将每一次mini-batch看做是一次训练,一次训练更新一次参数空间,因而optimizer.step()放在这里。
2.scheduler.step()
scheduler.step() 按照Pytorch的定义是用来 更新优化器的学习率的,一般是按照epoch为单位进行更换,即多少个epoch后更换一次学习率,因而scheduler.step()放在epoch这个大循环下。通常我们有:
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)
在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。
torch之optimizer.step() 和loss.backward()和scheduler.step()的关系与区别相关推荐
- Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别
参考 Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 - 云+社区 - 腾讯云 首先需要明确optimzier优化 ...
- Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 (Pytorch 代码讲解)
https://blog.csdn.net/xiaoxifei/article/details/87797935
- Pytorch:optim.zero_grad()、pred=model(input)、loss=criterion(pred,tgt)、loss.backward()、optim.step()的作用
在用pytorch训练模型时,通常会在遍历epochs的每一轮batach的过程中依次用到以下三个函数 optimizer.zero_grad(): loss.backward(): optimize ...
- optimizer.zero_grad()和loss.backward()
1.optimizer.zero_grad()和loss.backward()先后问题 刚开始学习深度学习,就是不明白,为什么第一次运行的时候就要optimizer.zero_grad()(梯度清零) ...
- optimizer.zero_grad(),loss.backward(),optimizer.step()的作用原理
目录 前言 一.optimizer.zero_grad() 二. loss.backward() 三.optimizer.step() 前言 在用pytorch训练模型时,通常会在遍历epochs的过 ...
- model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】
1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...
- Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解
引言 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.s ...
- optimizer.zero_grad(), loss.backward(), optimizer.step()的理解及使用
optimizer.zero_grad,loss.backward,optimizer.step 用法介绍 optimizer.zero_grad(): loss.backward(): optimi ...
- Pytorch的model.train() model.eval() torch.no_grad() 为什么测试的时候不调用loss.backward()计算梯度还要关闭梯度
使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval model.train() 启用 BatchNormalization 和 Dropout 告诉我们的网络,这 ...
最新文章
- QuarkXPress 2020中文版
- 堆栈转化8进制_11张卡片学会进制转换
- oracle socct用户,oracle 19c 添加 scott用户与表
- beaninfo详解源码解析 java_【Spring源码分析】Bean加载流程概览
- geek软件是干什么的_极客是什么?
- JS双引号替换单引号
- C# List最大值最小值问题 List排序问题 List Max/Min
- Event Loop - JavaScript和node运行机制
- 4针串口线接法图_​RS232串口线常见接法与引脚定义
- 2017 CCPC 秦皇岛 G 题 ZOJ 3987 - Numbers (高精度+贪心)
- Linux系统无网络安装nginx
- [Java] 身份证号码验证
- 利用sklearn.cluster实现k均值聚类
- mysql中字段长度到底是字符数还是字节数?
- oracle指定导出低版本号,oracle 高版本导出低版本数据库并且导入到低版本数据的方法...
- MIKE 21 教程 1.7 网格生成过程中的常见报错与问题
- 局域网病毒入侵原理和防御
- 视频系统 流媒体 rtsp hls h264 h265 aac 高并发 低延时 系统 设计 录像 视频合成 转发 点播 快进 快退 单步播放 分布式集群
- 个人博客标签和文章的表结构设计
- 用计算机弹暖暖数字代码,奇迹暖暖网页版计算器
热门文章
- STM32 ADC最大采样频率
- 微信小程序收款手续费_【微信支付】微信小程序支付开发者文档
- 个人计算机是国产芯片,全球最纯国产PC诞生!所有芯片/系统都是国产
- java 算出下一个工作日_Java 计算一段时间段内除去周六日、节假日的工作日数———超详细(全)...
- samba配置共享打印机
- Gradle 学习 ----Gradle 进阶说明
- PPT制作(文字排版)
- MFC如何让背景图随窗口大小改变
- 配置完dcom需要重启计算机,DCOM电脑自动重启(win7系统一直反复重启)
- 中关村-DIY之国外网盘下载测试