目录

前言

一、什么是迁移学习?

二、特征提取介绍

三、实例介绍

1.获取预训练的网络模型

2.使用数据增强

3.冻结模型参数

4.修改最后一层的输出类别数

5.定义损失函数和优化器

6.训练及验证模型

7.完整的代码:

总结



前言

深度学习一般需要大数据、深网络,但很多情况下我们并不能同时获取这些条件,但我们又想获得一个高性能的网络模型,这个时候,迁移学习就是一个很好的方法。本文将对迁移学习进行简单介绍,并运用实例讲解如何实现迁移学习。


一、什么是迁移学习?

迁移学习是机器学习的一种学习方法,其主要是把在任务A中训练得到的网络模型,通过调整和处理,使用到任务B中。迁移学习在人类生活中很常见,最简单的例子就是我们学完C语言后,再学习python就会感到很容易。

在神经网络迁移学习过程中,主要有两个应用场景:

  • 特征提取:冻结除最后的全连接层之外的所有网络的权重,最后一个全连接层被替换为具有随机权重的新层,并且仅训练新层
  • 微调:使用预训练好的初始化网络,用新数据训练部分或整个网络

在本文中,我们主要将使用特征提取进行实例训练。

二、特征提取介绍

特征提取是迁移学习的一个重要方法,其主要是先引入预训练好的网络模型,在预训练好的网络模型中添加一个简单的分类器,将原网络作为新任务的特征提取器,只对最后增加的分类器参数进行重新学习。

三、实例介绍

1.获取预训练的网络模型

pytorch中,有很多已经训练好的网络模型,我们可以使用model命令来下载这部分模型

from torchvision import modelsnet = models.resnet18(pretrained=True)

这里我下载的是ResNet18模型。除了ResNet18,pytorch还提供的已经训练好的模型有:“

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
  • Inception v3
  • GoogLeNet
  • ShuffleNet v2”

关于ResNet的网络模型结构如图:

2.使用数据增强

在加载数据集之前,我们可以先使用一些数据增强手段对图像进行处理,例如随机的缩放、反转等

