基于特征提取的迁移学习
目录
前言
一、什么是迁移学习?
二、特征提取介绍
三、实例介绍
1.获取预训练的网络模型
2.使用数据增强
3.冻结模型参数
4.修改最后一层的输出类别数
5.定义损失函数和优化器
6.训练及验证模型
7.完整的代码:
总结
前言
深度学习一般需要大数据、深网络,但很多情况下我们并不能同时获取这些条件,但我们又想获得一个高性能的网络模型,这个时候,迁移学习就是一个很好的方法。本文将对迁移学习进行简单介绍,并运用实例讲解如何实现迁移学习。
一、什么是迁移学习?
迁移学习是机器学习的一种学习方法,其主要是把在任务A中训练得到的网络模型,通过调整和处理,使用到任务B中。迁移学习在人类生活中很常见,最简单的例子就是我们学完C语言后,再学习python就会感到很容易。
在神经网络迁移学习过程中,主要有两个应用场景:
- 特征提取:冻结除最后的全连接层之外的所有网络的权重,最后一个全连接层被替换为具有随机权重的新层,并且仅训练新层
- 微调:使用预训练好的初始化网络,用新数据训练部分或整个网络
在本文中,我们主要将使用特征提取进行实例训练。
二、特征提取介绍
特征提取是迁移学习的一个重要方法,其主要是先引入预训练好的网络模型,在预训练好的网络模型中添加一个简单的分类器,将原网络作为新任务的特征提取器,只对最后增加的分类器参数进行重新学习。
三、实例介绍
1.获取预训练的网络模型
pytorch中,有很多已经训练好的网络模型,我们可以使用model命令来下载这部分模型
from torchvision import modelsnet = models.resnet18(pretrained=True)
这里我下载的是ResNet18模型。除了ResNet18,pytorch还提供的已经训练好的模型有:“
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3
- GoogLeNet
- ShuffleNet v2”
关于ResNet的网络模型结构如图:
2.使用数据增强
在加载数据集之前,我们可以先使用一些数据增强手段对图像进行处理,例如随机的缩放、反转等
trains_train = transforms.Compose( [transforms.RandomResizedCrop(224), # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定的大小;改成224*224 transforms.RandomHorizontalFlip(), # 图像按照一定概率水平翻转 transforms.ToTensor(), # 转成tensor数据类型 transforms.Normalize(mean=[0.485, 0.456, 0.406], # 归一化处理 std=[0.229, 0.224, 0.225])])trains_valid = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), # 在图片中间区域进行裁剪 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], # 归一化处理 std=[0.229, 0.224, 0.225])])
3.冻结模型参数
将原网络模型中的部分参数冻结,冻结的参数在反向传播中将不会被 更新
for param in net.parameters(): param.requires_grad = False
4.修改最后一层的输出类别数
员阿里输出为512*1000,如果是使用CIFAR10数据集, 该数据集一共有10个类别,需要把最后的1000改成10
net.fc = nn.Linear(512, 10)
5.定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)#loss_fn = loss_fn.to(device),imgs和targets同理#定义优化器learning_rate = 0.01optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.5)
6.训练及验证模型
for i in range(epoch): print("第{}轮训练开始".format(i+1)) #训练步骤开始 for data in train_loader: imgs, targets = data imgs = imgs.to(device) targets = targets.to(device) outputs = net(imgs) loss = loss_fn(outputs, targets) #优化器优化模型 optimizer.zero_grad() loss.backward() optimizer.step() total_train_step = total_train_step + 1 print("训练次数: {}, Loss {}".format(total_train_step, loss))
7.完整的代码:
import torchfrom torch import nnimport torchvisionimport torchvision.transforms as transformsfrom torchvision import modelsfrom torch.utils.data import DataLoaderdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')trains_train = transforms.Compose( [transforms.RandomResizedCrop(224), # 将给定图像随机裁剪为不同的大小和宽高比,然后缩放所裁剪得到的图像为指定的大小;改成224*224 transforms.RandomHorizontalFlip(), # 图像按照一定概率水平翻转 transforms.ToTensor(), # 转成tensor数据类型 transforms.Normalize(mean=[0.485, 0.456, 0.406], # 归一化处理 std=[0.229, 0.224, 0.225])])trains_valid = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), # 在图片中间区域进行裁剪 transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], # 归一化处理 std=[0.229, 0.224, 0.225])])train_set = torchvision.datasets.CIFAR10("D:\pytorch_train_log\pytorch_train", train=True, download=True, transform=trains_train)test_set = torchvision.datasets.CIFAR10("D:\pytorch_train_log\pytorch_train", train=False, download=True, transform=trains_valid)#获得数据集的长度train_data_size = len(train_set)test_data_size = len(test_set)train_loader = DataLoader(train_set, batch_size=64)test_loader = DataLoader(test_set, batch_size=64)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = models.resnet18(pretrained=True)for param in net.parameters(): param.requires_grad = Falsenet.fc = nn.Linear(512, 10)net = net.to(device)loss_fn = nn.CrossEntropyLoss()loss_fn = loss_fn.to(device)#loss_fn = loss_fn.to(device),imgs和targets同理#定义优化器learning_rate = 0.01optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.5)#设置训练网络的参数total_train_step = 0 #记录训练的次数total_test_step = 0 #记录测试的次数epoch = 20 #训练的轮数for i in range(epoch): print("第{}轮训练开始".format(i+1)) #训练步骤开始 for data in train_loader: imgs, targets = data imgs = imgs.to(device) targets = targets.to(device) outputs = net(imgs) loss = loss_fn(outputs, targets) #优化器优化模型 optimizer.zero_grad() loss.backward() optimizer.step() total_train_step = total_train_step + 1 print("训练次数: {}, Loss {}".format(total_train_step, loss))#测试步骤total_test_loss = 0total_accuracy = 0with torch.no_grad(): for data in test_loader: imgs, targets = data imgs = imgs.to(device) targets = targets.to(device) outputs = net(imgs) loss = loss_fn(outputs, targets) total_test_loss = total_test_loss + loss accuracy = (outputs.argmax(1) == targets).sum() total_accuracy = total_accuracy + accuracy print("整体测试集上的Loss:{}".format(total_test_loss)) print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size)) total_test_step = total_test_step + 1 torch.save(net, "net_{}.pth".format(i)) print("模型已保存")
总结
本文主要是对迁移学习进行了简单的介绍,通过使用迁移学习,可以极大的加快训练速度,并且相对普通的模型,迁移学习后的网络模型的精度也有较高的提升。
基于特征提取的迁移学习相关推荐
- keras冻结_【连载】深度学习第22讲:搭建一个基于keras的迁移学习花朵识别系统(附数据)...
在上一讲中,和大家探讨了迁移学习的基本原理,并利用 keras 基于 VGG16 预训练模型简单了在 mnist 数据集上做了演示.鉴于大家对于迁移学习的兴趣,本节将继续基于迁移学习利用一些花朵数据搭 ...
- 基于 CNN 和迁移学习的农作物病害识别方法研究
基于 CNN 和迁移学习的农作物病害识别方法研究 1.研究思路 采用互联网公开的 ImageNet 图像大数据集和PlantVillage 植物病害公共数据集, 以实验室的黄瓜和水稻病害数据集 AES ...
- 平潭迁移库是什么意思_迁移学习》第四章总结---基于模型的迁移学习
基于模型的迁移学习可以简单理解为就是基于模型参数的迁移学习,如何使我们构建的模型可以学习到域之间的通用知识. 1. 基于共享模型成分的迁移学习 在模型中添加先验知识. 1.1 利用高斯过程的迁移学习 ...
- 机器学习工程师 — Udacity 基于CNN和迁移学习创建狗品种分类器
卷积神经网络(Convolutional Neural Network, CNN) 项目:实现一个狗品种识别算法App 推荐你阅读以下材料来加深对 CNN和Transfer Learning的理解: ...
- Nat. Mach. Intell. | 基于神经网络的迁移学习用于单细胞RNA-seq分析中的聚类和细胞类型分类...
今天给大家介绍由美国宾夕法尼亚大学佩雷尔曼医学院生物统计学,流行病学和信息学系Jian Hu等人在<Nature Machine Intelligence>上发表了一篇名为"It ...
- ML.NET 示例:图像分类模型训练-首选API(基于原生TensorFlow迁移学习)
ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法 Microsoft.ML 1.5.0 动态API 最新 控制台应用程序和Web应用程序 图片文件 图像分类 基 ...
- 竞赛获奖系统解读:远场说话人确认中基于两阶段迁移学习解决域不匹配问题
作为Interspeech2022的赛事活动,远场说话人验证挑战赛 (FFSVC) 由昆山杜克大学.新加坡国立大学.南加州大学和希尔贝壳联合组织,主要关注极具挑战性的远场说话人确认任务.2020年举办 ...
- 基于MK-MMD度量迁移学习的轴承故障诊断方法研究
摘要 上一篇文章实验是基于凯斯西厨大学轴承数据集,使用同一负载情况下的6种轴承数据进行故障诊断,并没有进行不同负载下轴承故障诊断.之前没做这块迁移学习实验,主要是对于迁移学习理解不到位,也没有不知道从 ...
- 基于卷积神经网络迁移学习的手写体汉字识别
作者将基于经典的VGG-Net进行改进,构建出适合手写体汉字识别的浅层卷积神经网络模型,这个模型将适用于对3755个常用一级汉字的识别.在tensorflow环境下完成实验,得到的模型在验证集中的ac ...
最新文章
- 综合布线施工中的不规范现象
- Python代码块批量添加Tab缩进
- 安卓入门系列-03安卓的开发方式(逻辑与视图分离)
- 5904.刺客信条(AC)
- git报错:fatal: remote origin already exists
- C++学习札记(1)
- merkle tree(hash tree)
- Function.prototype.bind、call与apply方法简介
- 【运动学】基于matlab Singer模型算法机动目标跟踪【含Matlab源码 1157期】
- 超标量体系结构_计算机体系结构——以多发射和静态调度来开发ILP
- python 输出语句
- Android音视频三-AndroidStudio整合FFmpeg项目+FFmpeg视频解码
- 如何利用python计算即期利率_即期利率的定义_即期利率的计算公式_即期利率和远期利率...
- 微信小程序--监听对象属性变化
- 浅谈CPU 硬盘 内存关系
- 今年,你义务植树了吗?
- 计算机毕业设计SSM大学生社团管理系统【附源码数据库】
- 数字视音频处理知识点小结
- Python爬取热门微博,并存储到MySQL中
- Tech Lead(技术经理) 带人之道
热门文章
- 2017年英语六级作文(附翻译)
- 微信小程序之如何实现一寸照片换底色(附小程序成品)
- Havel算法-Python实现
- 现代农业智能温室种植系统方案
- 【数论】GDKOI day1 讲座(数论基本知识 详)
- 算法改进有多快?是否比迭代硬件收益更大?这是 MIT 的结论
- 图匹配(Graph Matching)入门学习笔记——以《Factorized Graph Matching》为例(一)
- 硬件知识:电源开关上的“1“和“0“分别是什么意思
- 年底不要慌,这个EXCEL模板帮你打赢Q4收官战
- jetson nano yolov5部署及USB摄像头实时检测 初次尝试