系列文章 \text{\bf 系列文章} 系列文章

  1. 计算机视觉系列(一)——CNN基础
  2. 计算机视觉系列(二)——AlexNet
  3. 计算机视觉系列(三)——VGG与NiN
  4. 计算机视觉系列(四)——GoogLeNet
  5. 计算机视觉系列(五)——ResNet的实现
  6. 计算机视觉系列(六)——图像增强
  7. 计算机视觉系列(七)——迁移学习

目录

  • 一、迁移学习与微调
  • 二、如何寻找预训练的模型?
  • 三、初始化模型
  • 四、将 ResNet 迁移到 CIFAR-10 上

一、迁移学习与微调

ImageNet 数据集大约有 120w 个样本,类别数为 1000;MNIST 数据集只有 6w 个样本,类别数为 10。然而,我们平常接触到的数据集的规模通常在这两者之间。

假如我们想识别图片中不同类型的椅子,然后向用户推荐购买链接。 一种可能的方法是首先识别 100 把普通椅子,为每把椅子拍摄 1000 张不同角度的图像,然后在收集的图像数据集上训练一个分类模型。 尽管这个椅子数据集可能大于 Fashion-MNIST 数据集,但实例数量仍然不到 ImageNet 中的十分之一。 适合 ImageNet 的复杂模型可能会在这个椅子数据集上过拟合。 此外,由于训练样本数量有限,训练模型的准确性可能无法满足实际要求。

为了解决上述问题,一个显而易见的解决方案是收集更多的数据。 但是,收集和标记数据可能需要大量的时间和金钱。 例如,为了收集 ImageNet 数据集,研究人员花费了数百万美元的研究资金。 尽管目前的数据收集成本已大幅降低,但这一成本仍不能忽视。

另一种解决方案是应用迁移学习(transfer learning)将从源数据集学到的知识迁移到目标数据集。 例如,尽管 ImageNet 数据集中的大多数图像与椅子无关,但在此数据集上训练的模型可能会提取更通用的图像特征,这有助于识别边缘、纹理、形状和对象组合。 这些类似的特征也可能有效地识别椅子。

迁移学习中的一个常见的技巧是微调(fine-tuning),它包括以下四个步骤:

  • 在源数据集(例如 ImageNet)上预训练神经网络模型,即源模型
  • 创建一个新的神经网络模型,即目标模型。这将复制源模型上的所有模型设计及其参数(输出层除外)。我们假定这些模型参数包含从源数据集中学到的知识,这些知识也将适用于目标数据集。
  • 向目标模型添加输出层,其输出数是目标数据集中的类别数。然后随机初始化该层的模型参数。
  • 在目标数据集(如椅子数据集)上训练目标模型。输出层将从头开始进行训练,而所有其他层的参数将根据源模型的参数进行微调。

因为是微调,我们通常选用较小的学习率,例如 5 × 1 0 − 4 5\times10^{-4} 5×10−4,且训练的 epoch 数也要少。

二、如何寻找预训练的模型?

这里提供两种方案。

  • 使用 PyTorch 官方的预训练模型(链接),这些模型均在 ImageNet 上完成了预训练。使用时需要设置 pretrained 参数为 True
  • 使用 timm 包(github链接、timm文档)。

本文接下来的部分将使用第一种方案。

三、初始化模型

我们使用在 ImageNet 数据集上预训练的 ResNet-18 作为源模型,重新设置输出层并将其随机初始化:

net = torchvision.models.resnet18(pretrained=True)
net.fc = nn.Linear(512, 10)  # 设置为10是因为接下来要面对十分类任务
nn.init.xavier_uniform_(net.fc.weight)

四、将 ResNet 迁移到 CIFAR-10 上

接下来,我们将 ResNet-18 迁移到 CIFAR-10 数据集上并进行微调。

需要注意的是,所有预训练的模型在接收输入时,必须将它们以下面的方式进行归一化:

transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

处理 CIFAR-10 数据集:

normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize,
])
test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize,
])
train_data = torchvision.datasets.CIFAR10('/mnt/mydataset', train=True, transform=train_augs, download=True)
test_data = torchvision.datasets.CIFAR10('/mnt/mydataset', train=False, transform=test_augs, download=True)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=6)
test_loader = DataLoader(test_data, batch_size=128, num_workers=6)

设置学习率为 5 × 1 0 − 4 5\times10^{-4} 5×10−4,只训练 10 个 epoch

e = E(train_loader, test_loader, net, 10, 5e-4)
e.main()

在 NVIDIA GeForce RTX 3090 上的训练/测试结果如下:

Epoch 10
--------------------------------------------------
Train Avg Loss: 0.866763, Train Accuracy: 0.696940
Test  Avg Loss: 0.498987, Test  Accuracy: 0.828300--------------------------------------------------
3273.3 samples/sec
--------------------------------------------------Done!