trains_train = transforms.Compose(    [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定的大小;改成224*224     transforms.RandomHorizontalFlip(),  # 图像按照一定概率水平翻转     transforms.ToTensor(),  # 转成tensor数据类型     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])trains_valid = transforms.Compose(    [transforms.Resize(256),     transforms.CenterCrop(224),  # 在图片中间区域进行裁剪     transforms.ToTensor(),     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])

3.冻结模型参数

将原网络模型中的部分参数冻结,冻结的参数在反向传播中将不会被 更新

for param in net.parameters():    param.requires_grad = False

4.修改最后一层的输出类别数

员阿里输出为512*1000,如果是使用CIFAR10数据集, 该数据集一共有10个类别,需要把最后的1000改成10

net.fc = nn.Linear(512, 10)

5.定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)#loss_fn = loss_fn.to(device),imgs和targets同理#定义优化器learning_rate = 0.01optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.5)

6.训练及验证模型

for i in range(epoch):    print("第{}轮训练开始".format(i+1))    #训练步骤开始    for data in train_loader:        imgs, targets = data        imgs = imgs.to(device)        targets = targets.to(device)        outputs = net(imgs)        loss = loss_fn(outputs, targets)        #优化器优化模型        optimizer.zero_grad()        loss.backward()        optimizer.step()        total_train_step = total_train_step + 1        print("训练次数: {}, Loss {}".format(total_train_step, loss))

7.完整的代码:

import torchfrom torch import nnimport torchvisionimport torchvision.transforms as transformsfrom torchvision import modelsfrom torch.utils.data import DataLoaderdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')trains_train = transforms.Compose(    [transforms.RandomResizedCrop(224),  # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定的大小;改成224*224     transforms.RandomHorizontalFlip(),  # 图像按照一定概率水平翻转     transforms.ToTensor(),  # 转成tensor数据类型     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])trains_valid = transforms.Compose(    [transforms.Resize(256),     transforms.CenterCrop(224),  # 在图片中间区域进行裁剪     transforms.ToTensor(),     transforms.Normalize(mean=[0.485, 0.456, 0.406],  # 归一化处理                          std=[0.229, 0.224, 0.225])])train_set = torchvision.datasets.CIFAR10("D:\pytorch_train_log\pytorch_train", train=True, download=True,                                         transform=trains_train)test_set = torchvision.datasets.CIFAR10("D:\pytorch_train_log\pytorch_train", train=False, download=True,                                        transform=trains_valid)#获得数据集的长度train_data_size = len(train_set)test_data_size = len(test_set)train_loader = DataLoader(train_set, batch_size=64)test_loader = DataLoader(test_set, batch_size=64)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = models.resnet18(pretrained=True)for param in net.parameters():    param.requires_grad = Falsenet.fc = nn.Linear(512, 10)net = net.to(device)loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)#loss_fn = loss_fn.to(device),imgs和targets同理#定义优化器learning_rate = 0.01optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.5)#设置训练网络的参数total_train_step = 0    #记录训练的次数total_test_step = 0     #记录测试的次数epoch = 20   #训练的轮数for i in range(epoch):    print("第{}轮训练开始".format(i+1))    #训练步骤开始    for data in train_loader:        imgs, targets = data        imgs = imgs.to(device)        targets = targets.to(device)        outputs = net(imgs)        loss = loss_fn(outputs, targets)        #优化器优化模型        optimizer.zero_grad()        loss.backward()        optimizer.step()        total_train_step = total_train_step + 1        print("训练次数: {}, Loss {}".format(total_train_step, loss))#测试步骤total_test_loss = 0total_accuracy = 0with torch.no_grad():    for data in test_loader:        imgs, targets = data        imgs = imgs.to(device)        targets = targets.to(device)        outputs = net(imgs)        loss = loss_fn(outputs, targets)        total_test_loss = total_test_loss + loss        accuracy = (outputs.argmax(1) == targets).sum()        total_accuracy = total_accuracy + accuracy    print("整体测试集上的Loss:{}".format(total_test_loss))    print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))    total_test_step = total_test_step + 1    torch.save(net, "net_{}.pth".format(i))    print("模型已保存")

总结

本文主要是对迁移学习进行了简单的介绍,通过使用迁移学习,可以极大的加快训练速度,并且相对普通的模型,迁移学习后的网络模型的精度也有较高的提升。

基于特征提取的迁移学习相关推荐

  1. keras冻结_【连载】深度学习第22讲:搭建一个基于keras的迁移学习花朵识别系统(附数据)...

    在上一讲中,和大家探讨了迁移学习的基本原理,并利用 keras 基于 VGG16 预训练模型简单了在 mnist 数据集上做了演示.鉴于大家对于迁移学习的兴趣,本节将继续基于迁移学习利用一些花朵数据搭 ...

  2. 基于 CNN 和迁移学习的农作物病害识别方法研究

    基于 CNN 和迁移学习的农作物病害识别方法研究 1.研究思路 采用互联网公开的 ImageNet 图像大数据集和PlantVillage 植物病害公共数据集, 以实验室的黄瓜和水稻病害数据集 AES ...

  3. 平潭迁移库是什么意思_迁移学习》第四章总结---基于模型的迁移学习

    基于模型的迁移学习可以简单理解为就是基于模型参数的迁移学习,如何使我们构建的模型可以学习到域之间的通用知识. 1. 基于共享模型成分的迁移学习 在模型中添加先验知识. 1.1 利用高斯过程的迁移学习 ...

  4. 机器学习工程师 — Udacity 基于CNN和迁移学习创建狗品种分类器

    卷积神经网络(Convolutional Neural Network, CNN) 项目:实现一个狗品种识别算法App 推荐你阅读以下材料来加深对 CNN和Transfer Learning的理解: ...

  5. Nat. Mach. Intell. | 基于神经网络的迁移学习用于单细胞RNA-seq分析中的聚类和细胞类型分类...

    今天给大家介绍由美国宾夕法尼亚大学佩雷尔曼医学院生物统计学,流行病学和信息学系Jian Hu等人在<Nature Machine Intelligence>上发表了一篇名为"It ...

  6. ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)

    ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...

  7. 竞赛获奖系统解读:远场说话人确认中基于两阶段迁移学习解决域不匹配问题

    作为Interspeech2022的赛事活动,远场说话人验证挑战赛 (FFSVC) 由昆山杜克大学.新加坡国立大学.南加州大学和希尔贝壳联合组织,主要关注极具挑战性的远场说话人确认任务.2020年举办 ...

  8. 基于MK-MMD度量迁移学习的轴承故障诊断方法研究

    摘要 上一篇文章实验是基于凯斯西厨大学轴承数据集,使用同一负载情况下的6种轴承数据进行故障诊断,并没有进行不同负载下轴承故障诊断.之前没做这块迁移学习实验,主要是对于迁移学习理解不到位,也没有不知道从 ...

  9. 基于卷积神经网络迁移学习的手写体汉字识别

    作者将基于经典的VGG-Net进行改进,构建出适合手写体汉字识别的浅层卷积神经网络模型,这个模型将适用于对3755个常用一级汉字的识别.在tensorflow环境下完成实验,得到的模型在验证集中的ac ...

最新文章

  1. Linux那些事儿 之 戏说USB(30)驱动的生命线(二)
  2. sql 字符串比较大小_SQL简单查询
  3. 05. 取SQL分组中的某几行数据
  4. python 定义一个插入数据(可以插入到每个表中)通用的方法
  5. Python_socketserver
  6. (80)Verilog HDL测试激励:保存波形文件
  7. PHP生成海报 文字描边,php实现图片添加描边字和马赛克的方法
  8. 利用python os模块搜索指定目录下包含指定字符的文件
  9. 数据分析用这样的可视化报表,秒杀Excel,再也不怕被说low
  10. 导入jasperreports出现Cannot resolve com.lowagie:itext:2.1.7.js6异常、生成PDF中文不显示中文解决方法、使用命令安装jar包
  11. python数据建模优缺点_Python数据分析\建模入门建议
  12. 新能源汽车厂四大派系
  13. 软件工程(速成)——第四章 总体设计
  14. 易点天下深度解决方案Predicted Payer正式上线,让ROI更有保障
  15. fastjson使用
  16. worldPress数据库
  17. 从零一起学Spring Boot之LayIM项目长成记(二) LayIM初体验
  18. 利用python实现判断两条直线是否平行,若相交,输出交点。
  19. 微信小程序时区时间转换
  20. 如何开始做一个开源项目?他的亲身经历值得参考

热门文章

  1. js实现的复制和粘贴
  2. 无法启动此应用因为计算机丢失,开机无法启动此程序因为计算机中丢失怎么回事...
  3. Python 之 Anaconda
  4. 一文讲清:对象存储、文件存储、块存储。绝对好文
  5. ArcGIS Pro教程 | 1#数据准备
  6. https 请求需要证书,忽略安全证书
  7. Havel-Hakimi定理(判断是否可图序列)
  8. dell服务器设置CPU高性能,DellR720服务器提示cpu1 internal error (IERR)
  9. 淘宝一月上钻是这样操作的
  10. win10禁止访问某网站