1 背景知识

在了解 torch.optim.swa_utils.AverageModel() 前, 我们先了解以下 SWA(随机加权平均)

1.1 SWA

SWA 全称 : Stochastic Weight Averaging,

  • SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果

  • 随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点, 经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失

  • SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好

2 AverageModel() 介绍

AveragedModel 类用于计算SWA模型的权重。可以通过运行以下命令创建一个averaged model:

from torch.optim.swa_utils import AverageModel
swa_model = AverageModel(model)

这里的模型Model可以是任意的torch.nn.Module对象。swa_model将跟踪模型参数的运行平均值。要更新这些平均值,你可以使用update_parameters()函数:

swa_model.update_parameters(model)

Pytorch 中的 torch.optim.swa_utils.AverageModel() 及其原理总结相关推荐

  1. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  2. gather torch_浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  3. Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化

    Pytorch 学习(6):Pytorch中的torch.nn  Convolution Layers  卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...

  4. Pytorch中的torch.where函数

    首先我们看一下Pytorch中torch.where函数是怎样定义的: @overload def where(condition: Tensor) -> Union[Tuple[Tensor, ...

  5. Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  6. Pytorch中的 torch.as_tensor() 和 torch.from_numpy() 的区别

    之前我写过一篇文章,比较了 torch.Tensor() 和 torch.tensor() 的区别,而这两者都是深拷贝的方法,返回张量的同时,会在内存中创建一个额外的数据副本,与原数据不共享内存,所以 ...

  7. pytorch中的torch.tensor.repeat以及torch.tensor.expand用法

    文章目录 torch.tensor.expand torch.tensor.repeat torch.tensor.expand 先看招 import torch x = torch.tensor([ ...

  8. pytorch中的torch.nn.LSTM解析

    文章目录 前言 多层LSTM 权重形状 batch_first 输入形状 输出形状 参考 前言 本文记录一下使用LSTM的一些心得. 多层LSTM 多层LSTM是这样: 而不是这样: 我们可以控制如下 ...

  9. Pytorch中的 torch.Tensor() 和 torch.tensor() 的区别

    直接在搜索引擎里进行搜索,可以看到官方文档中两者对应的页面: 分别点击进去,第一个链接解释了什么是 torch.Tensor: torch.Tensor 是一个包含单一数据类型元素的多维矩阵(数组). ...

最新文章

  1. boost安装_【环境搭建】源码安装Boost
  2. arcgis api for flex 开发入门(九)webservices 的使用
  3. Maven 版 JPA 最佳实践(转)
  4. python xpath语法-Python爬虫 | 解析库Xpath的使用
  5. STM32 ADC 同步规则模式 ADC1与ADC2同用一个DMA
  6. Rtx userlist.php,【图片】【C语言】【控制台】提取腾讯通用户信息(id,用户名,手机)【erbi_lucifer吧】_百度贴吧...
  7. [转]Spark能否取代Hadoop?
  8. 聚类分析在用户行为中的实例_用户关注行为数据分析过程详解-描述统计+聚类...
  9. 和python高级知识_Python中的5个高阶概念属性的知识点!你要了解明白哦!
  10. selenium 获取不了标签文本的解决方法
  11. Vue-注册全局组件的两种方法
  12. jquery easyui datagrid 获取Checked选择行(勾选行)数据
  13. emc测试e3软件系数导入,EMC测试标准
  14. XRF与ICP比较的差异
  15. ADKAR模型简介(转)
  16. 如何用photoshop做24色环_如何制作出Ps色环?
  17. android在体检报告叫什么,体检报告检测分析app
  18. 编程语言理解3-目前主流的编程语言有哪些,分别的应用场景是什么
  19. Django之stark组件1
  20. 解惑“可观测性”与“监控”的不同

热门文章

  1. Android直播播放器+弹幕使用总结
  2. Storm学习一集群安装
  3. kmp算法例题 登山
  4. Dreamweaver下拉菜单全攻略
  5. 如何切换笔记本键盘的功能键?
  6. AE基础教程第一阶段——15质量图标和效果开关
  7. 自动点击大师(AUTO CLICKER)
  8. 基于微信小程序的家政服务预约系统的设计与实现
  9. 在SQLServer处理中的一些问题及解决方法
  10. kvm与openvz等不同的虚拟化技术有什么区别