机器学习9:关于pytorch中的zero_grad()函数

本文参考了博客Pytorch 为什么每一轮batch需要设置optimizer.zero_grad。

1.zero_grad()函数的应用:

在pytorch中做随机梯度下降时往往会用到zero_grad()函数,相关代码如下。

optimizer.zero_grad()                       # 将模型的参数梯度初始化为0

outputs=model(inputs)              # 前向传播计算预测值

loss = cost(outputs, y_train)           # 计算当前损失

loss.backward()                               # 反向传播计算梯度

optimizer.step()                               # 更新所有参数

2.zero_grad()函数的作用:

根据pytorch中backward()函数的计算,当网络参量进行反馈时,梯度是累积计算而不是被替换,但在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算,因此需要对每个batch调用一遍zero_grad()将参数梯度置0.

另外,如果不是处理每个batch清除一次梯度,而是两次或多次再清除一次,相当于提高了batch_size,对硬件要求更高,更适用于需要更高batch_size的情况。

机器学习9:关于pytorch中的zero_grad()函数相关推荐

  1. 关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

    关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构, ...

  2. Pytorch中的collate_fn函数用法

    Pytorch中的collate_fn函数用法 官方的解释:   Puts each data field into a tensor with outer dimension batch size ...

  3. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  4. pytorch 中 利用自定义函数 get_mask_from_lengths(lengths, max_len)获取每个batch的mask

    在pytorch中,经常会需要通过batch进行批量处理数据,由于每个batch中各个样本之间存在差异,经常会需要进行先padding后mask的操作. 尤其是在自然语言处理任务中,每个batch中的 ...

  5. pytorch 中的topk函数

    pytorch中topk() 函数用法 1. 函数介绍 最近在代码中看到这两个语句 maxk = max(topk) _, pred = output.topk(maxk, 1, True, True ...

  6. pytorch中的MSELoss函数

    基本概念 均方误差(mean square error, MSE),是反应估计量与被估计量之间差异程度的一种度量,设ttt是根据子样确定的总体参数θ\thetaθ的一个估计量,(θ−t)2(\thet ...

  7. Pytorch中的contiguous()函数

    这个函数主要是为了辅助pytorch中的一些其他函数,主要包含 在PyTorch中,有一些对Tensor的操作不会真正改变Tensor的内容,改变的仅仅是Tensor中字节位置的索引.这些操作有: n ...

  8. PyTorch中的matmul函数详解

    PyTorch中的两个张量的乘法可以分为两种: 两个张量对应的元素相乘(element-wise),在PyTorch中可以通过torch.mul函数(或者∗*∗运算符)实现 两个张量矩阵相乘(Matr ...

  9. PyTorch中F.cross_entropy()函数

    对PyTorch中F.cross_entropy()的理解 PyTorch提供了求交叉熵的两个常用函数: 一个是F.cross_entropy(), 另一个是F.nll_entropy(), 是对F. ...

最新文章

  1. 刚子扯谈:微信 今天你打飞机了嘛吗?
  2. 本科毕业的互联网女主管,却被迫要嫁给开挖掘机的高中毕业生!这是咋回事?...
  3. Deno 正式发布,彻底弄明白和 node 的区别
  4. 互联网大脑的发育与元宇宙的兴起
  5. Linux终端关闭屏幕显示,使用命令行关闭监视器
  6. 计算机视觉(一)——深度学习
  7. UIPIckerView现实城市选择
  8. linux kill
  9. java项目管理工具
  10. 别找了,这就是你心心念念想要的年会活动抽奖软件
  11. AT89C51单片机8位竞赛抢答器_倒计时可调仿真设计
  12. 西安理工大学计算机科学与技术分数线,2017西安理工大学各专业录取分数线
  13. 11. Zigbee应用程序框架开发指南 - 命令行接口(CLI)
  14. 基于51单片机的简易计算器proteus仿真 数码管显示
  15. 聊聊我遇到的那些贵人
  16. 处理png图片为透明
  17. 及时尽孝,别枉读了大学
  18. PCB布局、布线总结(持续进行中。。。。。。)
  19. scipy Matlab-style IIR 滤波器设计上(Butterworth\Chebyshev type I \Chebyshev type II )
  20. 别在@官方加国旗啦,3分钟30行Python代码帮你搞定!还加鸡腿,加IPhone11!

热门文章

  1. 弘辽科技:月订单量超5亿单的背后,标志着快手已成电商第四极
  2. 计算游泳时间-第10届蓝桥杯Scratch省赛真题第3题
  3. 【8086汇编】cmp指令和条件转移指令jxxx
  4. html动态加载js方法,原生JS实现动态加载js文件并在加载成功后执行回调函数的方法...
  5. spark3.3.1通过hbase-connectors连接CDH6.3.2自带hbase
  6. 统计任意字符串中回文字符串的个数
  7. 简单实现在线更新系统
  8. 微信小程序界面设计小程序中CSS3样式精通课程-边框-box-shadow 盒阴影
  9. 添加打印样式的三种方式
  10. Pytorch squeeze() unsqueeze() 用法