Lookahead 优化算法是Adam的作者继Adam之后的又一力作,论文可以参见https://arxiv.org/abs/1907.08610

这篇博客先不讲述Lookahead具体原理,先介绍如何将Lookahead集成到现有的代码中。

本人在三个项目中(涉及风格转换、物体识别)使用该优化器,最大的感受就是使用该优化器十分有利于模型收敛,原本不收敛或者收敛过慢的模型在使用lookahead后可以看到明显的收敛情况,并且最终的效果能够满足最初设计的要求。

总所周知,Adam因为其具有较好的适应性,被广泛用于各类模型的优化;其参数简单,调参方便的特点一直为大家所喜爱,尤其对于初学者较为友好。Lookahead 也继承了Adam的优点。lookahead的Pytorch版本代码如下所示:后续会针对代码进行原理讲解,该代码在Github上可以找到。

from collections import defaultdict
from torch.optim import Optimizer
import torchclass Lookahead(Optimizer):def __init__(self, optimizer, k=5, alpha=0.5):self.optimizer = optimizerself.k = kself.alpha = alphaself.param_groups = self.optimizer.param_groupsself.state = defaultdict(dict)self.fast_state = self.optimizer.statefor group in self.param_groups:group["counter"] = 0def update(self, group):for fast in group["params"]:param_state = self.state[fast]if "slow_param" not in param_state:param_state["slow_param"] = torch.zeros_like(fast.data)param_state["slow_param"].copy_(fast.data)slow = param_state["slow_param"]slow += (fast.data - slow) * self.alphafast.data.copy_(slow)def update_lookahead(self):for group in self.param_groups:self.update(group)def step(self, closure=None):loss = self.optimizer.step(closure)for group in self.param_groups:if group["counter"] == 0:self.update(group)group["counter"] += 1if group["counter"] >= self.k:group["counter"] = 0return lossdef state_dict(self):fast_state_dict = self.optimizer.state_dict()slow_state = {(id(k) if isinstance(k, torch.Tensor) else k): vfor k, v in self.state.items()}fast_state = fast_state_dict["state"]param_groups = fast_state_dict["param_groups"]return {"fast_state": fast_state,"slow_state": slow_state,"param_groups": param_groups,}def load_state_dict(self, state_dict):slow_state_dict = {"state": state_dict["slow_state"],"param_groups": state_dict["param_groups"],}fast_state_dict = {"state": state_dict["fast_state"],"param_groups": state_dict["param_groups"],}super(Lookahead, self).load_state_dict(slow_state_dict)self.optimizer.load_state_dict(fast_state_dict)self.fast_state = self.optimizer.statedef add_param_group(self, param_group):param_group["counter"] = 0self.optimizer.add_param_group(param_group)

将lookahead集成在现有代码中如下操作即可:

base_optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
opt = Lookahead(base_optimizer, k=5, alpha=0.5)

此时直接将opt作为正常的优化器使用即可,就像直接使用Adam一样的步骤使用opt

