声明:该部分为转载参考

简单来说,预训练模型(pre-trained model)是前人为了解决类似问题所创造出来的模型。你在解决问题的时候,不用从零开始训练一个新模型,可以从在类似问题中训练过的模型入手。

场景一:数据集小,数据相似度高(与pre-trained model的训练数据相比而言)
在这种情况下,因为数据与预训练模型的训练数据相似度很高,因此我们不需要重新训练模型。我们只需要将输出层改制成符合问题情境下的结构就好。

我们使用预处理模型作为模式提取器。

比如说我们使用在ImageNet上训练的模型来辨认一组新照片中的小猫小狗。在这里,需要被辨认的图片与ImageNet库中的图片类似,但是我们的输出结果中只需要两项——猫或者狗。

在这个例子中,我们需要做的就是把dense layer和最终softmax layer的输出从1000个类别改为2个类别。

场景二:数据集小,数据相似度不高

在这种情况下,我们可以冻结预训练模型中的前k个层中的权重,然后重新训练后面的n-k个层,当然最后一层也需要根据相应的输出格式来进行修改。

因为数据的相似度不高,重新训练的过程就变得非常关键。而新数据集大小的不足,则是通过冻结预训练模型的前k层进行弥补。

场景三:数据集大,数据相似度不高

在这种情况下,因为我们有一个很大的数据集,所以神经网络的训练过程将会比较有效率。然而,因为实际数据与预训练模型的训练数据之间存在很大差异,采用预训练模型将不会是一种高效的方式。

因此最好的方法还是将预处理模型中的权重全都初始化后在新数据集的基础上重头开始训练。

场景四:数据集大,数据相似度高

这就是最理想的情况,采用预训练模型会变得非常高效。最好的运用方式是保持模型原有的结构和初始权重不变,随后在新数据集的基础上重新训练。

预训练模型的方法

A,特征提取

我们可以将预训练模型当做特征提取装置来使用。具体的做法是,将输出层去掉,然后将剩下的整个网络当做一个固定的特征提取机,从而应用到新的数据集中。

B,采用预训练模型的结构

我们还可以采用预训练模型的结构,但先将所有的权重随机化,然后依据自己的数据集进行训练。

C,训练特定层,冻结其他层

另一种使用预训练模型的方法是对它进行部分的训练。具体的做法是,将模型起始的一些层的权重保持不变,重新训练后面的层,得到新的权重。在这个过程中,我们可以多次进行尝试,从而能够依据结果找到frozen layers和retrain layers之间的最佳搭配。

如何使用与训练模型,是由数据集大小和新旧数据集(预训练的数据集和我们要解决的数据集)之间数据的相似度来决定的。

实现预训练模型的加载(pytorch)

直接加载预训练模型


import torchvision.models as modelsmodel = models.resnet101(pretrained=True)

修改某一层


import torchvision.models as modelsmodel = models.resnet101(pretrained=True)model.fc = nn.Linear(2048, 120)  #120为样本分类数目,修改最后的分类的全连接层
model.conv1 = nn.Conv2d(3, 64,kernel_size=5, stride=2, padding=3, bias=False)   #修改中间层

加载部分预训练模型

#加载model,model是自己定义好的模型
resnet50 = models.resnet50(pretrained=True)
model =Net(...) #读取参数
pretrained_dict =resnet50.state_dict()
model_dict = model.state_dict() #将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新现有的model_dict
model_dict.update(pretrained_dict) # 加载我们真正需要的state_dict
model.load_state_dict(model_dict)

