废话不多说,直接上代码吧~

model.zero_grad()

optimizer.zero_grad()

首先,这两种方式都是把模型中参数的梯度设为0

当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器

def zero_grad(self):

"""Sets gradients of all model parameters to zero."""

for p in self.parameters():

if p.grad is not None:

p.grad.data.zero_()

补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

引言

一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)

optimizer_d.zero_grad()

# 判别器对虚假数据进行训练

fake_ab = torch.cat((real_a, fake_b), 1)

pred_fake = net_d.forward(fake_ab.detach())

loss_d_fake = criterionGAN(pred_fake, False)

# 判别器对真实数据进行训练

real_ab = torch.cat((real_a, real_b), 1)

pred_real = net_d.forward(real_ab)

loss_d_real = criterionGAN(pred_real, True)

# 判别器损失

loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()

optimizer_d.step()

上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。

参数更新和反向传播

上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。

我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。

调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。

optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。

以上这篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

本文标题: PyTorch中model.zero_grad()和optimizer.zero_grad()用法

本文地址: http://www.cppcns.com/jiaoben/python/323103.html

python grad_PyTorch中model.zero_grad()和optimizer.zero_grad()用法相关推荐

  1. pytorch之model.zero_grad() 与 optimizer.zero_grad()

    转自 https://cloud.tencent.com/developer/article/1710864 1. 引言 在PyTorch中,对模型参数的梯度置0时通常使用两种方式:model.zer ...

  2. Python pandas 中loc函数的意思及用法,及跟iloc的区别

    Python pandas 中loc函数的意思及用法,及跟iloc的区别 loc和iloc的意思 loc和iloc的区别及用法展示 参考文献 loc和iloc的意思 首先,loc是location的意 ...

  3. model.train()、model.eval()、optimizer.zero_grad()、loss.backward()、optimizer.step作用及原理详解【Pytorch入门手册】

    1. model.train() model.train()的作用是启用 Batch Normalization 和 Dropout. 如果模型中有BN层(Batch Normalization)和D ...

  4. python中什么是关键字参数_如何使用python语言中函数的关键字参数的用法

    一般情况下,在调用函数时,使用的是位置参数,即是按照参数的位置来传值:关键字参数是按照定义函数传入的参数名称来传值的.那么,关键字参数怎么使用? 工具/原料 python pycharm 截图工具 W ...

  5. 理解optimizer.zero_grad(), loss.backward(), optimizer.step()的作用及原理

    在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward()和optimizer.step()三个函数,如下所示: ...

  6. 梯度值与参数更新optimizer.zero_grad(),loss.backward、和optimizer.step()、lr_scheduler.step原理解析

    在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward.和optimizer.step().lr_schedule ...

  7. optimizer.zero_grad()和loss.backward()

    1.optimizer.zero_grad()和loss.backward()先后问题 刚开始学习深度学习,就是不明白,为什么第一次运行的时候就要optimizer.zero_grad()(梯度清零) ...

  8. optimizer.zero_grad(),loss.backward(),optimizer.step()的作用原理

    目录 前言 一.optimizer.zero_grad() 二. loss.backward() 三.optimizer.step() 前言 在用pytorch训练模型时,通常会在遍历epochs的过 ...

  9. python中fit_Python sklearn中的.fit与.predict的用法说明

    我就废话不多说了,大家还是直接看代码吧~ clf=KMeans(n_clusters=5) #创建分类器对象 fit_clf=clf.fit(X) #用训练器数据拟合分类器模型 clf.predict ...

最新文章

  1. 常见面试题:重写strcpy() 函数原型
  2. 数据结构与算法--线性表(顺序表)
  3. C语言中变量的储存类别
  4. java imap收邮件_android pop3与imap方式接收邮件(javamail)
  5. azure web应用部署_使用Visual Studio Code将Python应用程序部署到Azure Functions
  6. 【数据结构和算法】拓扑排序(附leetcode题 207/210 课程表)
  7. 计算机论文期中小结,毕业论文中期小结
  8. 学习新浪微博计数服务
  9. 元组定义 元组运算符
  10. 【每日经典】李嘉诚:赚钱可以无处不在、无时不有
  11. 微信获取公众号二维码
  12. 在代码里设置view的android:layout_marginTop
  13. HTML5系列代码:注册商标reg_和版权商标copy
  14. ant app 心电监测_电话传输心电图监测在心血管疾病及远程医疗中的应用
  15. 安利这几款软件给需要的你
  16. 动植物代谢最新研究进展(2021年7月)
  17. 多益网络2018春季校园招聘研发岗笔试经验
  18. Motor Back-drive电机反驱
  19. 从数据仓库到大数据,数据平台这25年是怎样进化的?[转]
  20. 计算机投影到数字电视的方法,怎么把电脑投屏到电视有线(电脑无线投屏到电视机的方法)...

热门文章

  1. 8-2 Android 摄像头和相册
  2. fzoj Problem 2190 非提的救赎 【单调栈】
  3. 【C语言】初识指针(终篇)
  4. 影响到网站收录有哪些方面呢
  5. 期刊论文左下角横线的添加方法
  6. 国二c语言考试要点,全国计算机二级考试c语言考试要点
  7. 1字节是多少位,汉字utf-8又占多少。
  8. Typora使用总结
  9. 你问我答:小匠,如何像你一样,做一个订阅号挣它 100 W?
  10. OnePlus一加手机:测试