Pytorch 版本的lookahead 优化函数使用(附代码)相关推荐

  1. 详解pytorch实现猫狗识别98%附代码

    详解pytorch实现猫狗识别98%附代码 前言 一.为什么选用pytorch这个框架? 二.实现效果 三.神经网络从头到尾 1.来源:仿照人为处理图片的流程,模拟人们的神经元处理信息的方式 2.总览 ...

  2. NLP【05】pytorch实现glove词向量(附代码详解)

    上一篇:NLP[04]tensorflow 实现Wordvec(附代码详解) 下一篇:NLP[06]RCNN原理及文本分类实战(附代码详解) 完整代码下载:https://github.com/ttj ...

  3. 教你用PyTorch实现“看图说话”(附代码、学习资源)

    作者:FAIZAN SHAIKH 翻译:和中华 校对:白静 本文共2200字,建议阅读10分钟. 本文用浅显易懂的方式解释了什么是"看图说话"(Image Captioning), ...

  4. 快速上手笔记,PyTorch模型训练实用教程(附代码)

    机器之心发布 作者:余霆嵩 前言 自 2017 年 1 月 PyTorch 推出以来,其热度持续上升,一度有赶超 TensorFlow 的趋势.PyTorch 能在短时间内被众多研究人员和工程师接受并 ...

  5. 微信支付分,APIv3版本接口对接过程(附代码)

    刚对接完微信支付分,对接过程还是有点小坑,微信官方的接口文档写的比较粗略,代码示例比较少,网上的相关技术博客少之又少,前期还是有点小困难的,所以决定把对接过程梳理一下,希望能帮到需要的人. APIv3 ...

  6. 使用PyTorch+OpenCV进行人脸识别(附代码演练)

    人脸识别是一种用于从图像或视频中识别人脸的系统.它在许多应用程序和垂直行业中很有用.如今,我们看到这项技术可帮助新闻机构在重大事件报道中识别名人,为移动应用程序提供二次身份验证,为媒体和娱乐公司自动索 ...

  7. pytorch生成对抗网络GAN的基础教学简单实例(附代码数据集)

    1.简介 这篇文章主要是介绍了使用pytorch框架构建生成对抗网络GAN来生成虚假图像的原理与简单实例代码.数据集使用的是开源人脸图像数据集img_align_celeba,共1.34G.生成器与判 ...

  8. 简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码)

     简单有趣的 NLP 教程:手把手教你用 PyTorch 辨别自然语言(附代码) 雷锋网(公众号:雷锋网)按:本文作者甄冉冉,原载于作者个人博客,雷锋网已获授权. 最近在学pyTorch的实际应用 ...

  9. 毕设日志——pytorch版本faster rcnn运行代码前的环境配置2019.4.9

    准备测试的代码是: https://github.com/jwyang/faster-rcnn.pytorch 讲解:https://hellozhaozheng.github.io/z_post/P ...

  10. PyTorch 模型训练实用教程(附代码)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx PyTorch 能在短时间内被众多研究人员和工程师接受并推崇是因为其有着诸多优点,如采用 Py ...

最新文章

  1. 2021年春季学期-信号与系统-第五次作业参考答案-第七小题
  2. 摘自《Java工程师成神之路》2018修订版,自我勉励
  3. 【Python】Python Mako模板使用
  4. IBM发布JumpGate 连接OpenStack和公有云
  5. 计算机组成原理树状图,数据结构
  6. EJB(RMI学习)
  7. 将Numpy数组保存为图像
  8. 推荐12个最好的 JavaScript 图形绘制库
  9. 【AICC】2019训练营笔记
  10. VS 2005/2008 Web Setup Project
  11. Windows系统怎么查看电脑的系统位数?
  12. npm查看依赖包报错:npm ERR! extraneous解决!!
  13. Maven中不能引入ojdbc解决方法:com.oracle:ojdbc6:jar:11.2.0.3
  14. 【Excel 教程系列第 11 篇】Excel 如何快速下拉填充序列至 10000 行
  15. Ureport2导出内容加入PDF文件
  16. Linux zip 7z效率比较,linux 下面的 7z和7za的区别
  17. 2021-09-08-EMMC启动命令备份-设置bootargs-编译内核存放的路径-SD卡uboot启动备份,print打印的信息
  18. iOS 删除 SceneDelegeta.h
  19. JS 烧脑面试题大赏
  20. 小程序: 代码包过大

热门文章

  1. 【技巧】vscode快速生成html结构
  2. Ionic系列——Ionic介绍
  3. 大规模定制(Mass Customization,MC)
  4. BootStrap之导航navigation
  5. Human Muscles/Musculature (人体肌肉组织)
  6. 【自然语言处理】词性标注
  7. iOS应用崩溃日志分析
  8. 解决Win10系统下运行unity游戏闪退报错问题 包含 人类一败涂地 波西亚时光等
  9. ARM基础学习-寄存器寻址方式和指令
  10. 波浪线html,js中的波浪线符号作用(按位非(~)符号)