通过与这篇文章进行比较可以看出,使用迁移学习的方法后,ResNet-18 在测试集上的精度更胜一筹(虽然胜的不多)。并且这仅仅是训练了 10 个 epoch,从测试集的损失函数曲线变化来看,继续训练可以进一步提升精度。

计算机视觉系列(七)——迁移学习相关推荐

  1. 4个计算机视觉领域用作迁移学习的模型

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达本文转自|AI公园 导读 使用SOTA的预训练模型来通过迁移学习解决 ...

  2. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  3. 计算机视觉3.3 :迁移学习之图像特征向量提取与运用

    迁移学习之图像特征向量提取与运用 ​ 本篇文章将要讨论的是关于计算机视觉中迁移学习的概念,一种能够利用预先训练好的模型,从它之前训练的数据集之外的数据集进行学习的能力. ​ 举个例子来说: ​ 现有A ...

  4. 迁移学习全面指南:概念、项目实战、优势、挑战

    https://www.toutiao.com/a6685571761766924803/ 2019-04-30 14:32:56 前言 希望大家从头到尾的去看,而不是看完前面一点点感觉有用然后收藏, ...

  5. 迁移学习之深度迁移学习

    深度迁移学习即采用深度学习的方法进行迁移学习,这是当前深度学习的一个比较热门的研究方向. 深度学习方法对非深度方法两个优势: 一.自动化地提取更具表现力的特征: 二.满⾜了实际应用中的端到端 (End ...

  6. Tensorflow2.0实战练习之猫狗数据集(包含自定义训练和迁移学习)

    最近在学习使用Tenforflow2.0,写下这篇文章,用来帮助和我一样的初学者,文章中如果存在某些问题,还希望各位指出. 目录 数据集介绍 数据处理及增强 VGG模型介绍 模型搭建 训练及结果展示 ...

  7. 深度学习与计算机视觉系列(9)_串一串神经网络之动手实现小例子

    深度学习与计算机视觉系列(9)_串一串神经网络之动手实现小例子 作者:寒小阳  时间:2016年1月.  出处:http://blog.csdn.net/han_xiaoyang/article/de ...

  8. 深度学习与计算机视觉系列(8)_神经网络训练与注意点

    深度学习与计算机视觉系列(8)_神经网络训练与注意点 作者:寒小阳  时间:2016年1月.  出处:http://blog.csdn.net/han_xiaoyang/article/details ...

  9. 深度学习与计算机视觉系列(4)_最优化与随机梯度下降\数据预处理,正则化与损失函数

    1. 引言 上一节深度学习与计算机视觉系列(3)_线性SVM与SoftMax分类器中提到两个对图像识别至关重要的概念: 用于把原始像素信息映射到不同类别得分的得分函数/score function 用 ...

最新文章

  1. 045_CSS3过渡
  2. H5解码H264实时视频流
  3. python掷骰子_用于掷骰子的Python程序(2人骰子游戏)
  4. 数据科学学习心得_学习数据科学
  5. android 设置setmultichoiceitems设置初始化勾选_Linux内核启动:虚拟盘空间设置和内存管理结构初始化...
  6. Boostnote跨平台 Markdown 编辑器
  7. SharePoint 2010 中有个新的列表模板“导入电子表格”可以直接导入Excel数据并创建为列表 ....
  8. maven环境、本地仓储配置(下载安装)idea配置maven
  9. 遥感常用数据下载链接
  10. 毕向东java笔记ppt,毕向东java学习笔记.doc
  11. 我们从UNIX之父丹尼斯身上学到了什么
  12. c语言浮点型变量字母表示,C语言基础学习基本数据类型-浮点型
  13. 用 emacs 浏览 C/C++ 项目
  14. JavaScript事件函数
  15. xeon e5-2400 系列处理器能做四路服务器吗?,至强处理器E5-2400系列双路云服务器推出...
  16. linux中 777,755等用户权限说明
  17. 抓包技术(浏览器APP小程序PC应用)
  18. Vuforia提高识别以及稳定性方法总结
  19. sas数据的中国地图 湿地
  20. red hat linux的phythmbox音乐播放器乱码,Outlook中设置hotmail

热门文章

  1. React中的PureComponent,refs
  2. 搜狗大变动!搜狗收录接下来怎么做?
  3. 原来变压器可以用来调节阻抗匹配!
  4. 华为虚拟化FusionCompute知识点总结
  5. 抖音关键词排名优化技巧,手把手教你怎样优化抖音关键词
  6. Qt之 QStringLiteral
  7. 计算机二级考试公网入口和教育网入口,考试入口
  8. 传说中的补丁比较...很好玩啊..
  9. 【OAK开源项目教程】opencv+python实现测量包装盒尺寸和体积
  10. Freeswitch 常用命令