optimizer.zero_grad()意思是把梯度置零,也就是把loss关于weight的导数变成0.

在学习pytorch的时候注意到,对于每个batch大都执行了这样的操作:

        # zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()

对于这些操作我是把它理解成一种梯度下降法,贴一个自己之前手写的简单梯度下降法作为对照:

    # gradient descentweights = [0] * nalpha = 0.0001max_Iter = 50000for i in range(max_Iter):loss = 0d_weights = [0] * nfor k in range(m):h = dot(input[k], weights)d_weights = [d_weights[j] + (label[k] - h) * input[k][j] for j in range(n)] loss += (label[k] - h) * (label[k] - h) / 2d_weights = [d_weights[k]/m for k in range(n)]weights = [weights[k] + alpha * d_weights[k] for k in range(n)]if i%10000 == 0:print "Iteration %d loss: %f"%(i, loss/m)print weights

可以发现它们实际上是一一对应的:

optimizer.zero_grad()对应d_weights = [0] * n

即将梯度初始化为零(因为一个batch的loss关于weight的导数是所有sample的loss关于weight的导数的累加和)

outputs = net(inputs)对应h = dot(input[k], weights)

即前向传播求出预测的值

loss = criterion(outputs, labels)对应loss += (label[k] - h) * (label[k] - h) / 2

这一步很明显,就是求loss(其实我觉得这一步不用也可以,反向传播时用不到loss值,只是为了让我们知道当前的loss是多少)
loss.backward()对应d_weights = [d_weights[j] + (label[k] - h) * input[k][j] for j in range(n)]

即反向传播求梯度
optimizer.step()对应weights = [weights[k] + alpha * d_weights[k] for k in range(n)]

即更新所有参数

如有不对,敬请指出。欢迎交流

作者:scut_salmon
来源:CSDN
原文:https://blog.csdn.net/scut_salmon/article/details/82414730
版权声明:本文为博主原创文章,转载请附上博文链接!

torch代码解析 为什么要使用optimizer.zero_grad()相关推荐

  1. Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解

    引言 一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.s ...

  2. unet模型及代码解析

    什么是unet 一个U型网络结构,2015年在图像分割领域大放异彩,unet被大量应用在分割领域.它是在FCN的基础上构建,它的U型结构解决了FCN无法上下文的信息和位置信息的弊端 Unet网络结构 ...

  3. 梯度值与参数更新optimizer.zero_grad(),loss.backward、和optimizer.step()、lr_scheduler.step原理解析

    在用pytorch训练模型时,通常会在遍历epochs的过程中依次用到optimizer.zero_grad(),loss.backward.和optimizer.step().lr_schedule ...

  4. Baidu Apollo代码解析之EM Planner中的QP Speed Optimizer 1

    大家好,我已经把CSDN上的博客迁移到了知乎上,欢迎大家在知乎关注我的专栏慢慢悠悠小马车(https://zhuanlan.zhihu.com/duangduangduang).希望大家可以多多交流, ...

  5. [GCN] 代码解析 of GitHub:Semi-supervised classification with graph convolutional networks

    本文解析的代码是论文Semi-Supervised Classification with Graph Convolutional Networks作者提供的实现代码. 原GitHub:Graph C ...

  6. 自然语言处理(三):传统RNN(NvsN,Nvs1,1vsN,NvsM)pytorch代码解析

    文章目录 1.预备知识:深度神经网络(DNN) 2.RNN出现的意义与基本结构 3.根据输入和输出数量的网络结构分类 3.1 N vs N(输入和输出序列等长) 3.2 N vs 1(多输入单输出) ...

  7. Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析

    论文地址点这里 一. 介绍 联邦学习具有广阔的应用前景,但面临着来自数据异构的挑战,因为在现实世界中用户数据均为Non-IID分布的.在这样的情况下,传统的联邦学习算法可能会导致无法收敛到各个客户端的 ...

  8. GraphSAGE算法 和 代码解析

    聚合邻居 GraphSAGE研究了聚合邻居操作所需的性质,并且提出了几种新的聚合操作(aggregator),需满足如下条件: (1)聚合操作必须要对聚合节点的数量做到自适应.不管节点的邻居数量怎么变 ...

  9. pytorch之model.zero_grad() 与 optimizer.zero_grad()

    转自 https://cloud.tencent.com/developer/article/1710864 1. 引言 在PyTorch中,对模型参数的梯度置0时通常使用两种方式:model.zer ...

  10. python grad_PyTorch中model.zero_grad()和optimizer.zero_grad()用法

    废话不多说,直接上代码吧~ model.zero_grad() optimizer.zero_grad() 首先,这两种方式都是把模型中参数的梯度设为0 当optimizer = optim.Opti ...

最新文章

  1. 在Mac上控制Alt Delete-如何在Macbook上打开任务管理器
  2. 课时 28:理解容器运行时接口 CRI(知谨)
  3. c语言case后面多字符,多SWITCH-CASE结构时的C语言对象方式化解
  4. Linux入门学习(八)
  5. C++参考的翻译或校对
  6. 新修订未成年人保护法6月1日正式实施
  7. 看了这个有趣的例子,你就能秒懂Java中的多线程同步了!
  8. 欢迎加入互联网架构师群
  9. 三维扫描仪行业调研报告 - 市场现状分析与发展前景预测
  10. UDA: A user-difference attention for group recommendation
  11. 金山云智能营销平台再升级,AI 投放助力游戏厂商精准到达;微医发布 AI 解决方案,提升县域医疗服务能力...
  12. win10 右键卡顿问题
  13. Python学习笔记:通过Headers字段模拟浏览器访问亚马逊界面爬取
  14. 若依管理系统RuoYi-Vue(前后端分离版)项目启动教程
  15. 跑赢新趋势 | 未来3-5年,运维人的机会在哪里?
  16. 测度论--长度是怎样炼成的[zz]
  17. 单片机原理及应用知识总结(持续更新)
  18. 游戏设计模式之策略模式(二)
  19. extjs图片放大缩小实现
  20. 泊松融合实现图片拼接

热门文章

  1. 7.0高等数学五-高斯公式
  2. 关于数学计算机手抄报简单的,数学手抄报简单又漂亮图片
  3. 老徐WEB:js入门学习 - javascript函数和闭包
  4. java计算机毕业设计婴幼儿玩具共享租售平台源码+数据库+系统+lw文档+mybatis+运行部署
  5. running_mean和running_var
  6. 我的团长我的团可能的故事原现
  7. Android 模拟器 Root 和 SuperSU 安装
  8. 前端 JS 根据日期查询周几 星期几
  9. 台式机创建文件服务器,如何将台式机做成云存储服务器
  10. 安装andriod studio