文章目录

  • 引入
  • 1 微调
  • 2 热狗识别
    • 2.1 数据集载入
    • 2.2 数据集预处理
    • 2.3 定义和初始化模型
    • 2.4 微调模型
  • 致谢

引入

  场景
  从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。
  方案1: 找出100种常见的椅子,为每种椅子拍摄1000张不同角度的图像,并训练模型。
  缺点: 数据集看似庞大,样本类型和数据集仍不足;适用于更大规模数据集的模型,可能会在该数据集上过拟合。
  方案2: 收集更多的数据。
  缺点: 收集和标注的成本巨大。
  方案3: 应用迁移学习 (transfer learning):将从源数据集学到的知识迁移到目标数据集。例如,虽然ImageNet数据集的图像大多和椅子无关,但是该数据集上训练的模型可以抽取较通用的图像特征,从而帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

1 微调

  微调 (fine tuning)是迁移学习的一种常用技术,分为以下四步,具体如下图:


  1)在源数据集上训练一个神经网络模型,即源模型
  2)创建一个新的神经网络模型,即目标模型:它复制了源模型上除输出层外的所有模型设计及其参数;
  3)为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数;
  4)在目标数据集上训练目标模型。我们将从新训练输出层,其余层的参数则是微调获得。
  该方案的假设如下:
  1)源模型参数包含了源数据集上学到的知识,并能够适用于目标数据集;
  2)源模型的输出层与源数据集的标签紧密相关,因此在目标模型中不予采用。
  优点:
  当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

2 热狗识别

  本例基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集包含数千张包含热狗和不包含热狗的图像的图像。最终将判断一个图像中是否包含热狗。

2.1 数据集载入

  数据集的说明和载入可以参照:
  https://blog.csdn.net/weixin_44575152/article/details/118901491

2.2 数据集预处理

  训练: 从图像中随机裁剪出随机大小和随机高宽比的一块区域,并将该区域缩放为 224 × 224 224 \times 224 224×224。
  测试: 将图像的高和宽缩放为 256 × 256 256 \times 256 256×256,并裁剪出 224 × 224 224 \times 224 224×224的中心区域。
  RGB标准化: 每个通道的数值减去当前通道所有数值的平均值,再除以该通道所有数值的标准差作为输出。

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torchvision.datasets import ImageFolderdef get_format_image(save_home="D:/Data/Image/hotdog/"):"""图像载入"""# 训练集和测试集normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])aug_tr = transforms.Compose([transforms.RandomResizedCrop(size=224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),normalize,])aug_te = transforms.Compose([transforms.Resize(size=256),transforms.CenterCrop(size=224),transforms.ToTensor(),normalize,])image_tr = DataLoader(ImageFolder(save_home + "train", transform=aug_tr), batch_size=128, shuffle=True)image_te = DataLoader(ImageFolder(save_home + "train", transform=aug_te), batch_size=128)return image_tr, image_te

2.3 定义和初始化模型

  对于预训练的模型,可以设置较小的学习率进行微调;而新加入的用于适应输出的全连接层,学习率则设置为较大值:

def get_model(lr=0.01):"""获取预训练模型"""net_pretrained = models.resnet18(pretrained=True)# fc是一个全连接层,通过设置可以适应指定数据集,例如二分类net_pretrained.fc = nn.Linear(512, 2)params_output = list(map(id, net_pretrained.fc.parameters()))params_feature = filter(lambda p: id(p) not in params_output, net_pretrained.parameters())# 这里fc层是从头学,所以其学习率设置的较大,而预训练层的则相对较小optimizer = optim.SGD([{"params": params_feature},{"params": net_pretrained.fc.parameters(),"lr": lr * 10}],lr=lr, weight_decay=0.001,)return net_pretrained, optimizer

2.4 微调模型

