目录

前言

一、什么是迁移学习?

二、特征提取介绍

三、实例介绍

1.获取预训练的网络模型

2.使用数据增强

3.冻结模型参数

4.修改最后一层的输出类别数

5.定义损失函数和优化器

6.训练及验证模型

7.完整的代码:

总结



前言

深度学习一般需要大数据、深网络,但很多情况下我们并不能同时获取这些条件,但我们又想获得一个高性能的网络模型,这个时候,迁移学习就是一个很好的方法。本文将对迁移学习进行简单介绍,并运用实例讲解如何实现迁移学习。


一、什么是迁移学习?

迁移学习是机器学习的一种学习方法,其主要是把在任务A中训练得到的网络模型,通过调整和处理,使用到任务B中。迁移学习在人类生活中很常见,最简单的例子就是我们学完C语言后,再学习python就会感到很容易。

在神经网络迁移学习过程中,主要有两个应用场景:

  • 特征提取:冻结除最后的全连接层之外的所有网络的权重,最后一个全连接层被替换为具有随机权重的新层,并且仅训练新层
  • 微调:使用预训练好的初始化网络,用新数据训练部分或整个网络

在本文中,我们主要将使用特征提取进行实例训练。

二、特征提取介绍

特征提取是迁移学习的一个重要方法,其主要是先引入预训练好的网络模型,在预训练好的网络模型中添加一个简单的分类器,将原网络作为新任务的特征提取器,只对最后增加的分类器参数进行重新学习。

三、实例介绍

1.获取预训练的网络模型

pytorch中,有很多已经训练好的网络模型,我们可以使用model命令来下载这部分模型

from torchvision import modelsnet = models.resnet18(pretrained=True)

这里我下载的是ResNet18模型。除了ResNet18,pytorch还提供的已经训练好的模型有:“

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
  • Inception v3
  • GoogLeNet
  • ShuffleNet v2”

关于ResNet的网络模型结构如图:

2.使用数据增强

在加载数据集之前,我们可以先使用一些数据增强手段对图像进行处理,例如随机的缩放、反转等

