迁移学习就是载入别人预训练好的权重,拿别人的训练好的参数作为我们自己模型的初始化参数,再在这个基础上继续优化。比起从头开始一点一点随机初始化,让模型胡乱地找梯度最优的方向,肯定是迁移学习快啦。

目录

  • 1.载入训练权重
  • 2.载入预训练权重后两种情况
    • 2.1冻结全部的特征提取层,微调全连接层
    • 2.2冻结部分的特征层
    • 2.3不冻结特征提取层
  • 3.直观上的结果对比

1.载入训练权重

载入别人预训练权重的时候,由于别人的数据预处理的方式与我们自己的可能不同、全连接层中最后分类的结点个数和你的数据集类别个数不同等情况,都会产生各种报错,我就说一下我会的方法,后面有会进行补充哒。

net = MobileNetV2(class_nums=5)
pre_weights = torch.load(pre_trained_pth)  //(字典文件)这里载入别人训练好的预训练权重,此时只是导入内存中,还没加载到我们的网络中

下面的两种方法都是:载入预训练权重后,删除全连接层的参数
好处就是:在创建网络的时候,可以直接根据我们自己的数据集类别个数更改模型中最后一个全连接节点个数。

//方法一:
pre_dict = {k: v for k, v in pre_weights.items() if "fc" not in k}
//
//方法二:
//这里的想法是:拿出预训练权重(字典)的key和value,通过获取我们自己网络中与预训练权重中网络层名称一样的的层,拿到相同个数的网络层,删除不一致的
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)  //将特征提取层的权重送进网络,这里的strict设置为False后,就不用预训练权重的网络结构和我们自己的网络完全key值一致

而第三种方法,在创建网络时候,更改最后的全连接层节点个数,直接net.load_state_dict()方法载入会报错的

net = MobileNetV2()
net.load_state_dict(torch.load(pre_trained_pth), strict=False)
//方法三:
in_channel = net.fc.in_feacture
net.fc = nn.Linear(in_channel, 5)  //这里为什么是.fc,道理同下,下面具体讲解了

这里另外说一下为什么删除预训练权重中全连接层的参数,上方代码的判断语句中必须是“fc”???
其实不然,取决于你搭建网络时类变量名称的定义,可以通过来查看每个网络层的名称:

net = MobileNetV2(class_nums=5)
print(net)
//或者
for name in net.named_modules():print(name)

例如:上面的红框中网络的全连接层名称是classifier,那么我们在删除全连接层的参数的时候就是if “classifier” not in k。
或者把网络结构中的classifier改成fc。这里全连接层的名称完全取决于我们自己。

2.载入预训练权重后两种情况

特征层的权重参数在反向传播过程中会求导,冻结特征层,可以缩短训练时间。

2.1冻结全部的特征提取层,微调全连接层

我们如果想要在短时间内将我们的模型达到一个相对理想的效果,可以将特征层全部冻结(全部使用别人预训练权重),然后只训练全连接层,根据我们自己的数据集类别进行fine-tuning。

for param in net.parameters():  //这里的就遍历每个特征层上的权重参数了param.requires_grad = False

2.2冻结部分的特征层

在卷积神经网络中,特征提取层就是卷积层的部分,因为一开始卷积得到的特征图上的信息多为一些简单的信息:边缘、转角等,这类简单的特征对于大多数的对象都是通用的,所以我们可以冻结低层的权重,以减少训练时间。

for param in net.xx特征层名称.parameters():  //这里的xx特征层名称取决于搭建网络时的定义param.requires_grad = False

2.3不冻结特征提取层

我认为这能够取得比较好的效果,在以别人的预训练权重参数为基础,继续寻找梯度最优,但可能比起上面冻结特征层,会增加时间上的负担。

3.直观上的结果对比

红色曲线是:载入别人预训练权重后,冻结特征层的loss和精确率;
蓝色曲线是:载入别人预训练权重后,没有冻结特征层,在此基础上继续梯度优化;
黑色曲线是:没有采用迁移学习,让搭建的网络模型随机初始化权重开始训练。

conclusion:
载入预训练权重,拿别人的参数作为一开始的初始化,更容易达到最优;而比起从头开始一点一点随机初始化,让模型胡乱地找梯度最优的方向。肯定迁移学习更好。
而冻结了特征层后,载入别人预训练权重后就不再继续优化了,效果肯定不如随机初始化后梯度下降后的效果,好比你原地踏步,别人在慢慢进步。