def train(train_iter, test_iter, net, loss, optimizer, num_epochs):import timedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")net = net.to(device)print("training on ", device)batch_count = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print("epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec"% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))def evaluate_accuracy(data_iter, net, device=None):if device is None and isinstance(net, torch.nn.Module):# 如果没指定device就使用net的devicedevice = list(net.parameters())[0].deviceacc_sum, n = 0.0, 0with torch.no_grad():for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval()  # 评估模式, 这会关闭dropoutacc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train()  # 改回训练模式else:  # 自定义的模型, 3.13节之后不会用到, 不考虑GPUif ('is_training' in net.__code__.co_varnames):  # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item()else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum / ndef train_fine_tuning():"""模型微调测试"""net, optimizer = get_model()image_tr, image_te = get_format_image()loss = torch.nn.CrossEntropyLoss()train(image_tr, image_te, net, loss, optimizer, num_epochs=5)if __name__ == '__main__':train_fine_tuning()

致谢

  感谢李沐、Aston Zhang等老师的这本《动手学深度学习》一书,为鄙人学习深度学习提供了很大的帮助。本文一系列关于深度学习的博客均无侵权之意,只为记录自己的深度学习历程。
  项目Github地址:https://github.com/ShusenTang/Dive-into-DL-PyTorch

torch学习 (三十四):迁移学习之微调相关推荐

  1. Java多线程学习三十四:使用 Future 有哪些注意点?Future 产生新的线程了吗

    Future 的注意点 1. 当 for 循环批量获取 Future 的结果时容易 block,get 方法调用时应使用 timeout 限制 对于 Future 而言,第一个注意点就是,当 for ...

  2. Java开发学习(三十四)----Maven私服(二)本地仓库访问私服配置与私服资源上传下载

    一.本地仓库访问私服配置 我们通过IDEA将开发的模块上传到私服,中间是要经过本地Maven的 本地Maven需要知道私服的访问地址以及私服访问的用户名和密码 私服中的仓库很多,Maven最终要把资源 ...

  3. shell学习三十四天----printf详解

    http://blog.csdn.net/shanyongxu/article/details/46744055

  4. OpenCV学习三十四:watershed 分水岭算法

    1. watershed void watershed( InputArray image, InputOutputArray markers ); 第一个参数 image,必须是一个8bit 3通道 ...

  5. GUI学习之十四——QAbstractSpinBox学习总结

    QAbstractSpinBox是一个抽象类,是将所有步长调节器的通用的功能抽象出了一个父类.虽然QAbstractSpinBox是一个抽象类,但是可以直接实例化使用.QAbstractSpinBox ...

  6. 推荐系统遇上深度学习(三十九)-推荐系统中召回策略演进!

    推荐系统中的核心是从海量的商品库挑选合适商品最终展示给用户.由于商品库数量巨大,因此常见的推荐系统一般分为两个阶段,即召回阶段和排序阶段.召回阶段主要是从全量的商品库中得到用户可能感兴趣的一小部分候选 ...

  7. 回溯法采用的搜索策略_强化学习基础篇(三十四)基于模拟的搜索算法

    强化学习基础篇(三十四)基于模拟的搜索算法 上一篇Dyna算法是基于真实经验数据和模拟经验数据来解决马尔科夫决策过程的问题.本篇将结合前向搜索和采样法,构建更加高效的搜索规划算法,即基于模拟的搜索算法 ...

  8. JavaScript学习(三十四)—事件委托

    JavaScript学习(三十四)-事件委托 (一).什么是事件委托? 所谓的事件委托就是指将事件添加到祖先元素身上,依据事件冒泡的原理(就是指事件的执行顺序是从当前元素逐步扩展到祖先元素,直到扩展到 ...

  9. OpenCV学习(二十四 ):角点检测(Corner Detection):cornerHarris(),goodFeatureToTrack()

    OpenCV学习(二十四 ):角点检测(Corner Detection):cornerHarris(),goodFeatureToTrack() 参考博客: Harris角点检测原理详解 Harri ...

最新文章

  1. mysql 按照范围选择_选择MySQL范围内的特定行?
  2. 1098 Insertion or Heap Sort 需再做
  3. Datatable中对某列求和,三种不同情况下的方法 .
  4. 2能不用cuda_cuda学习-1-cufft的使用
  5. virtualBox下安装Linux6.4
  6. ARC106E-Medals【hall定理,高维前缀和】
  7. 学PyTorch还是TensorFlow?
  8. Bootstrap导航条所支持的组件
  9. 编译时如何看到每个文件的编译选项_导出 Clang 可视化编译耗时分析报告 —— ftimetrace 的使用...
  10. 禁用 ssh agent_如何修复“禁用Agent XP”错误
  11. 安兔兔html5测试35000,安兔兔评测 8.4.3 安卓版
  12. 界面开发用qt还是java,做windows界面,用QT还是MFC?
  13. SECURITY:补丁
  14. 第二课 MC9S08DZ60之多功能时钟发生器S08MCGV1
  15. SECS协议的SML表示
  16. 利用OneDNS同步chrome数据
  17. 数字图像处理:线性和非线性滤波的平滑空间滤波器(Smoothing Spatial Filters)
  18. 工具及方法 - 查看飞机信息
  19. 帆软10.0服务器Tomcat 下通过 IP 直接访问数据决策系统出错
  20. 如何评价唐卫国公李靖的战功、军事才能、政治才能?

热门文章

  1. linux 解压rar格式的文件怎么打开,linux服务器怎么解压rar格式的文件
  2. 衍射微透镜 设计 matlab,亚波长衍射微透镜色散的数值分析
  3. python --cpca(从文本中提取省市区)
  4. 在C4D中使用Python快速导角/将贴图投射模式设置为平直
  5. 发动机双可变气门正时信号示波器测量
  6. Python在自动化运维中的应用之批量配置交换机
  7. 保定计算机专业中专学校排名,保定计算机中专技校排名_城铁轨道
  8. 一篇文章带你搞定 SpringBoot 上传文件(单文件/多文件/Ajax上传)
  9. Android 一键开启手电筒
  10. 手把手实操系列|信贷风控中的额度管理和额度模型设计