微调(fine tuning)

首先举一个例子,假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。这个椅子数据集虽然可能比Fashion-MNIST数据集要庞大,但样本数仍然不及ImageNet数据集中样本数的十分之一。这可能会导致适用于ImageNet数据集的复杂模型在这个椅子数据集上过拟合。同时,因为数据量有限,最终训练得到的模型的精度也可能达不到实用的要求

为了应对上述问题,一个显而易见的解决办法是收集更多的数据。然而,收集和标注数据会花费大量的时间和资金。例如,为了收集ImageNet数据集,研究人员花费了数百万美元的研究经费。虽然目前的数据采集成本已降低了不少,但其成本仍然不可忽略。

另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

接下来就需要介绍迁移学习中的一种常用技术:微调(fine tuning)。如下图所示,微调由以下4步构成。

  • 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  • 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  • 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  • 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层而其余层的参数都是基于源模型的参数微调得到的

值得注意,但是并不难理解的是,当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

fine tuning的具体例子

接下来我们来实践一个具体的例子:热狗识别。我们将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张包含热狗和不包含热狗的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需的包或模块。torchvision的models包提供了常用的预训练模型。如果希望获取更多的预训练模型,可以使用使用pretrained-models.pytorch仓库。

%matplotlib inline
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
import os
from matplotlib import pyplot as plt
import timedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
获取数据集

点击下载热狗数据集,它含有1400张包含热狗的正类图像,和同样多包含其他食品的负类图像。各类的1000张图像被用于训练,其余则用于测试。

我们首先将压缩后的数据集下载到路径data_dir之下,然后在该路径将下载好的数据集解压,得到两个文件夹hotdog/train和hotdog/test。这两个文件夹下面均有hotdog和not-hotdog两个类别文件夹,每个类别文件夹里面是图像文件。

data_dir = './data_set'
os.listdir(os.path.join(data_dir, "hotdog")) # ['train', 'test']

创建两个ImageFolder实例来分别读取训练数据集和测试数据集中的所有图像文件。

train_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/train'))
test_imgs = ImageFolder(os.path.join(data_dir, 'hotdog/test'))

画出前8张正类图像和最后8张负类图像。可以看到,它们的大小和高宽比各不相同。

def show_images(imgs, num_rows, num_cols, scale=2):figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)for i in range(num_rows):for j in range(num_cols):axes[i][j].imshow(imgs[i * num_cols + j])axes[i][j].axes.get_xaxis().set_visible(False)axes[i][j].axes.get_yaxis().set_visible(False)return axeshotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]
show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4);

  • 在训练时,我们先从图像中裁剪出随机大小和随机高宽比的一块随机区域,然后将该区域缩放为高和宽均为224像素的输入。
  • 测试时,我们将图像的高和宽均缩放为256像素,然后从中裁剪出高和宽均为224像素的中心区域作为输入。
  • 此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化:每个数值减去该通道所有数值的平均值,再除以该通道所有数值的标准差作为输出。
    指定RGB三个通道的均值和方差来将图像通道归一化
# 指定RGB三个通道的均值和方差来将图像通道归一化
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_augs = transforms.Compose([transforms.RandomResizedCrop(size=224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize])test_augs = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize])

需要注意的是:在使用预训练模型时,一定要和预训练时作同样的预处理

定义和初始化模型

使用在ImageNet数据集上预训练的ResNet-18作为源模型。这里指定pretrained=True来自动下载并加载预训练的模型参数。在第一次使用时需要联网下载模型参数。

pretrained_net = models.resnet18(pretrained=True)

不管你是使用的torchvision的models还是pretrained-models.pytorch仓库,默认都会将预训练好的模型参数下载到你的home目录下.torch文件夹。你可以通过修改环境变量TORCH_MODEL_ZOO来更改下载目录: export TORCH_MODEL_ZOO="/local/pretrainedmodels" 。

另外比较常使用的方法是,在其源码中找到下载地址直接浏览器输入地址下载,下载好后将其放到环境变量$TORCH_MODEL_ZOO所指文件夹即可,这样比较快

打印源模型的成员变量fc。作为一个全连接层,它将ResNet最终的全局平均池化层输出变换成ImageNet数据集上1000类的输出。

print(pretrained_net.fc)

如果你使用的是其他模型,那可能没有成员变量fc(比如models中的VGG预训练模型),所以正确做法是查看对应模型源码中其定义部分,这样既不会出错也能加深我们对模型的理解。

可见此时pretrained_net最后的输出个数等于目标数据集的类别数1000。所以我们应该将最后的fc成修改我们需要的输出类别数:

pretrained_net.fc = nn.Linear(512, 2)
print(pretrained_net.fc)

此时,pretrained_net的fc层就被随机初始化了,但是其他层依然保存着预训练得到的参数。由于是在很大的ImageNet数据集上预训练的,所以参数已经足够好,因此一般只需使用较小的学习率来微调这些参数,而fc中的随机初始化参数一般需要更大的学习率从头训练

PyTorch可以方便的对模型的不同部分设置不同的学习参数,我们在下面代码中将fc的学习率设为已经预训练过的部分的10倍。