pytorch与resnet(六) 预训练模型使用的场景相关推荐

  1. 【小白学PyTorch】5.torchvision预训练模型与数据集全览

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyTorch | 3 浅谈Dataset和Da ...

  2. Pytorch 加载部分预训练模型并冻结某些层

    目录 1  pytorch的版本: 2  数据下载地址: 3  原始版本代码下载: 4  直接上代码: 1  pytorch的版本: 2  数据下载地址: <https://download.p ...

  3. pytorch:加载预训练模型(多卡加载单卡预训练模型,多GPU,单GPU)

    在pytorch加载预训练模型时,可能遇到以下几种情况. 分为以下几种 在pytorch加载预训练模型时,可能遇到以下几种情况. 1.多卡训练模型加载单卡预训练模型 2. 多卡训练模型加载多卡预训练模 ...

  4. 【PyTorch】下载的预训练模型的保存位置(Windows)

    保存位置 C:\Users\xxx\.cache\torch\hub\checkpoints\ xxx替换为你的用户名. 项目场景 迁移学习的时候一般需要用到预训练模型,那么预训练模型的保存位置是在哪 ...

  5. 【PyTorch】如何取得预训练模型的标签label列表(以 Alexnet 在 ImageNet 上的预训练模型为例)

    PyTorch 预训练模型 PyTorch 提供过了大量的预训练模型可以直接拿来使用,或者进行增量训练和微调. 拿 Alexnet 的预训练模型为例 import torch import torch ...

  6. Pytorch中使用Bert预训练模型,并给定句子得到对应的向量

    写在前面 本次的需求是:通过预训练好的Bert模型,得到不同语境下,不同句子的句向量.相比于word2vec.glove这种静态词向量,会含有更丰富的语义,并能解决不同场景不同意思的问题. 建议大家先 ...

  7. 迁移学习:如何为您的机器学习问题选择正确的预训练模型

    https://www.toutiao.com/a6687923187298075144/ 在这篇文章中,我们将简要介绍一下迁移学习是什么,以及如何使用它. 什么是迁移学习? 迁移学习是使用预训练模型 ...

  8. 百度大脑 EasyDL 专业版最新上线自研超大规模视觉预训练模型

    在学习与定制AI模型的过程中,开发者会面对各种各样的概念,在深度学习领域,有一个名词正在被越来越频繁地得到关注:迁移学习.它相比效果表现好的监督学习来说,可以减去大量的枯燥标注过程,简单来说就是在大数 ...

  9. 【NLP】一文速览 | 对话生成预训练模型

    作者 | 惠惠惠惠惠惠然 整理 | NewBeeNLP 大规模预训练言模型在生成式对话领域近年来有非常多的工作,如百度PLATO系列(PLATO[1]/PLATO-2[2]/PLATO-XL[3]), ...

  10. 一文速览 | 对话生成预训练模型

    作者 | 惠惠惠惠惠惠然 整理 | NewBeeNLP 大规模预训练言模型在生成式对话领域近年来有非常多的工作,如百度PLATO系列(PLATO[1]/PLATO-2[2]/PLATO-XL[3]), ...

最新文章

  1. Fresco源码分析之Hierarchy
  2. 尝鲜 workerize 源码
  3. 报工提示错误:“没有内部作业价格可被确认”的解决方法
  4. 「Ubuntu」系统常用命令
  5. 教你一招快速打开idea的秘诀
  6. queuedeclare参数说明_MQ 学习笔记之RabbitMQ
  7. unicode 转换
  8. 【MySQL数据库开发之四】MySQL 处理模式/常用查询/模式匹配等(下)
  9. 如何使用PHP中的字符串函数
  10. 【操作系统/OS笔记12】同步互斥的三种实现方法:禁用硬件中断、基于软件的解决方案、更高级的抽象
  11. List、Set、Map比较
  12. 牛顿插值 | MATLAB源码
  13. 由系统函数求零极点图、频率响应(幅频特性、相频特性)的 Matlab 和 Python 方法
  14. 微信公众号如何添加附件链接
  15. Windows Server 2016 搭建DHCP服务器(踩坑后总结)
  16. 潘金莲改变了历史吗 - PostgreSQL舆情事件分析应用
  17. canvas画任意角度的扇形,弧形,及扇形弧形填纯色渐变色
  18. 小白组装电脑详细教程
  19. table标签及排版详解(一)
  20. 单片机模拟计算机课设,《单片机课程设计实例》.doc

热门文章

  1. mysql插入时unique字段重复插入失败
  2. PHP实现折半查询算法
  3. html5有哪些优点,HTML5真正的优势优点有什么?
  4. Flink Kafka 端到端 Exactly-Once 分析
  5. python自动化测试实战 虫师_Page Object 1 百度搜索实例 (虫师《selenium3自动化测试实战--基于Python语言笔记40》)...
  6. java读取手机崩溃日志_Android抓取崩溃日志
  7. 计算机应用基础第3次平时作业,计算机应用基础第3次作业.doc
  8. linux系统中文乱码的问题
  9. JS实现键盘事件上下翻页
  10. 帆软JS获取控件扩展的值的两种方法