Pytorch:优化器
4.2 优化器
PyTorch将深度学习中常用的优化方法全部封装在torch.optim
中,其设计十分灵活,能够很方便的扩展成自定义的优化方法。
所有的优化方法都是继承基类optim.Optimizer
,并实现了自己的优化步骤。下面就以最基本的优化方法——随机梯度下降法(SGD)举例说明。这里需重点掌握:
- 优化方法的基本使用方法
- 如何对模型的不同部分设置不同的学习率
- 如何调整学习率
In [32]:
# 首先定义一个LeNet网络 class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(2,2))self.classifier = nn.Sequential(nn.Linear(16 * 5 * 5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, 10))def forward(self, x):x = self.features(x)x = x.view(-1, 16 * 5 * 5)x = self.classifier(x)return xnet = Net()
In [33]:
from torch import optim optimizer = optim.SGD(params=net.parameters(), lr=1) optimizer.zero_grad() # 梯度清零,等价于net.zero_grad()input = t.randn(1, 3, 32, 32) output = net(input) output.backward(output) # fake backwardoptimizer.step() # 执行优化
In [43]:
# 为不同子网络设置不同的学习率,在finetune中经常用到 # 如果对某个参数不指定学习率,就使用最外层的默认学习率 optimizer =optim.SGD([{'params': net.features.parameters()}, # 学习率为1e-5{'params': net.classifier.parameters(), 'lr': 1e-2}], lr=1e-5) optimizer
Out[43]:
SGD ( Parameter Group 0dampening: 0lr: 1e-05momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.01momentum: 0nesterov: Falseweight_decay: 0 )
In [44]:
# 只为两个全连接层设置较大的学习率,其余层的学习率较小 special_layers = nn.ModuleList([net.classifier[0], net.classifier[3]]) special_layers_params = list(map(id, special_layers.parameters())) base_params = filter(lambda p: id(p) not in special_layers_params,net.parameters())optimizer = t.optim.SGD([{'params': base_params},{'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001 ) optimizer
Out[44]:
SGD ( Parameter Group 0dampening: 0lr: 0.001momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.01momentum: 0nesterov: Falseweight_decay: 0 )
对于如何调整学习率,主要有两种做法。一种是修改optimizer.param_groups中对应的学习率,另一种是更简单也是较为推荐的做法——新建优化器,由于optimizer十分轻量级,构建开销很小,故而可以构建新的optimizer。但是后者对于使用动量的优化器(如Adam),会丢失动量等状态信息,可能会造成损失函数的收敛出现震荡等情况。
In [48]:
# 方法1: 调整学习率,新建一个optimizer old_lr = 0.1 optimizer1 =optim.SGD([{'params': net.features.parameters()},{'params': net.classifier.parameters(), 'lr': old_lr*0.1}], lr=1e-5) optimizer1
Out[48]:
SGD ( Parameter Group 0dampening: 0lr: 1e-05momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.010000000000000002momentum: 0nesterov: Falseweight_decay: 0 )
In [49]:
# 方法2: 调整学习率, 手动decay, 保存动量 for param_group in optimizer.param_groups:param_group['lr'] *= 0.1 # 学习率为之前的0.1倍 optimizer
Out[49]:
SGD ( Parameter Group 0dampening: 0lr: 1.0000000000000002e-06momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.0010000000000000002momentum: 0nesterov: Falseweight_decay: 0 )
Pytorch:优化器相关推荐
- Pytorch优化器全总结(三)牛顿法、BFGS、L-BFGS 含代码
目录 写在前面 一.牛顿法 1.看图理解牛顿法 2.公式推导-三角函数 3.公式推导-二阶泰勒展开 二.BFGS公式推导 三.L-BFGS 四.算法迭代过程 五.代码实现 1.torch.optim. ...
- Pytorch优化器全总结(四)常用优化器性能对比 含代码
目录 写在前面 一.优化器介绍 1.SGD+Momentum 2.Adagrad 3.Adadelta 4.RMSprop 5.Adam 6.Adamax 7.AdaW 8.L-BFGS 二.优化器对 ...
- Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad
目录 写在前面 一. torch.optim.SGD 随机梯度下降 SGD代码 SGD算法解析 1.MBGD(Mini-batch Gradient Descent)小批量梯度下降法 2.Moment ...
- Pytorch优化器
Pytorch优化器 了解不同优化器 构建一个优化器 差别 PyTorch种优化器选择 了解不同优化器 神经网络优化器,主要是为了优化我们的神经网络,使他在我们的训练过程中快起来,节省社交网络训练的时 ...
- pytorch优化器与学习率设置详解
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 小新 来源 | https://a.3durl.cn/Yr ...
- Pytorch —— 优化器Optimizer(二)
1.learning rate学习率 梯度下降:wi+1=wi−LR∗g(wi)w_{i+1}=w_{i}-LR*g\left(w_{i}\right)wi+1=wi−LR∗g(wi)梯度是沿着 ...
- pytorch优化器学习率调整策略以及正确用法
优化器 optimzier优化器的作用:优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用. 从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西: ...
- pytorch优化器详解:Adam
目录 说明 Adam原理 梯度滑动平均 偏差纠正 Adam计算过程 pytorch Adam参数 params lr betas eps weight_decay amsgrad 说明 模型每次反向传 ...
- pytorch优化器详解:SGD
目录 说明 SGD参数 params lr momentum dampening weight_decay nesterov 举例(nesterov为False) 第1轮迭代 第2轮迭代 说明 模型每 ...
最新文章
- 广东安网2016:重拳挥出 打造安宁互联网环境
- 红帽子linux6.6内核版本,RedHat/CentOS发行版本号及内核版本号对照表
- springboot 使用interceptor 返回前端http状态码为0
- python输入年份月份输出天数_6.2(输入年份 月份 输出该月天数)
- CSS多列布局(实例)
- [css] 请说说*{box-sizing: border-box;}的作用及好处有哪些?
- 爬虫最基本的工作流程:内涵社区网站为例
- 阿里下一代云分析型数据库AnalyticDB入选Forrester云化数仓象限
- 破解 找回 lockdir 加密的文件
- 《数字电路与逻辑设计》笔记及经典问答题
- 移动通信网络规划:D2D通信技术
- 题解 CF 1413B A New Technique
- macOS根目录上无法写入文件和创建目录的问题
- 【debug】汇编跳转指令: JMP、JECXZ、JA、JB、JG、JL、JE、JZ、JS、JC、JO、JP 等
- SQLite简介,C#调用SQLite
- 在线播放bt php,yunBT:一个基于TP3.1的多用户BT离线下载程序,支持在线播放
- 【vijos】1164 曹冲养猪(中国剩余定理)
- 猿如意中的【格式工厂】工具的安装与使用教程,格式转换这个工具就够了
- 2021 年 7 个较佳的免费电子商务平台(比较)
- PTA-输出倒三角图案