个人理解,迁移学习可以分为三类:

第一类:以训练好的模型参数为基础,对所有参数进行继续优化.

即,先在别的训练数据集上训练模型,达到一定训练标准之后,用当前的数据集继续进行训练.

第二类:将已经训练好的模型当作特征提取器,仅对模型中的部分层的参数进行优化(或继续优化).

此时,在加载模型后,需要通过对模型的所有参数的"requires_grad"属性进行设置,使其变成不再被优化的参数,即设置模型的所有参数的"requires_grad=False"

然后,指定需要训练的层的参数为"requires_grad=True",然后进行训练.

这一类迁移常见的情况是对模型最后部分的全连接层进行重新训练.除了对模型最后部分的全连接层进行重新训练之外,还可以重新定义最后的全连接层(或其它结构).

第三类:是将训练好的模型完全不动,仅用其输出作为新的模型的输入,此时的pretrained_model被完全当作特征提取器.

ps:模型参数的"requires_grad"属性的改变是容易的,相对难的是如何准确地获取需要修改的参数.这需要自己对pretrained_model的网络结构的熟悉与理解,然后才能调用相关参数.

----------------------------------------------------------------------------------------------------------------------

在pytorch中,实现迁移学习可按以下伪代码执行:

其中2) 3) 6)为关键步骤

伪代码:

1) new_model = pretrained_model  # 加载已有模型
2) for param in new_model.parameters():  # 将已有模型的已有参数设置为不可修改param.requires_grad = False
3) new_model.fc = nn.Linear(dims_input_feature,dims_output_feature)  # 重新构造/修改某一层的网络结构
4) new_model = new_model.to(device)#将新的模型放到device上(如GPU上)
5) criterion = nn.CrossEntropyLoss() # 定义损失函数
6) optimizer_new_model = optim.SGD(new_model.fc.parameters(),lr=0.001,momentum=0.9)#定义优化器,在定义优化器的时候指定了需要优化的参数(new_model.fc.parameters())
7) new_model = train_model(new_model,criterion,optimizer_new_model,num_epochs=25)#训练模型.train_model是自己定义的训练函数

疑问:如何获取某一层上的所有参数

解答: 用 model.fc.parameters(),其中fc可以是任意一层

疑问: fc是如何定义的?"fc"是在定义网络结构时人为给定的"层的名称"吗?如果是这样,那就可以直接用"层的名称"直接调用该层了,确实很方便,那么,如何给某一层网络"命名"呢?

解答:在pytorch中与在tensorflow中在命名网络层时策略是不一样的.在tensorflow中需要有专门的命令语句来对某一层或某几层进行命名,并通过命名来管理和调用相关的层的相关参数.

而在pytorch中不是这样的,pytorch是通过定义一个nn.Module的子类来定义网络结构以及前向算法.

在pytorch中每种网络结构都已经有了其对应的类,如:torch.nn.Conv2d,torch.nn.Linear等,都已经被定义成一个类.

因此自己定义的网络的每一层都会被定义成对应结构类的一个实例化对象.

因此,在pytorch中,网络结构的每一层是以类对象的形式存在.可以在自己的网络结构的类的__init__函数中,声明所有的需要用到的层,当作自己的网络结构类的一个属性(元素),

因此,pytorch调用每一层的技术手段是:用类对象调用类的属性元素的方法.不是通过string命名空间的方法.

参考:https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