output_params = list(map(id, pretrained_net.fc.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())lr = 0.01
optimizer = optim.SGD([{'params': feature_params},{'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],lr=lr, weight_decay=0.001)
微调模型

先定义一个使用微调的训练函数train_fine_tuning以便多次调用。

def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)batch_count = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y) optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
def train_fine_tuning(net, optimizer, batch_size=128, num_epochs=5):train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform=train_augs),batch_size, shuffle=True)test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs),batch_size)loss = torch.nn.CrossEntropyLoss()train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)
开始训练(微调)
train_fine_tuning(pretrained_net, optimizer)

理论上可以执行了哈,但是我的pc显存太小,溢出了,建议使用显存10G以上的GPU进行训练

  • 迁移学习将从源数据集学到的知识迁移到目标数据集上。微调是迁移学习的一种常用技术。
  • 目标模型复制了源模型上除了输出层外的所有模型设计及其参数,并基于目标数据集微调这些参数。而目标模型的输出层需要从头训练。
  • 一般来说,微调参数会使用较小的学习率,而从头训练输出层可以使用较大的学习率。

pytorch深度学习-微调(fine tuning)相关推荐

  1. 计算机视觉之迁移学习中的微调(fine tuning)

    现在的数据集越来越大,都是大模型的训练,参数都早已超过亿级,面对如此大的训练集,绝大部分用户的硬件配置达不到,那有没有一种方法让这些训练好的大型数据集的参数,迁移到自己的一个目标训练数据集当中呢?比如 ...

  2. 实战例子_Pytorch官方力荐新书《Pytorch深度学习实战指南》pdf及代码分享

    PyTorch是目前非常流行的机器学习.深度学习算法运算框架.它可以充分利用GPU进行加速,可以快速的处理复杂的深度学习模型,并且具有很好的扩展性,可以轻松扩展到分布式系统.PyTorch与Pytho ...

  3. pytorch深度学习_用于数据科学家的深度学习的最小pytorch子集

    pytorch深度学习 PyTorch has sort of became one of the de facto standards for creating Neural Networks no ...

  4. PyTorch深度学习笔记之四(深度学习的基本原理)

    本文探讨深度学习的基本原理.取材于<PyTorch深度学习实战>一书的第5章.也融入了一些自己的内容. 1. 深度学习基本原理初探 1.1 关于深度学习的过程的概述 给定输入数据和期望的输 ...

  5. pytorch深度学习_了解如何使用PyTorch进行深度学习

    pytorch深度学习 PyTorch is an open source machine learning library for Python that facilitates building ...

  6. PyTorch深度学习训练可视化工具tensorboardX

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 之前笔者提到了PyTorch的专属可视化工具visdom,参看Py ...

  7. PyTorch深度学习入门与实战(案例视频精讲)

    作者:孙玉林,余本国 著 出版社:中国水利水电出版社 品牌:智博尚书 出版时间:2020-07-01 PyTorch深度学习入门与实战(案例视频精讲)

  8. PyTorch深度学习入门

    作者:曾芃壹 出版社:人民邮电出版社 品牌:iTuring 出版时间:2019-09-01 PyTorch深度学习入门

  9. PyTorch深度学习

    作者:[印度] 毗湿奴·布拉马尼亚(Vishnu Subramanian) 著,王海玲,刘江峰 译 出版社:人民邮电出版社 品牌:异步图书 出版时间:2019-04-01 PyTorch深度学习

最新文章

  1. 黑白世界,感受不同的旅行...
  2. Mac OS—苹果搭建Android开发环境
  3. 中国互联网+政务建设发展现状及市场规模预测报告2022-2027年版
  4. GridView格式化数据失效
  5. EasyUI中Window窗口的简单使用
  6. Redis02_数据模型初识
  7. Java 集合系列目录(Category)
  8. Win32窗体应用程序如何添加资源文件?
  9. python绘制汉字_OpenCV Python 绘制中文字
  10. 微信发虎年新春贺词领福袋:游戏皮肤、QQ音乐VIP、现金红包等
  11. 第二章:09流程控制[2switch]
  12. c# WinForm开发 DataGridView控件的各种操作总结(单元格操作,属性设置)
  13. VMware-NAT连接网络
  14. 仿高德地图点亮城市html,高德地图怎么点亮城市_高德地图点亮城市教程_3DM手游...
  15. ios 推送通知服务证书不受信任(Apple Push Service certificate is not trusted)
  16. html基础、h5c3高级c3动画 、 JavaScript初高级、css预处理器和git 部分面试题
  17. 设备描述符请求失败解决
  18. mysql table plugin,MySql报错Table mysql.plugin doesn’t exist的解决方法
  19. 2019-12-21(98)
  20. redisTemplate执行lua脚本

热门文章

  1. hash和一致性hash
  2. hbase动态更改行键设计_Hadoop HBase概念学习系列之优秀行键设计(十六)
  3. linux 切换root账号_Linux 服务器的安全保障,看看这些
  4. JAVA入门级教学之(布尔型数据类型)
  5. android光传感实现摩斯密码,根据莫尔斯代码 - Android的闪烁闪光。 如何避免ANR次数由于睡觉? (火炬APP)...
  6. java 中的doit(n)_CoreJava测试题(含答案).docx
  7. oracle复制一个表的结构图,Oracle复制表结构
  8. hive left outer join 子查询临时表_基于历史数据的用户访问次数,每天新老用户,日活,周活,月活的hive计算...
  9. java c 客户端_java基于C/S模式实现聊天程序(客户端)
  10. python字典导入mongodb_Python语言生成内嵌式字典(dict)-案例从python提取内嵌json写入mongodb...