文章目录

  • 引言
  • 一、Transfer Learning & Model Finetune
  • 二、PyTorch中模型的Finetune

引言

  本节学习模型微调(Finetune)的方法,以及认识Transfer Learning(迁移学习)与Model Finetune之间的关系。

一、Transfer Learning & Model Finetune

  Transfer Learning(迁移学习):机器学习分支,研究源域(source domain)的知识如何应用到目标域(target domain)。我们将源任务所学习到的知识(权值)应用到目标任务当中,用来提升目标任务中的性能。

  那么,深度学习中的模型微调与迁移学习之间有什么关系?模型微调就是模型的迁移学习。为什么倾向于采用模型微调这个trick呢?这是由于在新任务当中数据量较小不足以训练一个较大的模型,因此,我们采用模型微调这个trick来辅助我们在新任务上训练一个较好的模型,让我们训练过程更快。这就类比于一个人学会了骑电动车,那么,他学自行车就比较快。那么,模型该如何迁移呢?以卷积神经网路为例,我们将特征提取部分认为是非常有共性的地方,我们可以原封不动的进行迁移,分类器的参数认为与具体的任务有关,通常我们需要去改变,分类器中的输出层通常需要进行改变。
  模型微调步骤:

  • 获取预训练模型参数
  • 加载模型(load_state_dict)
  • 修改输出层

  模型微调训练trick:

  • 固定预训练的参数(requires_grad =False or lr=0)

    # 冻结卷积层
    flag_m1 = 0
    # flag_m1 = 1
    if flag_m1:for param in resnet18_ft.parameters():param.requires_grad = False
    

    在非常小的数据量上,我们认为卷积核参数不能在更新了,因为数据量过小,如果继续更新,容易导致过拟合。

  • 将Features Extractor设置较小学习率,在分类器中的学习率比较大(params_group),优化器可以对不同的参数组设置不同的超参数,这里,我们就可以在不同部分设置不同的学习率

    # conv 小学习率
    flag = 0
    # flag = 1
    if flag:fc_params_id = list(map(id, resnet18_ft.fc.parameters()))     # 返回的是parameters的 内存地址base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())# 优化器设置不同的参数组,优化器中的元素是一个list,list中的每一个元素是字典optimizer = optim.SGD([{'params': base_params, 'lr': LR*0.1},   # 0{'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)
    

二、PyTorch中模型的Finetune

  以自然语言处理中BERT模型为例,这里展示模型的微调,具体见代码
BERT模型—3.BERT模型在ner任务上的微调

BERT模型—4.BERT模型在关系分类任务上的微调
BERT模型—7.BERT模型在句子分类任务上的微调(对抗训练)


如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!


PyTorch学习—20.模型的微调(Finetune)相关推荐

  1. Pytorch学习 - 保存模型和重新加载

    Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...

  2. 迁移学习和模型的微调

    迁移学习 对于新的数据,需要进行分类或者回归时,常用的方法是在以个大的数据集上(ImageNet包含120万张来自1000类别的数据集)进行预训练一个CNN模型,然后用这个预训练好的模型作为特征提取部 ...

  3. pytorch学习--UNet模型

    详细Unet网络结构可以查看Unet算法原理详解 深度网络训练之中需要大量的有标样本,Unet作者提供了一种新的训练方法,可以更有效的运用相应的有标样本,使网络即使通过少量的训练图片也可以进行更精确的 ...

  4. pytorch模型微调(Finetune)

    Transfer Learning & Model Finetune 模型微调 **Transfer Learning:**机器学习分支,研究源域(source domain)的知识如何应用到 ...

  5. 速成pytorch学习——7天模型层layers

    深度学习模型一般由各种模型层组合而成. torch.nn中内置了非常丰富的各种模型层.它们都属于nn.Module的子类,具备参数管理功能. 例如: nn.Linear, nn.Flatten, nn ...

  6. 深度学习【使用pytorch实现基础模型、优化算法介绍、数据集的加载】

    文章目录 一 Pytorch完成基础模型 1. Pytorch完成模型常用API 1.1 `nn.Module` 1.2 优化器类 1.3 损失函数 1.4 线性回归完整代码 2. 在GPU上运行代码 ...

  7. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  8. 神经网络学习小记录19——微调VGG分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录19--微调VGG分类模型训练自己的数据(猫狗数据集) 注意事项 学习前言 什么是VGG16模型 VGG模型的复杂程度 训练前准备 1.数据集处理 2.创建Keras的VGG模型 3 ...

  9. Pytorch通用图像分类模型(支持20+分类模型),直接带入数据就可训练自己的数据集,包括模型训练、推理、部署。

    Pytorch-Image-Classifier-Collection 介绍 ============================== 支持多模型工程化的图像分类器 =============== ...

  10. 系统学习大模型的20篇论文

    [引子]"脚踏实地,仰望星空", 知其然还要知其所以然.读论文是一条重要的途径,这一篇文章https://magazine.sebastianraschka.com/p/under ...

最新文章

  1. 4一20ma电流有源与无源区别_信号隔离安全栅与信号隔离器区别!
  2. Putty中文乱码的解决方法
  3. 一行Java代码竟能获取tomcat的绝对路径
  4. AI理论知识整理(10)-向量空间与矩阵(1)
  5. 用python处理excel 数据分析_Python应用实现处理excel数据过程解析
  6. AMD 发布第二代EPYC处理器,重新定义数据中心新标准
  7. MySQL之Variables(变量)
  8. python32什么意思_“python2”和“python3”有什么区别?
  9. 01-10 Linux-bash编程
  10. 古代的碎银子是怎么来的?
  11. 火狐浏览器中打开java_Ubuntu下通过Firefox Opera Chromium浏览器直接执行java应用程序(打开java jnlp文件)实现在服务器远程虚拟控制台完成远程管理的方法...
  12. [PHP] 数据结构-二叉树的创建PHP实现
  13. FAQ系列 | mysqldump选项之skip-opt
  14. UVA 10651 - Pebble Solitaire
  15. 强化学习(RL)QLearning算法详解
  16. QA | 关于手持式频谱仪,您想了解的那些技术问题(一)
  17. Vue中使用axio跨域请求外部WebService接口
  18. opencv思维导图
  19. 清华、商汤提出SIM方法,让自监督学习兼顾语义对齐与空间分辨能力
  20. Iterator 的用法

热门文章

  1. bootstrap 树
  2. phpmailer 与 mail
  3. HDU 4786 生成树 并查集+极大极小值 黑白边 确定选择白边的数量
  4. zTree树形控件讲解
  5. Part2-HttpClient官方教程-Chapter5-流利的API
  6. (dp)openjudge 复杂的整数划分问题
  7. 安装10gR2的硬件要求
  8. Handler机制使用时候一些问题记录
  9. SVN server
  10. CentreonMonitoringEvent Logs没有结果的解决方法