参考

9.2 微调

在前面得一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型。我们还描述了学术界当下使用最广泛规模图像数据集ImageNet,它有超过1000万的图像和1000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。

假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常用的椅子,为椅子拍摄1000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。这个椅子数据集虽然可能比Fashion-MNIST数据集要庞大,但样本仍然不及ImageNet数据集中样本数的十分之一。这可能会导致适用于ImageNet数据集的复杂模型在这个椅子数据集上过拟合。同时,因为数据量有限,但其成本仍热不可忽略。

另一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

本节我们介绍迁移学习中的一种常用技术: 微调(fine tuning)。如图9.1所示,微调由以下4步构成。

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

9.2.1 热狗识别

接下来我们来实践一个具体的例子: 热狗识别。我们将基于一个小数据集在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张包含热狗和不包含热狗的图像。我们使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需要的包或模块。torchvision的models包提供了常用的预训练模型。如果希望获取更多的预训练模型,可以使用pretrained-models.pytorch仓库.

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import osimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

9.2.1.1 获取数据集

我们使用的热狗数据集是从网上抓取的,它包含1400张含热狗的正类图像,和同样多包含其他食品的负类图像。各类的1000张图像被用于训练,其余则用于测试。

我们首先将压缩后的数据集下载到路径data_dir之下,然后在该路径将下载好的数据集解压,得到两个文件夹hotdog/trainhotdog/test。这两个文件夹下面均有hotdognot-hotdog两个类别文件夹,每个类别文件夹里面是图像文件。

data_dir = "C:/Users/1/Datasets"
os.listdir(os.path.join(data_dir, 'hotdog'))

我们创建两个ImageFolder实例来分别读取训练数据集和测试数据集中的所有图像文件

train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))
test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))

下面画出前8张正类图像和最后8张负类图像。

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i- 1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)


在训练时,我们先从图像中裁剪随机大小和随机宽高比的一块随机区域,然后将该区域缩放为高和宽均为224像素的输入。测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高和宽均为224像素的中心区域作为输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化:每个数值减去通道所有数值的平均值,再除以该通道所有数值的标准差作为输出。

注: 使用pretrained-models仓库时,一定要对图像进行相应的预处理

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406 ], std = [0.229, 0.224, 0.225])
train_augs = transforms.Compose([transforms.RandomResizedCrop(size= 224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize
])
test_augs = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize
])

9.2.1.2 定义和初始化模型

我们使用在ImageNet数据集上预训练的ResNet-18作为源模型。这里指定pretrained=True来自动下载并记载预训练的模型参数。在第一次使用时需联网下载模型参数

pretrained_net  = models.resnet18(pretrained=True)

打印源模型的成员变量fc。作为一个全连接层,它将ResNet最终的全局平均池化层输出变成ImageNet数据集上1000类的输出

print(pretrained_net.fc)


可见此时pretrained_net最后的输出个数等于目标数据集的类别数1000。所以我们应该将最后的fc修改成我们需要输出类别数:

pretrained_net.fc = nn.Linear(512, 2)

此时,pretrained_netfc层就随机初始化了,但是其他层依然保存着预训练得到的参数。由于是在很大的ImageNet数据集上预训练的,所以参数已经足够好,因此一般只需使用较小的学习率来微调这些参数,而fc中的随机参数一般需要更大的学习率从头训练。PyTorch可以方便的对模型的不同部分设置不同的学习参数,我们在下面代码中将fc的学习率设置为已经预训练过的部分的10倍

output_params = list(map(id, pretrained_net.fc.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())lr = 0.01
optimizer = optim.SGD([{'params': feature_params},{'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],lr = lr, weight_decay=0.001)

9.2.1.3 微调模型

我们先定义一个使用微调的训练函数train_fine_tuning以便多次调用。

def train_fine_tuning(net, optimizer, batch_size = 128, num_epochs = 15):train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform = train_augs), batch_size, shuffle=True)test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs), batch_size)loss = torch.nn.CrossEntropyLoss()d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)

根据前面的设置,我们将以10倍的学习率从头训练目标模型的输出层参数。

