深度学习网络必须通过优化器进行训练。在pytorch中相关代码位于torch.optim模块中。

1, 常规用法

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for data,target in train_loader:...optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()  

2, optimizer的方法和属性

  • optimizer.zero_grad()的作用是清空模型中所有参数的梯度。

注:pytorch默认是进行梯度自动累加的,所以要使用optimizer.zero_grad()对梯度进行清零,如果遇到多个loss相加时,或者用多次循环再更新梯度法放大batch_size时,就不用optimizer.zero_grad()了。

  • optimizer.step()的作用是根据梯度更新参数,所以放在loss.backward()计算梯度之后。
  • state_dict()获得optimizer的状态字典,里面存放有param_groups和state。可通过torch.save()来保存到硬盘。
  • load_state_dict()用来加载保存的状态字典,以继续训练。
  • param_groups是一个list列表,存放各参数分组param_group。各参数分组param_group中包括学习率、momentum、weight_decay、dampening、nesterov以及各参数张量。
  • 以及add_param_group、defaults等

3,对不同的参数分组设定不同的学习率

optimizer = torch.optim.SGD([{'params': other_params}, {'params': first_params, 'lr': 0.01*args.learning_rate},{'params': second_params, 'weight_decay': args.weight_decay}],lr=args.learning_rate,momentum=args.momentum,
)

4,各种优化器

torch.optim.
Adam(), SGD(), Adadelta(), Adagrad(), LBFGS(), AdamW(), SparseAdam(), Adamax(), ASGD(), RMSprop(), Rprop()
还有一些非pytorch自带的优化器,详见torch_optimizer(通过pip install torch-optimizer安装)。

5,lr_scheduler

根据epoch数或其他条件实现动态学习率,模板如下:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):train(...)val_loss = validate(...)# Note that step should be called after validate()scheduler.step(val_loss)

学习率调整方法有:
ReduceLROnPlateau(),
LambdaLR(),
StepLR(),
MultiStepLR(),
CosineAnnealingLR(),
CosineAnnealingWarmRestarts(),
CyclicLR(),
OneCycleLR()

6,手动设计lr_scheduler

不使用pytorch自带的lr_scheduler,手动设计也不难:

for epoch in range(N):   if epoch > 10:for param_group in optimizer.param_groups:param_group['lr'] = lr/10train(epoch)

pytorch基础知识整理(五) 优化器相关推荐

  1. 【PyTorch基础教程9】优化器optimizer和训练过程

    学习总结 (1)每个优化器都是一个类,一定要进行实例化才能使用,比如: class Net(nn.Moddule):··· net = Net() optim = torch.optim.SGD(ne ...

  2. Pytorch基础(十)——优化器(SGD,Adagrad,RMSprop,Adam,LBFGS等)

    一.概念 Pytorch中优化器的目的:将损失函数计算出的差值Loss减小. 优化过程:优化器计算网络参数的梯度,然后使用一定的算法策略来对参数进行计算,用新的参数来重新进行训练,最终降低Loss. ...

  3. pytorch基础知识整理(三)模型保存与加载

    1, torch.save(); troch.load() torch.save()使用python的pickle模块把目标保存到磁盘,可以用来保存模型.张量.字典等,文件后缀名一般用pth或pt或p ...

  4. pytorch基础知识整理(二)数据加载

    pytorch数据加载组件位于torch.utils.data中. from torch.utils.data import DataLoader, Dataset, Sampler 1, torch ...

  5. Pytorch基础知识整理(六)参数初始化

    参数初始化的目的是限定网络权重参数的初始分布,试图让权重参数更接近参数空间的最优解,从而加速训练.pytorch中网络默认初始化参数为随机均匀分布,设定额外的参数初始化并非总能加速训练. 1,模板 在 ...

  6. pytorch基础知识整理(一)自动求导机制

    torch.autograd torch.autograd是pytorch最重要的组件,主要包括Variable类和Function类,Variable用来封装Tensor,是计算图上的节点,Func ...

  7. pytorch基础知识整理(四) 模型

    1,模型构造模板 torch.nn.Module()是所有网络模型的基类,所有网络都需要继承此类,模板如下: import torch.nn as nn import torch.nn.functio ...

  8. centos7创建asm磁盘_Oracle ASM 磁盘组基础知识整理(收藏版)

    为什么要写这么一篇基础知识呢?还是有那么一点点原因的,不是胡编乱造还真是有真实存在的事件的,前两周里因一套生产环境数据库磁盘不足无法对其进行表空间扩容,需要向存储岗申请存储资源,当存储岗划好资源加完存 ...

  9. HTML5的基础知识整理

    HTML5 概述:HTML5是HTML最新的修订版本,2014年10月由万维网联盟(W3C)完成标准制定. HTML5的设计目的是为了在移动设备上支持多媒体. 文章目录 HTML5 前言 一.HTMl ...

最新文章

  1. 一种三维结构化导航的思路
  2. 引领潮流云电视机遇与挑战并现
  3. 手把手教你使用FineUI开发一个b/s结构的取送货管理信息系统(附源码+视频教程(第9节))...
  4. Java获取文件路径
  5. java中pagex_Java/5_get和post比较.md at master · zaoshangyaochifan/Java · GitHub
  6. 安卓手机 python控制_PyAndroidControl:使用python脚本控制你的安卓设备
  7. 网站服务器中病毒该如何处理,网站被中了木马无法删除怎么办? 解决网站中病毒的办法...
  8. mdk系列 Adsl 成功上网指南(非USB ADSL)
  9. 顺应大数据时代创新社会治理模式
  10. 常用Keytool 命令
  11. java如何开根号?
  12. kindle资源网址
  13. vscode配置运行php项目完整版
  14. css特效实例——纯css实现带边角卷边阴影的纸
  15. Flex在线文档阅读器::pdf、doc、docx、xls、xlsx、ppt、pptx、htm、txt、rtf、epub、csv、xdoc等
  16. Android Gradle进阶配置指南 1
  17. Java对象内存空间大小计算
  18. 【IT运维小知识】安全组是什么意思?
  19. 下载网易云课堂和B站的视频
  20. uniapp使用地图

热门文章

  1. jCryptoJS 、C#互通加密(MD5版)
  2. MyEclipse优化设置(最详细版本)
  3. 利用koa实现mongodb数据库的增删改查
  4. 普通用户Mysql 5.6.13 主从,主主及nagios的mysql slave监控
  5. ASP.NET中常用的26个优化性能方法
  6. 国自然申请初审中的注意事项
  7. shp设置utf8格式_shp文件格式说明
  8. python定义变量并赋值_Python 变量类型及变量赋值
  9. Ubuntu16.04+ROS+ORB-SLAM2测试(转载)
  10. 北斗导航 | 坐标转换:ECEF转LLA:GPS坐标系:WGS84(matlab代码)