前沿

本篇笔记主要介绍torch.optim模块,主要包含模型训练的优化器Optimizer

基础知识

PyTorch由4个主要包装组成:
1.Torch:类似于Numpy的通用数组库,可以在将张量类型转换为(torch.cuda.TensorFloat)并在GPU上进行计算。
2.torch.autograd:用于构建计算图形并自动获取渐变的包
3.torch.nn:具有共同层和成本函数的神经网络库
4.torch.optim:具有通用优化算法(如SGD,Adam等)的优化包

Optimizer模块

  1. 优化器主要是在模型训练阶段对模型可学习参数进行更新, 常用优化器有 SGD,RMSprop,Adam等
  2. 优化器初始化时传入模型的可学习参数,以及其他超参数如 lr,momentum等
  3. 在训练过程中先调用 optimizer.zero_grad() 清空梯度,再调用 loss.backward() 反向传播,最后调用optimizer.step()更新模型参数

举例

import torch
import numpy as np
import warnings
warnings.filterwarnings('ignore') #ignore warningsx = torch.linspace(-np.pi, np.pi, 2000)
y = torch.sin(x)p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1).pow(p)model = torch.nn.Sequential(torch.nn.Linear(3, 1),torch.nn.Flatten(0, 1)
)
loss_fn = torch.nn.MSELoss(reduction='sum')learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
for t in range(1, 1001):y_pred = model(xx)loss = loss_fn(y_pred, y)if t % 100 == 0:print('No.{: 5d}, loss: {:.6f}'.format(t, loss.item()))optimizer.zero_grad() # 梯度清零loss.backward() # 反向传播计算梯度optimizer.step() # 梯度下降法更新参数No.  100, loss: 26215.714844No.  200, loss: 11672.815430No.  300, loss: 4627.826172No.  400, loss: 1609.388062No.  500, loss: 677.805115No.  600, loss: 473.932159No.  700, loss: 384.862396No.  800, loss: 305.365143No.  900, loss: 229.774719No. 1000, loss: 161.483841

python torch.optim模块相关推荐

  1. Python torch 模块,randperm() 实例源码

    参考Python torch 模块,randperm() 实例源码 - 云+社区 - 腾讯云 torch.randperm(n, *, out=None, dtype=torch.int64, lay ...

  2. Python:机器学习模块PyTorch【上】

    点击访问:PyTorch中文API应用具体代码地址 自动求导机制 本说明将概述Autograd如何工作并记录操作.了解这些并不是绝对必要的,但我们建议您熟悉它,因为它将帮助您编写更高效,更简洁的程序, ...

  3. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  4. class torch.optim.lr_scheduler.ExponentialLR

    参考链接: class torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False) 配 ...

  5. class torch.optim.lr_scheduler.StepLR

    参考链接: class torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose= ...

  6. class torch.optim.lr_scheduler.LambdaLR

    参考链接: class torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False) 配套 ...

  7. python torch exp_Python torch.diag方法代码示例

    本文整理汇总了Python中torch.diag方法的典型用法代码示例.如果您正苦于以下问题:Python torch.diag方法的具体用法?Python torch.diag怎么用?Python ...

  8. python torch exp_Python torch.add方法代码示例

    本文整理汇总了Python中torch.add方法的典型用法代码示例.如果您正苦于以下问题:Python torch.add方法的具体用法?Python torch.add怎么用?Python tor ...

  9. python:Json模块dumps、loads、dump、load介绍

    20210831 https://www.cnblogs.com/bigtreei/p/10466518.html json dump dumps 区别 python:Json模块dumps.load ...

  10. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

最新文章

  1. Apache Shiro 使用手册(三)Shiro 授权
  2. HTTP协议基础知识点点滴滴
  3. java异常_Java线程池「异常处理」正确姿势:有病就得治
  4. [HTML]HTML5实现可编辑表格
  5. Webpack搭建React开发环境
  6. WORD如何自动给标题添加编号?
  7. gstreamer插件特别要注意事件处理(含代码范例)
  8. Qtalk 0.2.0版本(基于Qt的局域网聊天软件)
  9. Motrix全能下载神器 无限制版 支持下载HTTP、磁力、FTP、BT、百度网盘等
  10. win10各版本的历史记录
  11. k8s-身份认证与权限
  12. 【一】ArcGIS API for JavaScript 4.x之地图显示
  13. 云原生中间件RocketMQ-核心原理之高可用机制
  14. 2019-02-13 思考:1000瓶药水,1瓶有毒,老鼠毒发24h,如何用最少的老鼠在24h内找出毒药?
  15. 递归方法实现最大公约数
  16. Java数组可变长参数详解
  17. 【CF979D】 Kuro and GCD and XOR and SUM
  18. 河源实验室建设合理化细节探讨
  19. Unity 播放本地视频
  20. 半监督学习 图像分类_自我监督学习的图像分类。

热门文章

  1. uni-app小程序刷新当前页面的两种方法
  2. html字体图标显示不出来,h5页面字体图标显示不正常
  3. 计算机办公软件基础知识题库,办公软件基础知识试题试卷--题库.doc
  4. stm32h750电路_STM32H750开发板
  5. Unity 苹果内购
  6. 如何压缩ppt大小的方法不减画质?
  7. mysql两种事务管理器_MyBatis事务管理的两种方式
  8. 入驻shopee平台后,选择哪一个站点作为首站?
  9. lcd驱动移植的分析linux3.2内核,chipsee为例,液晶屏AT070TN92
  10. 软路由的介绍及安装和配置