文章目录

  • 前言
  • 第一步 加载预训练模型并修改类别数
  • 第二步 选择模型所有层/最后一层进行反向传播优化
  • 探讨:如何确定模型最后一层的名字是什么
    • 方法一: 查询源代码
    • 方法二: 查询模型的子模块名字

前言

先,这里不讲迁移学习的理论,只讲实践,因为理论已经全网飞了~~,不懂得大家先去学理论,理论学了再来实操。
今天,在这里只想给大家介绍一种代码写法,适用于基于pytorch的迁移学习。

迁移学习主要用在分类模型上,把原本在ImageNet或其他数据集上训练好的模型,迁移到自己的项目上来。所以对于分类模型,我们要把模型最后一层(通常是全连接层)的输出分类类别数量改了。比如,原本在ImageNet上是分1000类,而我们的目标是分3类,就要把最后的类别数改为3。

整个过程分为两步:

第一步 加载预训练模型并修改类别数

from torchvision.models import densenet169, resnet50
import torch.nn as nn
import torch.optim as optim### 加载模型, 并修改模型的最后一层  ####
model = densenet169(pretrained=True)
# 设定 pretrained = True 就会加载训练好的模型# 修改模型的最后一层。不同的模型,最后一层的修改略有差异
arch = 'densenet169'
classes = 2  # 分类的数量
if 'resnet' in arch:# for param in model.layer4.parameters():model.fc = nn.Linear(2048, classes)if 'dense' in arch:if '121' in arch:# (classifier): Linear(in_features=1024)model.classifier = nn.Linear(1024, classes)elif '169' in arch:# (classifier): Linear(in_features=1664)model.classifier = nn.Linear(1664, classes)

第二步 选择模型所有层/最后一层进行反向传播优化

迁移学习有两种模型:一种是对预训练好的模型,再次从头训练,模型的每一层都要重新优化。另一种是只重新训练模型最后一层,其余层的参数固定。第一种方法适用于大多数迁移学习,预训练好的模型是在自然图像上训练的,如果迁移到医学图像上来,那么特征之间的差异很大,这时就选择第一种,重头训练。但假如有一个任务是分别猫和狗这种自然图像,且在ImageNet这个数据集中已经包含的,那么就只优化最后一层参数即可。

####  选择模型所有层都要进行反向传播优化 还是 只优化最优一层  #####
fullretrain = True  # True: 表示所有层都要进行优化,为False: 只优化最后一层if fullretrain:print("=> optimizing all layers")for param in model.parameters():param.requires_grad = Trueoptimizer = optim.Adam(model.parameters(), lr=0.03, weight_decay=1e-4)# model.parameters(): 把模型所有参数都传进去
else:print("=> optimizing fc/classifier layers")optimizer = optim.Adam(model.module.fc.parameters(), lr=0.03, weight_decay=1e-4)# model.module.fc.parameters(): 只传最后一个分类层的参数进去# 注意: 不同模型,最后一层的名字不一样

我们从代码里面可以发现,两种方法的区别就是优化器(optimizer)接收的参数不一样,第一种方法是把模型的所有参数都传进去,第二种是只传模型最后一层的参数。

迁移学习这部分的代码就讲完了。其余的就跟平时训练模型一样的写法。
如果是做分类实验,经验来看,采用预训练的模型都比你自己从0开始训练的效果好。不信的话,可以自己对比对比。

接下来,对其中的部分细节进行进一步的探讨~~

探讨:如何确定模型最后一层的名字是什么

如上述代码里,在resnet这个模型中,它最后一层叫: fc
在Densenet模型中,最后一层叫: classifier
那我怎么知道它最后一层叫什么呢?

方法一: 查询源代码

最直接的办法就是进源代码里面去查看。

方法二: 查询模型的子模块名字

model = densenet169(pretrained=False)
for name in model.named_modules():print(name)


这种方法不仅可以知道模型最后一层的名字。
另外,nn.Linear(1664, classes),这括号里的1664是怎么知道的,这种方法还可以查到模型最后一层的输入节点是1664。

觉得有用点赞,关注,你的鼓励才是我坚持更新的动力♥️♥️♥️♥️

