1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():if name 满足某些条件:value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},{'params':model.decoder.parameters()}],lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。 
在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有paramslr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

参考:

  1. pytorch官方文档
  2. https://blog.csdn.net/u012759136/article/details/65634477

pytorch载入预训练模型后,训练指定层相关推荐

  1. Pytorch提取预训练模型特定中间层的输出

    如果是你自己构建的模型,那么可以再forward函数中,返回特定层的输出特征图. 下面是介绍针对预训练模型,获取指定层的输出的方法. 如果你只想得到模型最后全连接层之前的输出,那么只需要将最后一个全连 ...

  2. 学习笔记26-解决:载入预训练模型时Pytorch遇到权重不匹配的问题(附+修改后的预训练模型载入和冻结特征权重完整代码)

    在pytorch微调mobilenetV3模型时遇到的问题 1.KeyError: 'features.4.block.2.fc1.weight' 这个是因为模型结构修改了,没有正确修改预训练权重,导 ...

  3. Pytorch 加载部分预训练模型并冻结某些层

    目录 1  pytorch的版本: 2  数据下载地址: 3  原始版本代码下载: 4  直接上代码: 1  pytorch的版本: 2  数据下载地址: <https://download.p ...

  4. linux载入pytorch的预训练模型时遇到_pickle.UnpicklingError: unpickling stack underflow

    linux试图载入pytorch的预训练模型resnet101时遇到如下报错: Traceback (most recent call last): File "train_baseline ...

  5. Pytorch——BERT 预训练模型及文本分类(情感分类)

    BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...

  6. CV之NS之VGG16:基于预训练模型VGG16训练COCO的train2014数据集实现训练《神奈川冲浪里》风格配置yml文件

    CV之NS之VGG16:基于预训练模型VGG16训练COCO的train2014数据集实现训练<神奈川冲浪里>风格配置yml文件 目录 一.训练 1.<神奈川冲浪里>风格 2. ...

  7. 替换骨干网络之后使用预训练模型进行训练

    前言 最近看了几篇使用transformer的文章,于是想用其中的一个transformer模块来替换另一个方法的骨干网络(backbone),替换完之后跑起来感觉没有什么效果,想着可能是transf ...

  8. PyTorch载入预训练权重方法和冻结权重方法

    载入预训练权重 1. 直接载入预训练权重 简单粗暴法: pretrain_weights_path = "./resnet50.pth" net.load_state_dict(t ...

  9. Pytorch使用预训练模型进行图像分类

    在本文中,我们将介绍一些使用预训练网络的实际例子,这些网络出现在TorchVision模块的图像分类中. Torchvision包包括流行的数据集,模型体系结构,和通用的图像转换为计算机视觉.基本上, ...

最新文章

  1. .NET 即时通信,WebSocket服务端实例
  2. php和python区别-python与php比较
  3. 定义系统消息 Specify system messages
  4. leaflet 的 marker 弹框 iframe 嵌套代码
  5. 前端手册-CSS3 属性手册
  6. 电商公司ERP管理软件与旺店通、第三方仓库以及云仓的贯通解决方案
  7. # 数学基础task 01 函数极限与连续性
  8. 使用再生龙镜像备份还原linux,以及遇到的问题和解决方法
  9. 新世纪10年100个好东西 淘宝、QQ、伟哥入选
  10. cassandra java cql_Cassandra CQL v3.3中文文档(上)
  11. 马云被骗十亿?最后却被百倍奉还。
  12. Android 人脸识别签到(二)
  13. Pycharm2018永久破解方法
  14. chapter3 动态分析基础技术-01在线沙箱 微步云沙箱
  15. HTTP请求的过程与TCP连接的过程
  16. Uncaught ReferenceError: Mustache is not defined
  17. 服务器普通硬盘,服务器硬盘和普通硬盘区别
  18. HDMI硬件设计要求及CTS要求
  19. 当心,别被微信小程序火爆的假象所欺骗!
  20. elasticsearch7.9操作必看结合官方文档 es head的操作必看 es增删改查全详解

热门文章

  1. boost::serialization模块实现测试 shared_ptr 序列化的测试程序
  2. boost::range模块heap算法相关的测试程序
  3. boost::multiprecision::float128用法的测试程序
  4. boost::log::filter用法的测试程序
  5. boost::hana::remove_range用法的测试程序
  6. boost::coroutine模块实现非对称协程的测试程序
  7. GDCM:VRDS的测试程序
  8. GDCM:gdcm::Reader的测试程序
  9. VTK:Rendering之PhysicalBasedRendering
  10. VTK:PolyData之AlignFrames