Pytorch 的迁移学习的理解相关推荐

  1. 使用PyTorch进行迁移学习

    概述 迁移学习可以改变你建立机器学习和深度学习模型的方式 了解如何使用PyTorch进行迁移学习,以及如何将其与使用预训练的模型联系起来 我们将使用真实世界的数据集,并比较使用卷积神经网络(CNNs) ...

  2. pytorch之迁移学习

    文章目录 1.导入相关的包 2.加载数据 3.可视化部分图像数据 4.训练模型 5.可视化模型的预测结果 6.场景1:微调ConvNet 7.场景2:ConvNet作为固定特征提取器 实际中,基本没有 ...

  3. Pytorch实现迁移学习

    迁移学习 迁移学习是一种机器学习的方法,指的是一个预训练的模型被重新用在另一个任务中,它专注于存储已有问题的解决模型,并将其利用在其他不同但相关问题上.例如我在A的场景下训练了一个模型,而B.C.D等 ...

  4. 【PyTorch】迁移学习:基于 VGG 实现的光明哨兵与破败军团分类器

    文章目录 简述. 环境配置. PyTorch代码. 导入第三方库. 使用 GPU. 加载数据. 定义可视化函数. 加载预训练模型. 冻结特征层. 修改输出层. 定义优化器. 定义训练函数. 训练过程. ...

  5. Resnet152对102种花朵图像分类(PyTorch,迁移学习)

    目录 1.介绍 1.1.项目数据及源码 1.2.数据集介绍 1.3.任务介绍 1.4.ResNet网络介绍 2.数据预处理 3.展示数据 4.进行迁移学习 4.1.训练全连接层 4.2.训练所有层 5 ...

  6. 【IM】关于迁移学习的理解

    迁移学习:应用其他学习任务的信息来提升当前学习任务的求解精度,前提是两个学习任务之间存在关联. 这里对于半监督迁移学习给出了协变量移位和类别不平衡两种场景.

  7. pytorch添加迁移学习

    # 参数设置(指定用第几轮的预训练权重) parser = argparse.ArgumentParser(description="PyTorch Net") parser.ad ...

  8. 【Pytorch实战6】一个完整的分类案例:迁移学习分类蚂蚁和蜜蜂(Res18,VGG16)

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch官方文档 本文是采用pytorch进行迁移学习的实战演练,实战目的是为了进一步学习和熟悉pyt ...

  9. PyTorch迁移学习

    PyTorch迁移学习 实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集].通常的做法是在一个很大的数据集上进 ...

最新文章

  1. Linux bind-utils
  2. 给力!一行代码躺赚普通程序员 10 年薪资!
  3. python/pytorch中的一些函数介绍
  4. Matlab图像标题_title
  5. 鲸鱼优化算法_鲸鱼优化算法:一种群体智能最优化方法
  6. windows10 快速切换网络适配器
  7. 2022韦莱韬悦人力资源管理季刊
  8. 微博长图快速排版生成工具
  9. mysql报1032_MySQL SQL_ERROR 错误号 1032解决办法
  10. nexus仓库数据完整迁移到新的nexus仓库
  11. 落地SQL审核的迭代思路
  12. mysql存emoji_如何在MySQL中存储emoji?
  13. 利用nodejs对接232接口电子秤
  14. Python计算图片之间的相似度
  15. HTML实战案例素材1:制作树形菜单页面
  16. 山东省第八届acm大赛 F题 (SDUT 3898)
  17. 移动GPU:高通Adreno图形处理器全解析
  18. vue新窗口打开路由
  19. vcm服务器如何修改端口,产品技术-iMC VCM虚拟连接管理-新华三集团-H3C
  20. ​Beaglebone Black教程Beaglebone Black的引脚分配

热门文章

  1. php能连接动易吗,动易CMS数据转成dedecms的php程序
  2. OpenCV图像处理(18)——文件夹下所有图像转灰度(14-15综合)
  3. 对字符串进行折半查找c语言,C语言:编写折半查找函数
  4. excel表格vlookup函数怎么用_只会Vlookup函数Out了!Excel所有查找公式全在这儿(共16大类)...
  5. linux内核mtd分区,嵌入式Linux MTD分区调整(MX28)
  6. confluence启动不起来_“一键启动”只能点火?还有这5个“隐藏”功能,你都知道吗?...
  7. Python数据结构与算法(4.1)——递归
  8. linux平均负载什么意思_在Linux中什么是平均负载?
  9. jave类命名_Java重命名文件– Jave移动文件
  10. c ++递归算法数的计数_计数排序算法–在C / C ++中实现的想法