4.2 优化器

PyTorch将深度学习中常用的优化方法全部封装在torch.optim中,其设计十分灵活,能够很方便的扩展成自定义的优化方法。

所有的优化方法都是继承基类optim.Optimizer,并实现了自己的优化步骤。下面就以最基本的优化方法——随机梯度下降法(SGD)举例说明。这里需重点掌握:

  • 优化方法的基本使用方法
  • 如何对模型的不同部分设置不同的学习率
  • 如何调整学习率

In [32]:

# 首先定义一个LeNet网络
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(2,2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(2,2))self.classifier = nn.Sequential(nn.Linear(16 * 5 * 5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, 10))def forward(self, x):x = self.features(x)x = x.view(-1, 16 * 5 * 5)x = self.classifier(x)return xnet = Net()

In [33]:

from torch import  optim
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad() # 梯度清零,等价于net.zero_grad()input = t.randn(1, 3, 32, 32)
output = net(input)
output.backward(output) # fake backwardoptimizer.step() # 执行优化

In [43]:

# 为不同子网络设置不同的学习率,在finetune中经常用到
# 如果对某个参数不指定学习率,就使用最外层的默认学习率
optimizer =optim.SGD([{'params': net.features.parameters()}, # 学习率为1e-5{'params': net.classifier.parameters(), 'lr': 1e-2}], lr=1e-5)
optimizer

Out[43]:

SGD (
Parameter Group 0dampening: 0lr: 1e-05momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.01momentum: 0nesterov: Falseweight_decay: 0
)

In [44]:

# 只为两个全连接层设置较大的学习率,其余层的学习率较小
special_layers = nn.ModuleList([net.classifier[0], net.classifier[3]])
special_layers_params = list(map(id, special_layers.parameters()))
base_params = filter(lambda p: id(p) not in special_layers_params,net.parameters())optimizer = t.optim.SGD([{'params': base_params},{'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001 )
optimizer

Out[44]:

SGD (
Parameter Group 0dampening: 0lr: 0.001momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.01momentum: 0nesterov: Falseweight_decay: 0
)

对于如何调整学习率,主要有两种做法。一种是修改optimizer.param_groups中对应的学习率,另一种是更简单也是较为推荐的做法——新建优化器,由于optimizer十分轻量级,构建开销很小,故而可以构建新的optimizer。但是后者对于使用动量的优化器(如Adam),会丢失动量等状态信息,可能会造成损失函数的收敛出现震荡等情况。

In [48]:

# 方法1: 调整学习率,新建一个optimizer
old_lr = 0.1
optimizer1 =optim.SGD([{'params': net.features.parameters()},{'params': net.classifier.parameters(), 'lr': old_lr*0.1}], lr=1e-5)
optimizer1

Out[48]:

SGD (
Parameter Group 0dampening: 0lr: 1e-05momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.010000000000000002momentum: 0nesterov: Falseweight_decay: 0
)

In [49]:

# 方法2: 调整学习率, 手动decay, 保存动量
for param_group in optimizer.param_groups:param_group['lr'] *= 0.1 # 学习率为之前的0.1倍
optimizer

Out[49]:

SGD (
Parameter Group 0dampening: 0lr: 1.0000000000000002e-06momentum: 0nesterov: Falseweight_decay: 0Parameter Group 1dampening: 0lr: 0.0010000000000000002momentum: 0nesterov: Falseweight_decay: 0
)

Pytorch:优化器相关推荐

  1. Pytorch优化器全总结(三)牛顿法、BFGS、L-BFGS 含代码

    目录 写在前面 一.牛顿法 1.看图理解牛顿法 2.公式推导-三角函数 3.公式推导-二阶泰勒展开 二.BFGS公式推导 三.L-BFGS 四.算法迭代过程 五.代码实现 1.torch.optim. ...

  2. Pytorch优化器全总结(四)常用优化器性能对比 含代码

    目录 写在前面 一.优化器介绍 1.SGD+Momentum 2.Adagrad 3.Adadelta 4.RMSprop 5.Adam 6.Adamax 7.AdaW 8.L-BFGS 二.优化器对 ...

  3. Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad

    目录 写在前面 一. torch.optim.SGD 随机梯度下降 SGD代码 SGD算法解析 1.MBGD(Mini-batch Gradient Descent)小批量梯度下降法 2.Moment ...

  4. Pytorch优化器

    Pytorch优化器 了解不同优化器 构建一个优化器 差别 PyTorch种优化器选择 了解不同优化器 神经网络优化器,主要是为了优化我们的神经网络,使他在我们的训练过程中快起来,节省社交网络训练的时 ...

  5. pytorch优化器与学习率设置详解

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 小新 来源 | https://a.3durl.cn/Yr ...

  6. Pytorch —— 优化器Optimizer(二)

    1.learning rate学习率 梯度下降:wi+1=wi−LR∗g(wi)w_{i+1}=w_{i}-LR*g\left(w_{i}\right)wi+1​=wi​−LR∗g(wi​)梯度是沿着 ...

  7. pytorch优化器学习率调整策略以及正确用法

    优化器 optimzier优化器的作用:优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用. 从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西: ...

  8. pytorch优化器详解:Adam

    目录 说明 Adam原理 梯度滑动平均 偏差纠正 Adam计算过程 pytorch Adam参数 params lr betas eps weight_decay amsgrad 说明 模型每次反向传 ...

  9. pytorch优化器详解:SGD

    目录 说明 SGD参数 params lr momentum dampening weight_decay nesterov 举例(nesterov为False) 第1轮迭代 第2轮迭代 说明 模型每 ...

最新文章

  1. 广东安网2016:重拳挥出 打造安宁互联网环境
  2. 红帽子linux6.6内核版本,RedHat/CentOS发行版本号及内核版本号对照表
  3. springboot 使用interceptor 返回前端http状态码为0
  4. python输入年份月份输出天数_6.2(输入年份 月份 输出该月天数)
  5. CSS多列布局(实例)
  6. [css] 请说说*{box-sizing: border-box;}的作用及好处有哪些?
  7. 爬虫最基本的工作流程:内涵社区网站为例
  8. 阿里下一代云分析型数据库AnalyticDB入选Forrester云化数仓象限
  9. 破解 找回 lockdir 加密的文件
  10. 《数字电路与逻辑设计》笔记及经典问答题
  11. 移动通信网络规划:D2D通信技术
  12. 题解 CF 1413B A New Technique
  13. macOS根目录上无法写入文件和创建目录的问题
  14. 【debug】汇编跳转指令: JMP、JECXZ、JA、JB、JG、JL、JE、JZ、JS、JC、JO、JP 等
  15. SQLite简介,C#调用SQLite
  16. 在线播放bt php,yunBT:一个基于TP3.1的多用户BT离线下载程序,支持在线播放
  17. 【vijos】1164 曹冲养猪(中国剩余定理)
  18. 猿如意中的【格式工厂】工具的安装与使用教程,格式转换这个工具就够了
  19. 2021 年 7 个较佳的免费电子商务平台(比较)
  20. PTA-输出倒三角图案

热门文章

  1. 调参必备---GridSearch网格搜索
  2. 正则表达式 re模块
  3. python学习之旅(入门)
  4. css3鼠标悬停图片抖动效果
  5. JAVA简单选择排序算法原理及实现
  6. C# DataTable怎么合计字段
  7. hdu 2184 01背包变形
  8. 唯品会高级副总裁 唐倚智:电商精细化运营
  9. 返回行javascript比较时间大小
  10. C盘空间越来越小怎么办,教你27招