torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([

{'params': model.features12.parameters(), 'lr': 1e-2},

{'params': model.features22.parameters()},

{'params': model.features32.parameters()},

{'params': model.features42.parameters()},

{'params': model.features52.parameters()},

], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()

Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):

for para in optimizer.param_groups:

para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch

from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):

def __init__(self, params, lr=required, momentum=0, dampening=0,

weight_decay1=0, weight_decay2=0, nesterov=False):

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,

weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)

if nesterov and (momentum <= 0 or dampening != 0):

raise ValueError("Nesterov momentum requires a momentum and zero dampening")

super(SGD, self).__init__(params, defaults)

def __setstate__(self, state):

super(SGD, self).__setstate__(state)

for group in self.param_groups:

group.setdefault('nesterov', False)

def step(self, closure=None):

"""Performs a single optimization step.

Arguments:

closure (callable, optional): A closure that reevaluates the model

and returns the loss.

"""

loss = None

if closure is not None:

loss = closure()

for group in self.param_groups:

weight_decay1 = group['weight_decay1']

weight_decay2 = group['weight_decay2']

momentum = group['momentum']

dampening = group['dampening']

nesterov = group['nesterov']

for p in group['params']:

if p.grad is None:

continue

d_p = p.grad.data

if weight_decay1 != 0:

d_p.add_(weight_decay1, torch.sign(p.data))

if weight_decay2 != 0:

d_p.add_(weight_decay2, p.data)

if momentum != 0:

param_state = self.state[p]

if 'momentum_buffer' not in param_state:

buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)

buf.mul_(momentum).add_(d_p)

else:

buf = param_state['momentum_buffer']

buf.mul_(momentum).add_(1 - dampening, d_p)

if nesterov:

d_p = d_p.add(momentum, buf)

else:

d_p = buf

p.data.add_(-group['lr'], d_p)

return loss

一个使用的例子:

optimizer = SGD([

{'params': model.features12.parameters()},

{'params': model.features22.parameters()},

{'params': model.features32.parameters()},

{'params': model.features42.parameters()},

{'params': model.features52.parameters()},

], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

sgd 参数 详解_关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)相关推荐

  1. mysql innodb 设置详解_【mysql】mysql innodb 配置详解

    MySQLinnodb 配置详解 innodb_buffer_pool_size:这是InnoDB最重要的设置,对InnoDB性能有决定性的影响.默认的设置只有8M,所以默认的数据库设置下面InnoD ...

  2. python之sys模块详解_(转)python之os,sys模块详解

    python之sys模块详解 原文:http://www.cnblogs.com/cherishry/p/5725184.html sys模块功能多,我们这里介绍一些比较实用的功能,相信你会喜欢的,和 ...

  3. cat命令详解_好程序员Python培训之详解eval好与坏

    好程序员Python培训之详解eval好与坏,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,下面我们一起来看一下吧. eval是Python的一个内置函数,这个函数的作用 ...

  4. SpringAOP描述及实现_AspectJ详解_基于注解的AOP实现_SpringJdbcTemplate详解

    AOP AOP特点: 面向切面编程, 利用AOP对业务逻辑的各个部分进行抽取公共代码, 降低耦合度, 提高代码重用性, 同时提高开发效率. 采取横向抽取, 取代传统纵向继承体系重复性代码 解决事务管理 ...

  5. mysql查询语句详解_基于mysql查询语句的使用详解

    1> 查询数据表除了前三条以外的数据. 起初我想到的是这条语句 SELECT * FROM admin WHERE userid NOT IN (SELECT userid FROM admin ...

  6. python input函数详解_对Python3中的input函数详解

    下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函 ...

  7. vlan配置实例详解_网工知识角|MUXVLAN技术详解,基本原理一篇搞定

    学网络,就在IE-LAB 国内高端网络工程师培养基地 MUX VLAN(Multiplex VLAN )提供了一种通过VLAN进行网络资源控制的机制.通过MUX VLAN提供的二层流量隔离的机制可以实 ...

  8. c fread 快读 详解_梨的简笔画画法教程详解【彩色】__水果_水果简笔画图片_学画画网...

    2020-10-20 11:43:58 来源:简笔画教程 作者: 小西 导读: 美味的梨的简笔画怎么画?梨的简单的画法教程,手把手的教你画梨的简笔画,喜欢就跟着教程一起来学吧. 梨的简笔画画法教程详解 ...

  9. mysql连接数详解_查看mysql当前连接数的方法详解

    1.查看当前所有连接的详细资料: ./mysqladmin -uadmin -p -h10.140.1.1 processlist2.只查看当前连接数(Threads就是连接数.): ./mysqla ...

最新文章

  1. 曲线数学NURBS之bezier曲线
  2. Leangoo看板工具做敏捷故事地图看板示例
  3. 测量接线导通问题解决方案
  4. Chika and Friendly Pairs
  5. 要不要赶个时髦,去建设一个「 中台 」?
  6. Linux的概念与体系 7. Linux进程基础(转载)
  7. lepus监控oracle数据库_实用脚本一键监控oracle数据库索引使用状况
  8. 突然间~两年悄然而过
  9. java基础案例7-4升级日记本
  10. 微信开发 完美微信自动转发朋友圈-flutterAndroid
  11. 期货交易常用术语中英文对照表
  12. 英语四六级必备资料(全网最全)
  13. iOS从零开始学习socket编程——高并发多线程服务器
  14. VMware虚拟机上共享主机网络的设置方法
  15. USTCOJ 1382 毛毛虫
  16. MockMvc案例实战调用Controller层API接⼝
  17. 使用PLAN法提升执行力——笔记与答案
  18. 游戏背景音乐的两个特殊类型
  19. 「SaaS金羊毛」BI平台_Domo
  20. Twipstopixels java_Access量度单位缇与像素,厘米等的换算关系

热门文章

  1. net.ipv4.tcp_max_tw_buckets=10
  2. 人工智能放大插件Topaz Gigapixel AI
  3. IIS配置映射磁盘虚拟目录
  4. 癌症的根源----《细胞叛逆者》笔记
  5. 学习数据结构与算法心得
  6. 【MySQL】安装check requirement里面的很多选项左边多出现红叉叉解决方案
  7. 在App Store日进斗金的推广秘诀
  8. zookeeper实现负载均衡
  9. 网络ping不通是什么原因?那是因为你没掌握这些!
  10. win10装mysql哪个版本好用吗_win10安装两个不同版本的mysql(mysql5.7和mysql-8.0.19)