pytorch基础知识整理(五) 优化器
深度学习网络必须通过优化器进行训练。在pytorch中相关代码位于torch.optim模块中。
1, 常规用法
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for data,target in train_loader:...optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()
2, optimizer的方法和属性
- optimizer.zero_grad()的作用是清空模型中所有参数的梯度。
注:pytorch默认是进行梯度自动累加的,所以要使用optimizer.zero_grad()对梯度进行清零,如果遇到多个loss相加时,或者用多次循环再更新梯度法放大batch_size时,就不用optimizer.zero_grad()了。
- optimizer.step()的作用是根据梯度更新参数,所以放在loss.backward()计算梯度之后。
- state_dict()获得optimizer的状态字典,里面存放有param_groups和state。可通过torch.save()来保存到硬盘。
- load_state_dict()用来加载保存的状态字典,以继续训练。
- param_groups是一个list列表,存放各参数分组param_group。各参数分组param_group中包括学习率、momentum、weight_decay、dampening、nesterov以及各参数张量。
- 以及add_param_group、defaults等
3,对不同的参数分组设定不同的学习率
optimizer = torch.optim.SGD([{'params': other_params}, {'params': first_params, 'lr': 0.01*args.learning_rate},{'params': second_params, 'weight_decay': args.weight_decay}],lr=args.learning_rate,momentum=args.momentum,
)
4,各种优化器
torch.optim.
Adam(), SGD(), Adadelta(), Adagrad(), LBFGS(), AdamW(), SparseAdam(), Adamax(), ASGD(), RMSprop(), Rprop()
还有一些非pytorch自带的优化器,详见torch_optimizer(通过pip install torch-optimizer安装)。
5,lr_scheduler
根据epoch数或其他条件实现动态学习率,模板如下:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min')
for epoch in range(10):train(...)val_loss = validate(...)# Note that step should be called after validate()scheduler.step(val_loss)
学习率调整方法有:
ReduceLROnPlateau(),
LambdaLR(),
StepLR(),
MultiStepLR(),
CosineAnnealingLR(),
CosineAnnealingWarmRestarts(),
CyclicLR(),
OneCycleLR()
6,手动设计lr_scheduler
不使用pytorch自带的lr_scheduler,手动设计也不难:
for epoch in range(N): if epoch > 10:for param_group in optimizer.param_groups:param_group['lr'] = lr/10train(epoch)
pytorch基础知识整理(五) 优化器相关推荐
- 【PyTorch基础教程9】优化器optimizer和训练过程
学习总结 (1)每个优化器都是一个类,一定要进行实例化才能使用,比如: class Net(nn.Moddule):··· net = Net() optim = torch.optim.SGD(ne ...
- Pytorch基础(十)——优化器(SGD,Adagrad,RMSprop,Adam,LBFGS等)
一.概念 Pytorch中优化器的目的:将损失函数计算出的差值Loss减小. 优化过程:优化器计算网络参数的梯度,然后使用一定的算法策略来对参数进行计算,用新的参数来重新进行训练,最终降低Loss. ...
- pytorch基础知识整理(三)模型保存与加载
1, torch.save(); troch.load() torch.save()使用python的pickle模块把目标保存到磁盘,可以用来保存模型.张量.字典等,文件后缀名一般用pth或pt或p ...
- pytorch基础知识整理(二)数据加载
pytorch数据加载组件位于torch.utils.data中. from torch.utils.data import DataLoader, Dataset, Sampler 1, torch ...
- Pytorch基础知识整理(六)参数初始化
参数初始化的目的是限定网络权重参数的初始分布,试图让权重参数更接近参数空间的最优解,从而加速训练.pytorch中网络默认初始化参数为随机均匀分布,设定额外的参数初始化并非总能加速训练. 1,模板 在 ...
- pytorch基础知识整理(一)自动求导机制
torch.autograd torch.autograd是pytorch最重要的组件,主要包括Variable类和Function类,Variable用来封装Tensor,是计算图上的节点,Func ...
- pytorch基础知识整理(四) 模型
1,模型构造模板 torch.nn.Module()是所有网络模型的基类,所有网络都需要继承此类,模板如下: import torch.nn as nn import torch.nn.functio ...
- centos7创建asm磁盘_Oracle ASM 磁盘组基础知识整理(收藏版)
为什么要写这么一篇基础知识呢?还是有那么一点点原因的,不是胡编乱造还真是有真实存在的事件的,前两周里因一套生产环境数据库磁盘不足无法对其进行表空间扩容,需要向存储岗申请存储资源,当存储岗划好资源加完存 ...
- HTML5的基础知识整理
HTML5 概述:HTML5是HTML最新的修订版本,2014年10月由万维网联盟(W3C)完成标准制定. HTML5的设计目的是为了在移动设备上支持多媒体. 文章目录 HTML5 前言 一.HTMl ...
最新文章
- 一种三维结构化导航的思路
- 引领潮流云电视机遇与挑战并现
- 手把手教你使用FineUI开发一个b/s结构的取送货管理信息系统(附源码+视频教程(第9节))...
- Java获取文件路径
- java中pagex_Java/5_get和post比较.md at master · zaoshangyaochifan/Java · GitHub
- 安卓手机 python控制_PyAndroidControl:使用python脚本控制你的安卓设备
- 网站服务器中病毒该如何处理,网站被中了木马无法删除怎么办? 解决网站中病毒的办法...
- mdk系列 Adsl 成功上网指南(非USB ADSL)
- 顺应大数据时代创新社会治理模式
- 常用Keytool 命令
- java如何开根号?
- kindle资源网址
- vscode配置运行php项目完整版
- css特效实例——纯css实现带边角卷边阴影的纸
- Flex在线文档阅读器::pdf、doc、docx、xls、xlsx、ppt、pptx、htm、txt、rtf、epub、csv、xdoc等
- Android Gradle进阶配置指南 1
- Java对象内存空间大小计算
- 【IT运维小知识】安全组是什么意思?
- 下载网易云课堂和B站的视频
- uniapp使用地图
热门文章
- jCryptoJS 、C#互通加密(MD5版)
- MyEclipse优化设置(最详细版本)
- 利用koa实现mongodb数据库的增删改查
- 普通用户Mysql 5.6.13 主从,主主及nagios的mysql slave监控
- ASP.NET中常用的26个优化性能方法
- 国自然申请初审中的注意事项
- shp设置utf8格式_shp文件格式说明
- python定义变量并赋值_Python 变量类型及变量赋值
- Ubuntu16.04+ROS+ORB-SLAM2测试(转载)
- 北斗导航 | 坐标转换:ECEF转LLA:GPS坐标系:WGS84(matlab代码)