train_fine_tuning(pretrained_net, optimizer)


作为对比,我们定义一个相同的模型,但将它的所有模型参数都初始化为随机值。由于整个模型都需要从头训练,我们可以使用较大的学习率。

scratch_net = models.resnet18(pretrained=False, num_classes=2)
lr = 0.1
optimizer  = optim.SGD(scratch_net.parameters(), lr = lr, weight_decay = 0.001)
train_fine_tuning(scratch_net, optimizer)

[pytorch、学习] - 9.2 微调相关推荐

  1. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  2. PyTorch学习笔记(六):PyTorch进阶训练技巧

    PyTorch实战:PyTorch进阶训练技巧 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: P ...

  3. PyTorch学习记录——PyTorch进阶训练技巧

    PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...

  4. PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

    您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...

  5. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  6. pytorch 学习中安装的包

    记录pytorch学习遇到的包 1.ImportError: cannot import name 'PILLOW_VERSION' torchvision 模块内import pillow的时候发现 ...

  7. pytorch学习笔记(二):gradien

    pytorch学习笔记(二):gradient 2017年01月21日 11:15:45 阅读数:17030

  8. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

  9. 2_初学者快速掌握主流深度学习框架Tensorflow、Keras、Pytorch学习代码(20181211)

    初学者快速掌握主流深度学习框架Tensorflow.Keras.Pytorch学习代码 一.TensorFlow 1.资源地址: 2.资源介绍: 3.配置环境: 4.资源目录: 二.Keras 1.资 ...

  10. Pytorch学习 - 保存模型和重新加载

    Pytorch学习 - 保存和加载模型 1. 3个函数 2. 模型不同后缀名的区别 3. 保存和重载模型 参考资料: Pytorch官方文档链接 某博客 1. 3个函数 torch.save() : ...

最新文章

  1. 7.3. UUID()
  2. sublime text3中文文件名显示为框框,怎么解决
  3. oracle exp执行失败,Oracle EXP-EXP-00091的错误原因及处理方法
  4. 心理医生给女人的忠告
  5. Web 高效开发必备的 PHP 框架
  6. python面向对象有什么用_Python 中的面向对象没有意义
  7. 二十九、PHP框架Laravel学习笔记——Debugbar 调试器
  8. 【译】数据结构中关于树的一切(java版)
  9. MySQL常用函数 一
  10. unity 中文_Unity无情大爆料时间Unity3D的脚本语言
  11. svn 同步备份的所有问题,亲测可用
  12. PHP判断pc和移动端跳转,JS判断是PC还是移动端浏览器,并根据不同的终端跳转到不同的网址...
  13. oracle突然挂了服务全部消失,OracleService服务不见了|OracleServiceXE服务没有了
  14. 洛谷P1079 Vigenère 密码
  15. 关于mysql叙述中错误的是什么_以下关于MySQL的叙述中,错误的是( )。_学小易找答案...
  16. Scala笔记整理(二):Scala数据结构—数组、map与tuple
  17. 这个世界是怎么了?做商业软件的怎么越来越流氓了?
  18. Linux内核加载f2fs,安装f2fs工具以使用f2fs文件系统作为引导deepin系统分区
  19. linux tex文件编译,用latexmk编译XeLaTeX tex文件
  20. RS-232通信接口

热门文章

  1. 项目中使用 java函数式编程_函数式编程在Java8中使用Lambda表达式进行开发
  2. python中的类装饰器应用场景_Python 自定义装饰器使用写法及示例代码
  3. java复制的函数会报错,2 面试题之面向对象
  4. oracle ref游标用法,[置顶] Oracle 参照游标(SYS_REFCURSOR)使用
  5. java set去重复元素_java List去掉重复元素的几种方式
  6. OpenCV学习——轮廓检测
  7. 一步一步教你实现iOS音频频谱动画(一)
  8. PAT 1007 Maximum Subsequence Sum
  9. 《剑指offer》第四十三题(从1到n整数中1出现的次数)
  10. HDU 4035 Maze