前言

 使用PyTorch官方提供的权重或者其他第三方提供的权重对相同模型的参数进行初始化,在数据量较少的前提下,可以帮助模型更快地收敛到最优点,达到更好的效果,即迁移学习。

 在大部分的迁移学习场景中,我们一般沿用之前模型的相关参数,这是因为卷积神经网络认为大部分的特征提取模式是一致的,即卷积神经网络中的归纳偏置能力强。在使用别人训练好的权重的过程中,一般冻结/保留提供权重模型中浅层的权重参数,只修改跟当前任务相关的层数。

 在本文中,主要讲解如何修改提供的权重并将其迁移到当前的任务上,例如:如何将PyTorch官方提供的ResNet权重迁移到别的分类任务上。本文将围绕此需求进行相关方法的介绍。

一、PyTorch官方ResNet权重下载链接

 PyTorch提供的ResNet权重文件下载链接如下:

model_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth','resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth','resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth','resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth','resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth','resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth','resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth','wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth','wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

本文通过ResNet网络使用num_classes代表该任务的分类类别,例如要将PyTorch官方在ImageNet数据集上训练的参数迁移到一个二分类的小任务中,即num_classes=2

 其他的网络预训练权重下载链接可以查看torchvision.models下对应模型的python文件或在官网进行查找。

  PyTorch官方提供的两种加载权重方法说明如下:

二、方法1

 第一种方法是:根据模型结构,删除权重模型最后的分类层,替换成属于该分类任务对应的全连接层,如将ImageNet最后输出的1000个神经元(1000分类)替换成2个输出神经元(二分类)。具体代码如下:

# 使用官方给定的api进行网络结构和权重的加载
from torchvision.models.resnet import resnet18
net = resnet18(pretrained=True)
# 冻结所有参数 使其不更新
for param in net.parameters():param.requires_grad = False
# 替换全连接层
in_channel = net.fc.in_features    # 获得原模型全连接层的输入特征大小
net.fc = nn.Linear(in_channel, num_classes)  # num_classes代表分类器的类别

这里先设置所有参数为不可训练,而新的nn.Linear是可训练的。

 如果不使用官方定义的模型结构,也可以使用自己定义好的,前提是自身定义的模型在网络结构的定义上跟官方是一致的(层的名称和网络的参数),自己定义的模型加载预训练权重的方式如下:

# 使用自定义的网络结构和权重的加载
net = resnet18(num_classes=1000)      # ***
# 载入预训练权重 model_weight_path:含权重名的模型路径 device:设备
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# 冻结所有参数 使其不更新
for param in net.parameters():param.requires_grad = False
# 替换全连接层
in_channel = net.fc.in_features    # 获得原模型全连接层的输入特征大小
net.fc = nn.Linear(in_channel, num_classes)  # num_classes代表分类器的类别

注意:这里是在原有的模型基础上进行增删网络结构的操作,如想一开始就定义针对该任务的方式请看方法2。

三、方法2

 第二种方式是:首先读取权重到变量中但不载入(读取的参数是以一个字典的形式存储),然后找到预训练权重中不需要的权重将它删除,最后再载入到网络中。

from torchvision.models.resnet import resnet18
net = resnet18(pretrained=True)
# 将定义的网络中参数读取到net_weights字典变量中 (key: name, val: weights)
net_weights = net.state_dict()
# 读取预训练的权重
weights = torch.load(model_weight_path, map_location=device)
del_key = []# 遍历预训练权重,如果存在key值包含“fc”的权重就将它从预训练权重字典中删除
for key, _ in weights.items():if "fc" in key:del_key.append(key)
for key in del_key:del weights[key]# 载入权重
net.load_state_dict(weights, strict=False)

注意这里要设置strict=False。默认为True会严格按key值载入权重,如果出现缺失的权重会报错,这里因为我们已经删除了一部分预训练权重所以要设置为False。但需要注意的是使用strict=False时,即使模型和权重的关系不对应,载入也不会报错,如将ResNet34上的预训练权重以此方式载入到ResNet18模型时,会根据网络结构对可加载的权重进行加载,如果这样设置有可能会造成预训练权重没有加载到要求的网络中,造成无效地预训练/迁移。

 使用自定义结构进行权重加载的方式如下:

net = resnet18(num_classes=2)    # ***
# 将定义的网络中参数读取到net_weights字典变量中 (key: name, val: weights)
net_weights = net.state_dict()
# 读取预训练的权重
weights = torch.load(model_weight_path, map_location=device)
del_key = []# 遍历预训练权重,如果存在key值包含“fc”的权重就将它从预训练权重字典中删除
for key, _ in weights.items():if "fc" in key:del_key.append(key)
for key in del_key:del weights[key]# 载入权重
net.load_state_dict(weights, strict=False)

 还有比较麻烦的情况就是定义的网络中某一层的key值和预训练权重对应层的key值不一样,这个比较麻烦,一般需要手动修改字典中的key值,本文不再详述。

