我的需求是,由于我在不停的尝试各种模型,导致模型木块一直会变。如果每次重复重新开始训练要花费大把时间。

我之前运行的模型 ResNet ->                                      三个ResNet参数共享。

ResNet ->           中间模块 -> 结果

ResNet ->

现在我要改成        ResNet 1->                                       三个ResNet不参数共享来重新训练,我想导入之前模型中间模块的参数,

ResNet 2->           中间模块 -> 结果

ResNet 3->

并且冻结中间模块的参数使训练速度加快。

参考了两位大神的两篇博文:加载部分参数https://blog.csdn.net/weixin_41519463/article/details/101604662,冻结部分参数https://blog.csdn.net/jdzwanghao/article/details/83239111。

具体代码如下:

net = MY_Net( )
######导入部分参数model_dict = net.state_dict()for k, v in model_dict.items():print(k)pretrained_dict = torch.load(model_file1)#model_file1是之前模型的模型保存路径,这里只是加载参数而已for k, v in pretrained_dict.items():print(k)pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}model_dict.update(pretrained_dict)  # 用预训练模型参数更新new_model中的部分参数net.load_state_dict(model_dict)  # 将更新后的model_dict加载进new model中##### 冻结部分参数for param in net.parameters():param.requires_grad = False#设置所有参数不可导,下面选择设置可导的参数for param in net.ResNet1.parameters():param.requires_grad = Truefor param in net.ResNet2.parameters():param.requires_grad = Truefor param in net.ResNet3.parameters():param.requires_grad = Trueoptimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr = 0.0001, momentum=0.90,weight_decay=0.0005)#关键是优化器中通filter来过滤掉那些不可导的参数

pytorch加载之前训练模型中的部分参数以及冻结部分参数(实测,自己实际项目代码中的)相关推荐

  1. pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

    问题 最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下.     KeyError: 'layer1.0.bn1.num_batches_tracked' 其实 ...

  2. pytorch加载预训练模型_Pytorch-Transformers 1.0发布,支持六个预训练框架,含27个预训练模型...

    AI 科技评论按:刚刚在Github上发布了开源 Pytorch-Transformers 1.0,该项目支持BERT, GPT, GPT-2, Transfo-XL, XLNet, XLM等,并包含 ...

  3. Pytorch 加载预训练模型参数时出现size mismatch错误

    目录 1 不妨先研究一下' resnet18-5c106cde.pth'里面存了什么东西以及它的数据类型 (1_1)' resnet18-5c106cde.pth'的数据类型

  4. pytorch:加载预训练模型(多卡加载单卡预训练模型,多GPU,单GPU)

    在pytorch加载预训练模型时,可能遇到以下几种情况. 分为以下几种 在pytorch加载预训练模型时,可能遇到以下几种情况. 1.多卡训练模型加载单卡预训练模型 2. 多卡训练模型加载多卡预训练模 ...

  5. HuggingFace学习3:加载预训练模型完成机器翻译(中译英)任务

    加载模型页面为:https://huggingface.co/liam168/trans-opus-mt-zh-en 文章目录 整理文件 跑通程序,测试预训练模型 拆解Pipeline,逐步进行翻译任 ...

  6. Pytorch加载torchvision从本地下载好的预训练模型的简单解决方案

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.喜 ...

  7. Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率

    前言 在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型.本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用. 权重初始化 常见的初始化方法 PyTorc ...

  8. 用pytorch加载训练模型

    用pytorch加载.pth格式的训练模型 在pytorch/vision/models网页上有很多现成的经典网络模型可以调用,其中包括alexnet.vgg.googlenet.resnet.inc ...

  9. Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

    需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...

最新文章

  1. 07/11/13 资料整理
  2. is this mysql server_Mysql:is not allowed to connect to this MySQL server
  3. svn: Checksum mismatch while updating 'D:\workspace\demo\test\.svn\text-base\test.php.svn-base'
  4. 最全、最详细的配置jdk十步法!
  5. Shell编程:简洁的 Bash Programming 技巧
  6. C#调用C和C++函数的一点区别
  7. 2018蓝桥杯A组:方格计数(3种方法)
  8. YUI:globle object
  9. 适用于软件工程的定律Augustine's laws
  10. DEL: 华为无线modem变无线路由器 2
  11. Qt Https http 请求案例
  12. 进化的系统需要进化的系统工程
  13. android即时通讯ui框架,android IM即时通信之聊天界面UI框架
  14. 火狐打不开qq空间,说“建立安全连接失败”,解决方案
  15. Android12 apk安装失败 安装包异常 安装包大小显示1k
  16. 读懂常见IRP:IRP_MJ_CLEANUP\IRP_MJ_CLOSE\IRP_MJ_CREATE
  17. SOLIDWORKS怎么把STEP曲面转换成实体
  18. web前端入门到实战:CSS text-decoration
  19. EV1527解码函数,看网上人家写的不好使,贡献一下,定时器中断形式解码!
  20. python2安装tensorflow,tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)

热门文章

  1. 支持IPsec / L2TP / EtherIP测试版下载
  2. mybatis返回Date类型数据 格式化
  3. NPOI设置Excel中的单元格识别为日期
  4. Swift4.0复习协议
  5. ThinkPHP文件目录说明
  6. log4net根据日志类型写入到不同的文件中
  7. 数据库执行sql报错Got a packet bigger than 'max_allowed_packet' bytes及重启mysql
  8. 【原创】一款符合当前主流审美的Swing外观(Look and Feel)_测试版发布
  9. 接口测试工具--apipost预/后执行脚本
  10. 接口测试工具--apipost脚本讲解