摘要

论文链接:https://arxiv.org/abs/1803.05407.pdf

官方代码:https://github.com/timgaripov/swa

论文翻译:【第32篇】SWA:平均权重导致更广泛的最优和更好的泛化_AI浩的博客-CSDN博客

SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。记训练过程第iii个epoch的checkpoint为wiw_{i}wi​,一般情况下我们会选择训练过程中最后的一个epoch的模型wnw_{n}wn​或者在验证集上效果最好的一个模型wi∗w^{*}_{i}wi∗​作为最终模型。但SWA一般在最后采用较高的固定学习速率或者周期式学习速率额外训练一段时间,取多个checkpoints的平均值。

pytorch使用举例:

from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):for batch_idx, (data, target) in enumerate(train_loader):   data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)# 在反向传播前要手动将梯度清零optimizer.zero_grad()output = model(data)#计算losssloss = train_criterion(output, targets)# 反向传播求解梯度loss.backward()optimizer.step()lr = optimizer.state_dict()['param_groups'][0]['lr']   swa_model.update_parameters(model)swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")

上面的代码展示了SWA的主要代码,实现的步骤:

1、定义SGD优化器。

2、定义SWA。

3、定义SWALR,调整模型的学习率。

4、开始训练,等待训练完成。

5、在每个epoch中更新模型的参数,更新学习率。

6、等待训练完成后,更新BN层的参数。

详细实现过程

环境

pyotrch:1.10

准备

在开始今天的代码前,我们要准备好训练好的模型。然后才能开始今天的代码。

实现过程

定义模型,并将训练好的模型载入,代码如下:

    model_ft = efficientnet_b1(pretrained=True)print(model_ft)num_ftrs = model_ft.classifier.in_featuresmodel_ft.classifier = nn.Linear(num_ftrs, classes)model_ft.to(DEVICE)model_ft = torch.load(model_path)print(model_ft)fine_epoch = 80fine_tune(model_ft, DEVICE, train_loader, test_loader, criterion_train, criterion_val, fine_epoch, mixup_fn,use_amp)

定义模型为efficientnet_b1,这里要和训练的模型保持一致。

如果保存的整个模型,则使用torch.load(model_path)载入模型,如果只保存了权重信息,则要使用model_ft=load_state_dict(torch.load(model_path)),载入模型。

然后,设置fine的epoch为80。

接下来,我们一起去看fine_tune函数中的内容。

 # 采用SGD优化器optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3, momentum=0.9)if use_amp:model, optimizer = amp.initialize(model_ft, optimizer, opt_level="O1")  # 这里是“欧一”,不是“零一”

定义优化器为SGD。

如果使用混合精度,则对amp初始化。

 # 随机权重平均SWA,实现更好的泛化swa_model = AveragedModel(model).to(device)# SWA调整学习率swa_scheduler = SWALR(optimizer, swa_lr=1e-6)

初始化SWA。

使用SWALR调整学习率。

接下来循环epoch,这里都是比较通用的逻辑。

 for epoch in range(1, epoch + 1):model.train()train_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):if len(data) % 2 != 0:print(len(data))data = data[0:len(data) - 1]target = target[0:len(target) - 1]print(len(data))data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)samples, targets = mixup_fn(data, target)output = model(samples)loss = train_criterion(output, targets)optimizer.zero_grad()if use_amp:with amp.scale_loss(loss, optimizer) as scaled_loss:scaled_loss.backward()else:loss.backward()optimizer.step()lr = optimizer.state_dict()['param_groups'][0]['lr']print_loss = loss.data.item()train_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))swa_model.update_parameters(model)swa_scheduler.step()

主要步骤有:

1、计算loss。

2、是否使用amp混合精度,如果使用混合精度则使用scaled_loss反向传播求梯度,否则直接loss反向传播求梯度。

3、 swa_model.update_parameters(model)更新swa_model的参数。

4、 swa_scheduler.step()更新学习率。

等待所有的epoch执行完成后。

torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
torch.save(swa_model.state_dict(), "last.pt")

更新BN层参数。

然后保存模型的权重。注意:这里只能保存模型的权重,不能保存整个模型。

完成之后就可以测试了,执行代码:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
from torchvision.models.mobilenetv3 import mobilenet_v3_large
import torch.nn as nn
from torch.optim.swa_utils import AveragedModel, SWALR
from timm.models.efficientnet import efficientnet_b1
import numpy as npdef show_outputs(output):output_sorted = sorted(output, reverse=True)top5_str = '-----TOP 5-----\n'for i in range(5):value = output_sorted[i]index = np.where(output == value)for j in range(len(index)):if (i + j) >= 5:breakif value > 0:topi = '{}: {}\n'.format(index[j], value)else:topi = '-1: 0.0\n'top5_str += topiprint(top5_str)transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = efficientnet_b1(pretrained=True)num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 8)
swa_model = AveragedModel(model)
swa_model.load_state_dict(torch.load("last.pt"))
swa_model.to(DEVICE)
swa_model.eval()path = 'test/'
testList = os.listdir(path)
for file in testList:img = Image.open(path + file)img = transform_test(img)img.unsqueeze_(0)img = Variable(img).to(DEVICE)out = swa_model(img)out = out.data.cpu().numpy()[0]print(file)show_outputs(out)