迁移学习、载入预训练权重和冻结权重相关推荐

  1. 迁移学习-使用预训练的Inception v3进行宠物分类

    个人博客:http://www.chenjianqu.com/ 原文链接:http://www.chenjianqu.com/show-53.html 迁移学习 迁移学习(Transfer Learn ...

  2. PyTorch载入预训练权重方法和冻结权重方法

    载入预训练权重 1. 直接载入预训练权重 简单粗暴法: pretrain_weights_path = "./resnet50.pth" net.load_state_dict(t ...

  3. 学习笔记26-解决:载入预训练模型时Pytorch遇到权重不匹配的问题(附+修改后的预训练模型载入和冻结特征权重完整代码)

    在pytorch微调mobilenetV3模型时遇到的问题 1.KeyError: 'features.4.block.2.fc1.weight' 这个是因为模型结构修改了,没有正确修改预训练权重,导 ...

  4. 迁移学习实战 | 快速训练残差网络 ResNet-101,完成图像分类与预测,精度高达 98%!...

    作者 | AI 菌 出品 | CSDN博客 头图 | CSDN付费下载自视觉中国 前言 笔者在实现ResNet的过程中,由于电脑性能原因,不得不选择层数较少的ResNet-18进行训练.但是很快发现, ...

  5. 直播预告 | AAAI 2022论文解读:基于对比学习的预训练语言模型剪枝压缩

    「AI Drive」是由 PaperWeekly 和 biendata 共同发起的学术直播间,旨在帮助更多的青年学者宣传其最新科研成果.我们一直认为,单向地输出知识并不是一个最好的方式,而有效地反馈和 ...

  6. 无需在数据集上学习和预训练,这种图像修复新方法效果惊人 | 论文

    林鳞 编译自 Github 量子位 出品 | 公众号 QbitAI Reddit上又炸了,原因是一个无需在数据集上学习和预训练就可以超分辨率.修补和去噪的方法:Deep image prior. 帖子 ...

  7. pytorch迁移学习载入部分权重

    载入权重是迁移学习的重要部分,这个权重的来源可以是官方发布的预训练权重,也可以是你自己训练的权重并载入模型进行继续学习.使用官方预训练权重,这样的权重包含的信息量大且全面,可以适配一些小数据的任务,即 ...

  8. 深度学习--使用预训练的卷积神经网络

    文章目录 前言 一.使用预训练网络 二.将VGG16卷积基实例化 三.使用卷积基进行特征提取 1.不使用数据增强的快速特征提取 2.使用数据增强的特征提取 四.微调模型 前言 想要将深度学习应用于小型 ...

  9. 【深度学习】预训练的卷积模型比Transformer更好?

    引言 这篇文章就是当下很火的用预训练CNN刷爆Transformer的文章,LeCun对这篇文章做出了很有深意的评论:"Hmmm".本文在预训练微调范式下对基于卷积的Seq2Seq ...

最新文章

  1. 抱歉,我觉得有些人做副业并不靠谱
  2. python yield与递归
  3. 大中型网站集群架构企业级高标准全自动实战项目征集
  4. 电脑开机进入不了XP界面
  5. 换种方法学操作系统,轻松入门Linux内核
  6. git commit命令
  7. js生成随机密码,密码位数自定
  8. 【PyQt5】PyQt5 安装 以及使用 designer 开发 python GUI 界面
  9. java学习(六)多线程 下
  10. 小程序丨canvas内容自适应不同尺寸屏幕
  11. Spring 之bean的注入
  12. Excel:合并两个单元格内容
  13. 利用计算机名称共享打印机步骤,如何连接共享打印机?共享打印机连接方法介绍...
  14. 树莓派驱动数码管c 语言,用树莓派驱动八段数码管实现倒计时
  15. python计算偶数平方和_如何使用Python和Numpy计算r平方?
  16. 用jQuery做一个简单的用户注册页面
  17. vue项目中引入阿里云滑动验证
  18. 1.find如何快速查找、搜索文件
  19. Doves and bombs UVA - 10765
  20. Excel 2010 VBA 入门 013 导入或导出VBA代码

热门文章

  1. 有十个按钮点击按钮aler按钮的序号
  2. validForm结合layer制作表单验证提示
  3. 南大通用GBase8s 常用SQL语句(233)
  4. Day51 前端开发 浮动、定位 、js入门
  5. X86逆向教程9:通过关键常量破解
  6. 讲故事的人写的谈判手册——Leo锦书64
  7. nyoj1170 最大的数
  8. 中国足球队输球的原因大总结
  9. vue-axios demo_user 展示数据 删除 修改 新增
  10. 为什么有些人能用一年获得你三年的工作经验?