Optimizer

optimizer.param_groups用法的示例分析

日期:2022年7月25日

pytorch版本: 1.11.0

对于param_groups的探索

optimizer.param_groups: 是一个list,其中的元素为字典;

optimizer.param_groups[0]:长度为7的字典,包括[‘params’, ‘lr’, ‘betas’, ‘eps’, ‘weight_decay’, ‘amsgrad’, ‘maximize’]这7个参数;

下面用的Adam优化器创建了一个optimizer变量:

>>> optimizer.param_groups[0].keys()
>>> dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad', 'maximize'])

可以自己把训练参数分别赋予不同的学习率,这样子list里就不止一个元素了,而是多个字典了。

  • params 是一个list[…],里面存放参数

    >>> len(optimizer.param_groups[0]['params'])
    >>> 48
    >>> optimizer.param_groups[0]['params'][0]
    >>>
    Parameter containing:
    tensor([[ 0.0212, -0.1151,  0.0499,  ..., -0.0807, -0.0572,  0.1166],[-0.0356, -0.0397, -0.0980,  ...,  0.0690, -0.1066, -0.0583],[ 0.0238,  0.0316, -0.0636,  ...,  0.0754, -0.0891,  0.0258],...,[ 0.0603, -0.0173,  0.0627,  ...,  0.0152, -0.0215, -0.0730],[-0.1183, -0.0636,  0.0381,  ...,  0.0745, -0.0427, -0.0713],
    
  • lr 是学习率

    >>> optimizer.param_groups[0]['lr']
    >>> 0.0005
    
  • betas 是一个元组(…),与动量相关

    >>> optimizer.param_groups[0]['betas']
    >>> (0.9, 0.999)
    
  • eps

    >>> optimizer.param_groups[0]['eps']
    >>> 1e-08
    
  • weight_decay 是一个int变量

    >>> optimizer.param_groups[0]['weight_decay']
    >>> 0
    
  • amsgrad是一个bool变量

    >>> optimizer.param_groups[0]['amsgrad']
    >>> False
    
  • maximize 是一个bool变量

    >>> optimizer.param_groups[0]['maximize']
    >>> False
    

以网上的例子来继续试验:

import torch
import torch.optim as optimw1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)# 输出
>>>
[{'params': [tensor([[-0.1002,  0.3526, -1.2212],[-0.4659,  0.0498, -0.2905],[ 1.1862, -0.6085,  0.4965]], requires_grad=True)],'lr': 0.001, 'betas': (0.9, 0.999),'eps': 1e-08,'weight_decay': 0,'amsgrad': False,'maximize': False}]

以下主要是Optimizer这个类有个add_param_group的方法

# Per the docs, the add_param_group method accepts a param_group parameter that is a dict. Example of use:import torch
import torch.optim as optimw1 = torch.randn(3, 3)
w1.requires_grad = True
w2 = torch.randn(3, 3)
w2.requires_grad = True
o = optim.Adam([w1])
print(o.param_groups)# 输出
>>> [{'params': [tensor([[-1.5916, -1.6110, -0.5739],[ 0.0589, -0.5848, -0.9199],[-0.4206, -2.3198, -0.2062]], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False}]o.add_param_group({'params': w2})
print(o.param_groups)# 输出
>>> [{'params': [tensor([[-1.5916, -1.6110, -0.5739],[ 0.0589, -0.5848, -0.9199],[-0.4206, -2.3198, -0.2062]], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False}, {'params': [tensor([[-0.5546, -1.2646,  1.6420],[ 0.0730, -0.0460, -0.0865],[ 0.3043,  0.4203, -0.3607]], requires_grad=True)], 'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False}]

平时写代码如何动态修改学习率(常规操作)

for param_group in optimizer.param_groups:param_group["lr"] = lr

补充:pytorch中的优化器总结

SGD优化器为例:

from torch import nn as nn
import torch as t
from torch.autograd import Variable as V
from torch import optim  # 优化器# 定义一个LeNet网络
class LeNet(t.nn.Module):def __init__(self):super(LeNet, self).__init__()self.features = t.nn.Sequential(t.nn.Conv2d(3, 6, 5),t.nn.ReLU(),t.nn.MaxPool2d(2, 2),t.nn.Conv2d(6, 16, 5),t.nn.ReLU(),t.nn.MaxPool2d(2, 2))# 由于调整shape并不是一个class层,# 所以在涉及这种操作(非nn.Module操作)需要拆分为多个模型self.classifiter = t.nn.Sequential(t.nn.Linear(16*5*5, 120),t.nn.ReLU(),t.nn.Linear(120, 84),t.nn.ReLU(),t.nn.Linear(84, 10))def forward(self, x):x = self.features(x)x = x.view(-1, 16*5*5)x = self.classifiter(x)return xnet = LeNet()# 通常的step优化过程
optimizer = optim.SGD(params=net.parameters(), lr=1)
optimizer.zero_grad()  # 梯度清零,相当于net.zero_grad()input = V(t.randn(1, 3, 32, 32))
output = net(input)
output.backward(output)
optimizer.step()  # 执行优化

为不同的子网络参数不同的学习率,finetune常用,使分类器学习率参数更高,学习速度更快(理论上)。

1.经由构建网络时划分好的模组进行学习率设定,

# 为不同子网络设置不同的学习率,在finetune中经常用到
# 如果对某个参数不指定学习率,就使用默认学习率
optimizer = optim.SGD([{'params': net.features.parameters()},  # 学习率为1e-5{'params': net.classifiter.parameters(), 'lr': 1e-2}], lr=1e-5
)

2.以网络层对象为单位进行分组,并设定学习率

# 只为两个全连接层设置较大的学习率,其余层的学习率较小
# 以层为单位,为不同层指定不同的学习率# 提取指定层对象
special_layers = nn.ModuleList([net.classifiter[0], net.classifiter[3]])
# 获取指定层参数id
special_layers_params = list(map(id, special_layers.parameters()))
# 获取非指定层的参数id
base_params = filter(lambda p: id(p) not in special_layers_params, net.parameters())optimizer = t.optim.SGD([{'params': base_params},{'params': special_layers.parameters(), 'lr': 0.01}], lr=0.001)

参考:
https://blog.csdn.net/weixin_43593330/article/details/108490956
https://www.cnblogs.com/hellcat/p/8496727.html
https://www.yisu.com/zixun/456082.html

有关optimizer.param_groups用法的示例分析相关推荐

  1. TVM开发三个示例分析

    TVM开发三个示例分析 把自主生成的代码生成TVM 把自主生成的代码生成TVM 目录 简介 要生成C代码. 要生成任何其它图形表示. 实现一个C代码生成器 实现[CodegenC] 运算符代码生成 输 ...

  2. osgEarth各个示例分析目录

    概述 由于数字地球项目需要osgEarth的代码知识,故决定学习osgEarth的示例,示例中有很多可以借鉴的内容.以下是分析目录,完全是随机进行的,并没有什么上下文逻辑. 每一篇代码边学习边分析,如 ...

  3. php关键词匹配度排序,MySQL_mysql 关键词相关度排序方法详细示例分析,小项目有时需要用到关键词搜 - phpStudy...

    mysql 关键词相关度排序方法详细示例分析 小项目有时需要用到关键词搜索相关性排序,用sphinx显得杀鸡用牛刀,就用mysql的order by对付下. 方法一: select * from ar ...

  4. Dorado用法与示例

    Dorado用法与示例 dorado用后总结 一.dorado概念 dorado的产品全名是"dorado展现中间件".从产品形态上dorado由两部分组成,第一部分是一个具有AJ ...

  5. Python进阶之递归函数的用法及其示例

    作者 | 程序员adny 责编 | 徐威龙 封图| CSDN│下载于视觉中国 出品 |  AI科技大本营(ID:rgznai100) 本篇文章主要介绍了Python进阶之递归函数的用法及其示例,现在分 ...

  6. python和R数据类型查看、赋值、列表、for循环、函数用法对比示例

    python和R数据类型查看.赋值.列表.for循环.函数用法对比示例 很多工程师可能刚开始的时候只熟悉python或者R其中的一个进行数据科学相关的任务. 那么如果我们对比这学习可以快速了解语言设计 ...

  7. Android涂鸦技术及刮刮乐示例分析

    概述: 很早之前就想研究一下Android中的涂鸦,其实也说不上是研究了,毕竟都是一些相对比较简单的知识点.下面就对基于画布(Canvas)和触摸事件(onTouchEvent)来实现涂鸦和刮刮乐. ...

  8. 计算机网络时延图,计算机网络中网站性能延迟加载图像的示例分析

    计算机网络中网站性能延迟加载图像的示例分析 发布时间:2021-06-09 11:38:56 来源:亿速云 阅读:95 作者:小新 这篇文章给大家分享的是有关计算机网络中网站性能延迟加载图像的示例分析 ...

  9. nodejs ajax进度条,Ajax异步文件上传与NodeJS express服务端处理的示例分析

    Ajax异步文件上传与NodeJS express服务端处理的示例分析 发布时间:2021-07-24 11:17:21 来源:亿速云 阅读:79 作者:小新 这篇文章主要介绍Ajax异步文件上传与N ...

最新文章

  1. 让资源管理器不显示最近常用文件夹
  2. Java多线程断点下载
  3. MPLS、SD-WAN孰优孰劣?
  4. STM32F4 HAL库开发 -- STM32F407引脚图
  5. uvalive4842(AC自动机+DP)
  6. Mp3tag(MP3文件信息修改器) V2.79a 多语绿色版
  7. 在SAP CRM webclient ui右上角显示系统时间
  8. Hbase Import导入数据异常处理-RetriesExhaustedWithDetailsException
  9. SpringBoot @Cacheable缓存入门程序
  10. 也乱弹Book.Save而引OO对话
  11. STL 格式解析--文本以及二进制格式
  12. zcu106 固化_ZCU106的PYNQ移植
  13. 【前端学习笔记】微信小程序vue 组件式开发
  14. 将RT-Thread Nano移植到STM32F401CCU6
  15. 史上最好听的十首纯音乐推荐
  16. 对话韩寒父子:“韩寒是我得意的笔名”
  17. 黑客比程序员高在哪里?
  18. 各大网站和app是如何实现黑白页面效果?
  19. 初学者应从文件目录结构理解import的过程,并创建自已的代码库
  20. python的基本数据类型关键字_Python3 基本数据类型

热门文章

  1. Delphi中对Excel表格文件的导入和导出操作。
  2. 有道云笔记的快速剪报
  3. c语言 system(quot;pausequot;);,c++中system(quot;pausequot;)的作用和含义,systempause
  4. cornerRadius属性
  5. Hadoop 电影评分数据统计分析实验
  6. 双11购物优惠劵 满减计算程序
  7. 开着mysql是不是很耗电_空调经常开和关费电,还是一直开启更费电?今天总算知道了...
  8. 载20(S)-人参皂苷/细胞穿膜肽-单克隆抗体-载丝裂霉素白蛋白纳米微球的制备
  9. Android开发 SQLite数据库
  10. python 背景音乐程序代码_用Python演奏音乐