1.应用

import torch
import torch.nn as nnoptimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

概念

最简单的更新规则是Stochastic Gradient Descent (SGD):

weight = weight - learning_rate * gradient

手动实现

learning_rate = 0.01
for f in net.parameters(): # 遍历图中每个节点的参数f.data.sub_(f.grad.data * learning_rate) # 将节点的参数-(学习速率*梯度),单下划线表示替换

pytorch中已经实现了SGD等一系列的更新方法

import torch.optim as optim# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)# in your training loop:
optimizer.zero_grad()   # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()    # Does the update

API

1.类

stochastic gradient descent (optionally with momentum).

CLASS torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)
参数 描述
params (iterable) iterable of parameters to optimize or dicts defining parameter groups
lr (float) 学习速率
momentum (float, optional) momentum factor (default: 0)
weight_decay (float, optional) weight decay (L2 penalty) (default: 0)
dampening (float, optional) dampening for momentum (default: 0)
nesterov (bool, optional) enables Nesterov momentum (default: False)

对象

参数 描述
step(closure=None) Performs a single optimization step.

参考:
https://pytorch.org/docs/stable/optim.html?highlight=sgd#torch.optim.SGD

pytorch optim.SGD相关推荐

  1. pytorch优化器: optim.SGD optimizer.zero_grad()

        在神经网络优化器中,主要为了优化我们的神经网络,使神经网络在我们的训练过程中快起来,节省时间.在pytorch中提供了 torch.optim方法优化我们的神经网络,torch.optim 是 ...

  2. torch.optim.SGD()

    其中的SGD就是optim中的一个算法(优化器):随机梯度下降算法 PyTorch 的优化器基本都继承于 "class Optimizer",这是所有 optimizer 的 ba ...

  3. pytorch中SGD源码解读

    调用方法: torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0 ...

  4. torch.optim.sgd参数详解

    SGD(随机梯度下降)是一种更新参数的机制,其根据损失函数关于模型参数的梯度信息来更新参数,可以用来训练神经网络.torch.optim.sgd的参数有:lr(学习率).momentum(动量).we ...

  5. pytorch .item_pytorch + SGD

    梯度下降是模型优化常用方法,原理也比较简单,简言之就是参数沿负梯度方向更新,参数更新公式如下. ,其中 表示的是步长,用于控制每次更新移动的步伐. 我们将使用pytorch来试验下这个方法. 首先先生 ...

  6. 基于深度学习的简单二分类(招聘信息的真假)

    招聘数据真假分类 此次机器学习课程大作业-招聘数据真假分类,是一个二分类问题.训练集中共有14304个样本,每个样本有18个特征,目标是判断不含有标签的招聘信息的真假性. 利用Pandas读取训练集和 ...

  7. Pytorch实现MNIST(附SGD、Adam、AdaBound不同优化器下的训练比较) adabound实现

     学习工具最快的方法就是在使用的过程中学习,也就是在工作中(解决实际问题中)学习.文章结尾处附完整代码. 一.数据准备   在Pytorch中提供了MNIST的数据,因此我们只需要使用Pytorch提 ...

  8. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

  9. PyTorch基础(三)-----神经网络包nn和优化器optim

    前言 torch.nn是专门为神经网络设计的模块化接口.nn构建于Autograd之上,可用来定义和运行神经网络.这里我们主要介绍几个一些常用的类. 约定:torch.nn 我们为了方便使用,会为他设 ...

最新文章

  1. Android 自定义控件开发入门(一)
  2. kali 安装java jdk
  3. Null return value from advice does not match primitive return type for: public abstract boolean
  4. JMeter学习(六)集合点
  5. [转]要有梦----送给自己,希望自己能尽快走出当前的痛苦期
  6. python中e-r图_E-R图基本步骤
  7. android studio 手动安装gradle,Android Studio 如何安装Gradle?
  8. linux ftp mysql_linux下ftp和ftps以及ftp基于mysql虚拟用户认证服务器的搭建
  9. 朗锐智科PoE图像采集卡助力机器视觉应用
  10. 二十.激光、视觉和惯导LVIO-SLAM框架学习之相机内参标定
  11. rust怎么建柱子_小报:捷达VS5安全带卡扣向里?敲B柱?怎么掰回来? 第191220期...
  12. HDU 1861 游船出租(模拟)
  13. yoloV3运行速度测试报告
  14. 信号预处理电路(三角波和正弦波转换成方波)
  15. 计算机以太网,局域网,互联网,令牌网,ATM网络
  16. windows电脑cmd命令查看网卡的物理地址(mac地址)
  17. excel中读取数据拟合幂律分布
  18. web前端项目开发流程
  19. sourcetree远端 红色叹号
  20. k8s集群搭建-1mater2node

热门文章

  1. c语言结果输出10遍,C语言 如何实现输出这样一系列输出结果
  2. linux 怎么看w7分区,如何查看widows7系统和Linux端口被占用
  3. xadmin与mysql数据库_django和xadmin打造后台管理系统(一)-xadmin安装及使用
  4. 从零开始学前端:CSS元素模式的转换和CSS三大特性 --- 今天你学习了吗?(CSS:Day12)
  5. 使用线性回归拟合平面最佳直线及预测之Python+sklearn实现
  6. java条件触发_java – 当给定75:android时,条件不会触发
  7. 全国计算机二级c语言和江苏教材一样吗,计算机二级省级和全国计算机二级考试内容一样吗...
  8. mysql更改安装路径6_关于mysql安装后更改数据库路径方法-Centos6环境
  9. 下划线间隔数字 排序_面试必备:经典算法动画解析之希尔排序
  10. qml入门学习(六):Component组件