这里测试代码和以前的写法没有啥区别,唯一不同的地方:

重新定义模型,然后载入权重。
运行结果:

完整代码:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85223146

SWA实战:使用SWA进行微调,提高模型的泛化相关推荐

  1. 量纲与无量纲、标准化、归一化、正则化【能够帮助梯度下降中学习进度收敛的更快、提升模型的收敛速度提升模型的精度、防止模型过拟合,提高模型的泛化能力】

    目录 1 量纲与无量纲 1.1 量纲 1.2 无量纲 2 标准化 3 归一化 归一化的好处 4 正则化 5 总结 1 量纲与无量纲 1.1 量纲 物理量的大小与单位有关.就比如1块钱和1分钱,就是两个 ...

  2. 如何提高模型的泛化能力

    本博客纯属个人观点,不喜勿喷,也欢迎大神们留言补充. 我们把提高泛化能力的方法分为4类 数据端 模型端 训练过程 后处理 下面分别从这4个部分进行分类 一.数据端 方法: 1.  data augme ...

  3. 提高模型泛化能力的几大方法

    作者:OpenMMLab 链接:https://www.zhihu.com/question/540433389/answer/2629056736 来源:知乎 著作权归作者所有.商业转载请联系作者获 ...

  4. 【深度学习】常见的提高模型泛化能力的方法

    前言 模型的泛化能力是其是否能良好地应用的标准,因此如何通过有限的数据训练泛化能力更好的模型也是深度学习研究的重要问题.仅在数据集上高度拟合而无法对之外的数据进行正确的预测显然是不行的.本文将不断总结 ...

  5. AI体统中提高模型泛化能力的两个思路

    近几天做模式识别实验时遇到了一个问题.在A环境下采集的数据所训练出的模型,在B环境下几乎丧失了识别能力.很明显,该模型的泛化能力太差. 考虑两个思路:第一,在不同的环境中采集多组数据重新模型训练,以此 ...

  6. 【AI初识境】如何增加深度学习模型的泛化能力​​​​​​​

    文章首发于微信公众号<有三AI> [AI初识境]如何增加深度学习模型的泛化能力 这是专栏<AI初识境>的第9篇文章.所谓初识,就是对相关技术有基本了解,掌握了基本的使用方法. ...

  7. 【AI初识境】如何增加深度学习模型的泛化能力

    这是专栏<AI初识境>的第9篇文章.所谓初识,就是对相关技术有基本了解,掌握了基本的使用方法. 今天来说说深度学习中的generalization问题,也就是泛化和正则化有关的内容. 作者 ...

  8. 数据增强,扩充了数据集,增加了模型的泛化能力

    数据增强(Data Augmentation)是在不实质性的增加数据的情况下,从原始数据加工出更多的表示,提高原数据的数量及质量,以接近于更多数据量产生的价值. 其原理是,通过对原始数据融入先验知识, ...

  9. 使用折外预测(oof)评估模型的泛化性能和构建集成模型

    机器学习算法通常使用例如 kFold等的交叉验证技术来提高模型的准确度.在交叉验证过程中,预测是通过拆分出来的不用于模型训练的测试集进行的.这些预测被称为折外预测(out-of-fold predic ...

最新文章

  1. Basic的Json与Xml
  2. Kafka 副本OffsetOutOfRangeException
  3. java生成pdf_JAVA 生成PDF 并导出
  4. 生产环境遇到难题,你是如何解决的?
  5. display:none和visibility:hidden区别
  6. java迷宫_java实现迷宫算法--转
  7. Jenkins插件开发(四)-- 插件发布
  8. .NPT 扩展名格式文件类型及打开方式分析:首次渗入 XR 内容领域
  9. Mysql查询性能优化
  10. Shiro的详细简介解释(快速搭建官网解释代码)
  11. 如何用云计算提高员工工作效率
  12. 特斯拉是如何训练自动驾驶的?
  13. Windows 系统添加 VirtIO 驱动(Windows ISO 安装镜像添加驱动)
  14. Airflow实战--获取REST参数并通过Xcom传递给Bash算子
  15. 微PE安装系统 不显示U盘中镜像文件 的解决方法
  16. java的char类型
  17. 单片机c语言表达式与的关系,单片机c语言教程第七章--运算符和表达式(关系运算符)...
  18. 自定义checkbox
  19. 挖潜无极限---数据挖掘技术与应用热点扫描
  20. 利用机队数据训练的性能模型检测飞机异常

热门文章

  1. nginx配置插件压缩(切)图片
  2. java怎么求方程的虚根_java解一元二次方程 运行出错?
  3. 用java实现:生成13位条形码
  4. php网站老掉线,电脑网络不稳定老掉线的两种解决方法
  5. crt显存试题计算机,2008年9月全国计算机三级考试《PC技术》笔试真题
  6. matlab分布式电源储能系统配置优化研究 面向新能源储能容量配置 储能系统定容和电力系统优化调度双层决策优化模型
  7. 利用阿里云搭建NFS服务器
  8. 蓝桥杯寒假作业——python
  9. Microsoft Teams Voice语音落地系列-3 实战:拨号计划的配置
  10. 机器学习什么显卡_为什么机器学习模型在生产中会退化?