trains_train = transforms.Compose(    [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定的大小;改成224*224     transforms.RandomHorizontalFlip(),  # 图像按照一定概率水平翻转     transforms.ToTensor(),  # 转成tensor数据类型     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])trains_valid = transforms.Compose(    [transforms.Resize(256),     transforms.CenterCrop(224),  # 在图片中间区域进行裁剪     transforms.ToTensor(),     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])

3.冻结模型参数

将原网络模型中的部分参数冻结,冻结的参数在反向传播中将不会被 更新

for param in net.parameters():    param.requires_grad = False

4.修改最后一层的输出类别数

员阿里输出为512*1000,如果是使用CIFAR10数据集, 该数据集一共有10个类别,需要把最后的1000改成10

net.fc = nn.Linear(512, 10)

5.定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)#loss_fn = loss_fn.to(device),imgs和targets同理#定义优化器learning_rate = 0.01optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.5)

6.训练及验证模型

for i in range(epoch):    print("第{}轮训练开始".format(i+1))    #训练步骤开始    for data in train_loader:        imgs, targets = data        imgs = imgs.to(device)        targets = targets.to(device)        outputs = net(imgs)        loss = loss_fn(outputs, targets)        #优化器优化模型        optimizer.zero_grad()        loss.backward()        optimizer.step()        total_train_step = total_train_step + 1        print("训练次数: {}, Loss {}".format(total_train_step, loss))

7.完整的代码:

import torchfrom torch import nnimport torchvisionimport torchvision.transforms as transformsfrom torchvision import modelsfrom torch.utils.data import DataLoaderdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')trains_train = transforms.Compose(    [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定的大小;改成224*224     transforms.RandomHorizontalFlip(),  # 图像按照一定概率水平翻转     transforms.ToTensor(),  # 转成tensor数据类型     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])trains_valid = transforms.Compose(    [transforms.Resize(256),     transforms.CenterCrop(224),  # 在图片中间区域进行裁剪     transforms.ToTensor(),     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])train_set = torchvision.datasets.CIFAR10("D:\pytorch_train_log\pytorch_train", train=True, download=True,                                         transform=trains_train)test_set = torchvision.datasets.CIFAR10("D:\pytorch_train_log\pytorch_train", train=False, download=True,                                        transform=trains_valid)#获得数据集的长度train_data_size = len(train_set)test_data_size = len(test_set)train_loader = DataLoader(train_set, batch_size=64)test_loader = DataLoader(test_set, batch_size=64)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = models.resnet18(pretrained=True)for param in net.parameters():    param.requires_grad = Falsenet.fc = nn.Linear(512, 10)net = net.to(device)loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)#loss_fn = loss_fn.to(device),imgs和targets同理#定义优化器learning_rate = 0.01optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.5)#设置训练网络的参数total_train_step = 0    #记录训练的次数total_test_step = 0     #记录测试的次数epoch = 20   #训练的轮数for i in range(epoch):    print("第{}轮训练开始".format(i+1))    #训练步骤开始    for data in train_loader:        imgs, targets = data        imgs = imgs.to(device)        targets = targets.to(device)        outputs = net(imgs)        loss = loss_fn(outputs, targets)        #优化器优化模型        optimizer.zero_grad()        loss.backward()        optimizer.step()        total_train_step = total_train_step + 1        print("训练次数: {}, Loss {}".format(total_train_step, loss))#测试步骤total_test_loss = 0total_accuracy = 0with torch.no_grad():    for data in test_loader:        imgs, targets = data        imgs = imgs.to(device)        targets = targets.to(device)        outputs = net(imgs)        loss = loss_fn(outputs, targets)        total_test_loss = total_test_loss + loss        accuracy = (outputs.argmax(1) == targets).sum()        total_accuracy = total_accuracy + accuracy    print("整体测试集上的Loss:{}".format(total_test_loss))    print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))    total_test_step = total_test_step + 1    torch.save(net, "net_{}.pth".format(i))    print("模型已保存")

总结

本文主要是对迁移学习进行了简单的介绍,通过使用迁移学习,可以极大的加快训练速度,并且相对普通的模型,迁移学习后的网络模型的精度也有较高的提升。

基于特征提取的迁移学习相关推荐

  1. keras冻结_【连载】深度学习第22讲:搭建一个基于keras的迁移学习花朵识别系统(附数据)...

    在上一讲中,和大家探讨了迁移学习的基本原理,并利用 keras 基于 VGG16 预训练模型简单了在 mnist 数据集上做了演示.鉴于大家对于迁移学习的兴趣,本节将继续基于迁移学习利用一些花朵数据搭 ...

  2. 基于 CNN 和迁移学习的农作物病害识别方法研究

    基于 CNN 和迁移学习的农作物病害识别方法研究 1.研究思路 采用互联网公开的 ImageNet 图像大数据集和PlantVillage 植物病害公共数据集, 以实验室的黄瓜和水稻病害数据集 AES ...

  3. 平潭迁移库是什么意思_迁移学习》第四章总结---基于模型的迁移学习

    基于模型的迁移学习可以简单理解为就是基于模型参数的迁移学习,如何使我们构建的模型可以学习到域之间的通用知识. 1. 基于共享模型成分的迁移学习 在模型中添加先验知识. 1.1 利用高斯过程的迁移学习 ...

  4. 机器学习工程师 — Udacity 基于CNN和迁移学习创建狗品种分类器

    卷积神经网络(Convolutional Neural Network, CNN) 项目:实现一个狗品种识别算法App 推荐你阅读以下材料来加深对 CNN和Transfer Learning的理解: ...

  5. Nat. Mach. Intell. | 基于神经网络的迁移学习用于单细胞RNA-seq分析中的聚类和细胞类型分类...

    今天给大家介绍由美国宾夕法尼亚大学佩雷尔曼医学院生物统计学,流行病学和信息学系Jian Hu等人在<Nature Machine Intelligence>上发表了一篇名为"It ...

  6. ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)

    ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...

  7. 竞赛获奖系统解读:远场说话人确认中基于两阶段迁移学习解决域不匹配问题

    作为Interspeech2022的赛事活动,远场说话人验证挑战赛 (FFSVC) 由昆山杜克大学.新加坡国立大学.南加州大学和希尔贝壳联合组织,主要关注极具挑战性的远场说话人确认任务.2020年举办 ...

  8. 基于MK-MMD度量迁移学习的轴承故障诊断方法研究

    摘要 上一篇文章实验是基于凯斯西厨大学轴承数据集,使用同一负载情况下的6种轴承数据进行故障诊断,并没有进行不同负载下轴承故障诊断.之前没做这块迁移学习实验,主要是对于迁移学习理解不到位,也没有不知道从 ...

  9. 基于卷积神经网络迁移学习的手写体汉字识别

    作者将基于经典的VGG-Net进行改进,构建出适合手写体汉字识别的浅层卷积神经网络模型,这个模型将适用于对3755个常用一级汉字的识别.在tensorflow环境下完成实验,得到的模型在验证集中的ac ...

最新文章

  1. debian9宝塔面板安装php失败,宝塔面板安装php失败:提示No package 'libjpeg' found的解决办法...
  2. iOS架构-xcodebuild常用命令(22)
  3. C#和Java详细描述
  4. 索引 - 数据结构 - BTREE
  5. 【解题报告】Leecode 372. 超级次方——Leecode每日一题系列
  6. H.264边缘块进行帧内预测时,上边缘和左边缘块的预测情况。
  7. 什么是Spring Boot以及为什么它是用于创建微服务的首选框架
  8. 【机器学习】 - 关于图像质量评价IQA(Image Quality Assessment)
  9. matlab约当消去法,Gauss消去法解线性方程组(Matlab)
  10. USB外接摄像头不能用怎么办
  11. 在微信小程序中使用字体图标
  12. ubuntn 16.04 安装fabric 1.0
  13. JavaScriptCore全面解析
  14. 定时关机win10_长按电源键强制关机,真的会弄坏电脑吗?
  15. Spring:pom.xml中引入依赖发红解决方案
  16. 为VSCode 设置好看的字体:Operator Mono
  17. loadrunner批量添加压力
  18. s3c2440 linux3.2.65 uda134x声卡卡顿,杂音修复
  19. 微信账户冻结怎么解除
  20. PC比电脑好玩的秘密是什么?答案就是因为有这些神奇的网站!

热门文章

  1. 对重装系统彻底说再见——电脑C盘备份
  2. zblog php getlist,zblog函数:GetArticleList()
  3. ACO-OFDM与DCO-OFDM的区别
  4. uni-app watch事件监听三种用法
  5. 电影票房爬取到MySQL中_爬取最热电影及票房统计
  6. 2016Android公司面试题
  7. 主要的数据交换格式XML与JASON
  8. 为什么看P1dB压缩,而不是2dB,3dB压缩
  9. 【记录】螺纹连接与螺旋传动
  10. python中time是什么意思_python中time的基本介绍