四、参考链接

[1] https://www.csdn.net/tags/MtTaEgzsMTk3NjAxLWJsb2cO0O0O.html

[2] https://pytorch.org/vision/stable/models.html

PyTorch 加载预训练权重相关推荐

  1. torch编程-加载预训练权重-模型冻结-解耦-梯度不反传

    1)加载预训练权重 net = torchvision.models.resnet50(pretrained=False) # 构建模型 pretrained_model = torch.load(p ...

  2. 深度学习加载预训练权重好处

    深度学习加载预训练权重好处: 在模型开始训练前,使模型参数得到一个好的初始化,对于后面的训练学习有非常大的帮助.

  3. Pytorch加载预训练网络,替换分类层并重新训练

    定义网络时,在网络类的构造函数网络结构定义中添加如下语句: for p in self.parameters():p.requires_grad = False 该语句的功能是固定定义在该语句之前的网 ...

  4. pytorch加载预训练 加载部分参数

    最简单的: state_dict = torch.load(weight_path)    self.load_state_dict(state_dict,strict=False) 加载cpu: m ...

  5. 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次都特别慢

    欢迎大家关注笔者,你的关注是我持续更博的最大动力 原创文章,转载告知,盗版必究 使用torchvision.models.inception_v3(pretrained=True)加载预训练的模型每次 ...

  6. 实践:jieba分词和pkuseg分词、去除停用词、加载预训练词向量

    一:jieba分词和pkuseg分词 原代码文件 链接:https://pan.baidu.com/s/1J8kmTFk8lec5ubfwBaSnLg 提取码:e4nv 目录: 1:分词介绍: 目标: ...

  7. paddlepaddle加载预训练词向量

    文章目录 1.一些用到的api文档 2.加载预训练词向量 2.1小数据 2.2核心代码 2.3验证结果 3.可能有用的 tensorflow的加载方法可以看我之前写的: tensorflow加载词向量 ...

  8. pytorch加载预训练模型_Pytorch-Transformers 1.0发布,支持六个预训练框架,含27个预训练模型...

    AI 科技评论按:刚刚在Github上发布了开源 Pytorch-Transformers 1.0,该项目支持BERT, GPT, GPT-2, Transfo-XL, XLNet, XLM等,并包含 ...

  9. Pytorch中更改预训练权重文件的下载位置

    目录 1. 参考链接 2. 更改方法 3. 一个小技巧 1. 参考链接 Pytorch更改预训练权重下载位置 pytorch---修改预训练模型下载路径 2. 更改方法 在线加载的预训练权重默认存放位 ...

最新文章

  1. [转]memset用法详解
  2. 如何检测 SAP 电商云 Spartacus UI 当前正处于导航状态
  3. iOS框架介绍之coreImage
  4. Uniswap 24h交易量约11.2亿美元涨23.91%
  5. 数据结构与算法之-----队列(Queue)
  6. Win10安装乌班图18双系统
  7. c语言16进制与字符串互转,C语言版的16进制与字符串互转函数
  8. java公告栏按月查询_求java公告栏特效代码
  9. 从正射到倾斜,Mavic 3E详细使用报告
  10. 30 个 Python 的最佳实践、小贴士和技巧
  11. 程序员去大公司面试,阿里P8面试官都说太详细了,社招面试心得
  12. IBM__P系列 小型机 故障定位 故障排除
  13. 老男孩-筷子兄弟(歌词)
  14. python 保留浮点数为两位小数
  15. 预防新型冠状病毒感染的肺炎口罩使用指南
  16. linux下修改ext3硬盘为nst,Linux服务器数据备份恢复策略(3)
  17. int和String类型的转换
  18. WINDOWS图像编程
  19. 蓝牙技术|蓝牙智能笔方案简介
  20. 计算机操作基本知识题库,计算机操作基础知识题库.doc

热门文章

  1. iiOS 6 新特性
  2. 中国最大的“隐形首富”,掌舵中国最大汽车集团,身价高达760亿
  3. 小重山 【南宋】 岳飞
  4. adb连接的2种方式,有线(USB线)和无线
  5. 黑马程序员--银行以及交通系统项目个人理解
  6. es拼音分词 大帅哥_elasticsearch实现中文分词和拼音分词混合查询+CompletionSuggestion...
  7. freeswitch通过limit限制cps
  8. Jupyter Notebook 输出有颜色的文字
  9. DayDayUp:佛说:有果必有因。 黑格尔说:世界上没有无缘无故的爱,也没有无缘无故的恨。
  10. 使用互传APP实现Android手机投屏到windows电脑