全网都在讲迁移学习,可你会写代码了吗?收藏我这个,10分钟开始你的迁移学习训练相关推荐

  1. java编程石头剪刀布图片_石头、剪刀、布!10分钟带你打开深度学习大门,代码已开源...

    原标题:石头.剪刀.布!10分钟带你打开深度学习大门,代码已开源 沉沉 发自 宇宙中心 量子位 出品 | 公众号 QbitAI 深度学习技术的不断普及,越来越多的语言可以用来进行深度学习项目的开发,即 ...

  2. 在学习计算机编程不写代码_使用代码创建:通过制作游戏来学习和教授计算机编程

    在学习计算机编程不写代码 Create with Code is Unity Education's new, free program for teaching and learning compu ...

  3. 在学习js的然后写代码的过程中我老是找不到思路怎么办?

    在学习js的然后写代码的过程中我老是找不到思路怎么办? 写的少了,边写边思考: 刚刚学习的阶段,还是要多写,多借鉴别人的代码. 转载于:https://www.cnblogs.com/helloy/p ...

  4. 初级java开发学习路线_成为初级全栈Web开发人员的10分钟路线图

    初级java开发学习路线 So you have started your journey into the world of web development. But what do you lea ...

  5. 白话AI:看懂深度学习真的那么难吗?初中数学,就用10分钟

    如果在这个人工智能的时代,作为一个有理想抱负的程序员,或者学生.爱好者,不懂深度学习这个超热的话题,似乎已经跟时代脱节了. 但是,深度学习对数学的要求,包括微积分.线性代数和概率论与数理统计等,让大部 ...

  6. 打游戏学习人工智能!不写代码|湾区人工智能

    栗子 乾明 发自 凹非寺  量子位 报道 | 公众号 QbitAI 撸猫.咖啡,玩游戏. 但我其实是在入门机器学习. 2019年最简单有趣的入门方式,就在这里: Steam高赞游戏,极度易上手. 现在 ...

  7. 石头、剪刀、布!10分钟带你打开深度学习大门,代码已开源

    本文首发于量子位. 随着深度学习技术的不断普及,越来越多的语言可以用来进行深度学习项目的开发,即使是JavaScript这样曾经只是在浏览器中运行的用于处理轻型任务的脚本语言. TensorFlow. ...

  8. python在哪里写代码比较适合-适合练习的10个Python项目,每个项目都不到500行代码...

    以下10个练手项目均摘录自一本尚未出版的 Python 神书<500 Lines or Less>,尽管没有出版,但其 review 版已在官方博客放出. 1. 实现一个网络爬虫 不多说, ...

  9. 人脸检测和识别(中文标记)完整项目源代码(基于深度学习+python3.6+dlib+PIL+CNN+(tensorflow、keras)10分钟实现 区分欢乐颂中人物详细图文教程和完整项目代码)

    转载请注明:https://blog.csdn.net/wyx100/article/details/80428424 效果展示 未完待续... 环境配置 win7sp1 python         ...

最新文章

  1. 提高C++性能的编程技术笔记:编码优化+测试代码
  2. 【总结整理】登录模块---摘自《人人都是产品经理》
  3. RedHat linux服务器对外开放指定端口
  4. Python做文本挖掘的情感极性分析
  5. jquery.autocomplete修改 实现键盘上下键 自动填充
  6. 螺旋测微器 flash_使用测微计收集应用程序指标
  7. 计算机操作员可以免考自考吗,计算机《职业资格证书》可以免考高
  8. Chapter7-10_Deep Learning for Question Answering (1/2)
  9. suse12安装详解
  10. Scala 与设计模式(六):Bridge 桥接模式
  11. 机器学习-吴恩达-笔记-2-逻辑回归
  12. ASP.NET连接数据库实现登录和注册
  13. 【优化求解】基于matlab差分进化算法求解函数极值问题【含Matlab源码 1199期】
  14. 会员分享几个平时看榜单常去的网站
  15. 干货分享:vue2.0做移动端开发用到的相关插件和经验总结
  16. SPR:SUPERVISED PERSONALIZED RANKING BASED ON PRIOR KNOWLEDGE FOR RECOMMENDATION
  17. uni-app实现文件管理器(Android)
  18. 【BZOJ4027】【HEOI2015】兔子与樱花 贪心
  19. 图解RAM结构与原理,系统内存的Channel、Chip与Bank
  20. 【云和恩墨大讲堂】高凯 | Oracle 12c 新特性-多租户的维护管理

热门文章

  1. 定位教程0---定位初介绍之均方根误差
  2. Cesium官方教程9--粒子系统
  3. Linux中图形用户界面与命令行模式互相切换
  4. VBA入门到进阶常用知识代码总结77
  5. java << 、>>理解
  6. Python3网络爬虫1:初识Scrapy
  7. codeblocks:: frotran 调用dll(详细)
  8. matlab数字音频处理实验报告,数字信号处理MATLAB实验1
  9. 考试系统怎么用?如何安装到电脑?
  10. 五种方案解决幂等问题