迁移学习训练分类模型实践第一篇

  • 前言
  • 数据获取、预处理
  • 构建模型
  • 查看模型参数量和FLOPs
  • 测试模型

前言

为了简洁,本文不包含任何训练过程,仅介绍处理数据、构建模型、使用随机初始化权重推断;
关于如何使用预训练模型,训练整个流程,后面继续介绍。

数据获取、预处理

数据集:102 Category Flower Dataset
点击下载

包括102种花卉。每个类别包含40到258张图片。这些图像有很大的尺度,姿势和光线变化。此外,还有一些类别有很大的变化,以及一些非常相似的类别。

!unzip flower_data.zip
# 导入必要的库
from collections import OrderedDict
import numpy as np
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
import torchvision.transforms.functional as TF
from torch.utils.data import Subset
from thop import profile, clever_format
from torchsummary import summary
from PIL import Image
data_dir = 'flower_data'
input_size = 224
# 用来归一化的均值和标准差
normalize_mean = np.array([0.485, 0.456, 0.406])
normalize_std = np.array([0.229, 0.224, 0.225])

构建模型

使用torchvision提供的resnet,并根据数据集修改模型的分类器,因为所提供的模型是基于ImageNet设计的,分类器是1000类,并不适用与这个数据集。将原始分类器改为102类的分类器。
这里也可以使用其他模型,后续将根据效果和需求适当调整模型

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f'Running on: {str(device).upper()}')
    Running on: CUDA
output_size = 102
model = models.resnet18()
# 替换分类器为102类
classifier = OrderedDict()
classifier['layer0'] = nn.Linear(model.fc.in_features, output_size)
classifier['output_function'] = nn.LogSoftmax(dim=1)
model.fc = nn.Sequential(classifier)model.to(device)
    ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(1): BasicBlock((conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer2): Sequential((0): BasicBlock((conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer3): Sequential((0): BasicBlock((conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(layer4): Sequential((0): BasicBlock((conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): BasicBlock((conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Sequential((layer0): Linear(in_features=512, out_features=102, bias=True)(output_function): LogSoftmax(dim=1)))

查看模型参数量和FLOPs

区分一下FLOPSFLOPs

FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

FLOPS:什么是FLOPS

参考:知乎

查看模型及其参数量以及FLOPs有助于我们对模型进一步了解,对以后部署也是可以提供优化方向的:

# model.to(device)
_input = torch.randn(1, 3, input_size, input_size).to(device)
flops, params = profile(model, inputs=(_input,))  # 自定义模块需要:custom_ops={YourModule: count_your_model}
flops, params = clever_format([flops, params], '%.6f')
print('FLOPs:', flops, '\tparams:', params )
# FLOPs: 1.819066G  params: 11.689512M
  [INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.[INFO] Register count_bn() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.[WARN] Cannot find rule for <class 'torchvision.models.resnet.BasicBlock'>. Treat it as zero Macs and zero Params.[WARN] Cannot find rule for <class 'torch.nn.modules.container.Sequential'>. Treat it as zero Macs and zero Params.[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.[WARN] Cannot find rule for <class 'torch.nn.modules.activation.LogSoftmax'>. Treat it as zero Macs and zero Params.[WARN] Cannot find rule for <class 'torchvision.models.resnet.ResNet'>. Treat it as zero Macs and zero Params.FLOPs: 1.818607G  params: 11.228838M

自定义模块需要自己添加hook去计算
custom_ops={ YourModule: count_your_model }
YourModule:自定义模块
count_your_model:自定义模块的计算hook函数
参考:thop.profile.py

# model.to(device)
summary(model, (3, input_size, input_size))
    ----------------------------------------------------------------Layer (type)               Output Shape         Param #================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0Conv2d-5           [-1, 64, 56, 56]          36,864BatchNorm2d-6           [-1, 64, 56, 56]             128ReLU-7           [-1, 64, 56, 56]               0Conv2d-8           [-1, 64, 56, 56]          36,864BatchNorm2d-9           [-1, 64, 56, 56]             128ReLU-10           [-1, 64, 56, 56]               0BasicBlock-11           [-1, 64, 56, 56]               0Conv2d-12           [-1, 64, 56, 56]          36,864BatchNorm2d-13           [-1, 64, 56, 56]             128ReLU-14           [-1, 64, 56, 56]               0Conv2d-15           [-1, 64, 56, 56]          36,864BatchNorm2d-16           [-1, 64, 56, 56]             128ReLU-17           [-1, 64, 56, 56]               0BasicBlock-18           [-1, 64, 56, 56]               0Conv2d-19          [-1, 128, 28, 28]          73,728BatchNorm2d-20          [-1, 128, 28, 28]             256ReLU-21          [-1, 128, 28, 28]               0Conv2d-22          [-1, 128, 28, 28]         147,456BatchNorm2d-23          [-1, 128, 28, 28]             256Conv2d-24          [-1, 128, 28, 28]           8,192BatchNorm2d-25          [-1, 128, 28, 28]             256ReLU-26          [-1, 128, 28, 28]               0BasicBlock-27          [-1, 128, 28, 28]               0Conv2d-28          [-1, 128, 28, 28]         147,456BatchNorm2d-29          [-1, 128, 28, 28]             256ReLU-30          [-1, 128, 28, 28]               0Conv2d-31          [-1, 128, 28, 28]         147,456BatchNorm2d-32          [-1, 128, 28, 28]             256ReLU-33          [-1, 128, 28, 28]               0BasicBlock-34          [-1, 128, 28, 28]               0Conv2d-35          [-1, 256, 14, 14]         294,912BatchNorm2d-36          [-1, 256, 14, 14]             512ReLU-37          [-1, 256, 14, 14]               0Conv2d-38          [-1, 256, 14, 14]         589,824BatchNorm2d-39          [-1, 256, 14, 14]             512Conv2d-40          [-1, 256, 14, 14]          32,768BatchNorm2d-41          [-1, 256, 14, 14]             512ReLU-42          [-1, 256, 14, 14]               0BasicBlock-43          [-1, 256, 14, 14]               0Conv2d-44          [-1, 256, 14, 14]         589,824BatchNorm2d-45          [-1, 256, 14, 14]             512ReLU-46          [-1, 256, 14, 14]               0Conv2d-47          [-1, 256, 14, 14]         589,824BatchNorm2d-48          [-1, 256, 14, 14]             512ReLU-49          [-1, 256, 14, 14]               0BasicBlock-50          [-1, 256, 14, 14]               0Conv2d-51            [-1, 512, 7, 7]       1,179,648BatchNorm2d-52            [-1, 512, 7, 7]           1,024ReLU-53            [-1, 512, 7, 7]               0Conv2d-54            [-1, 512, 7, 7]       2,359,296BatchNorm2d-55            [-1, 512, 7, 7]           1,024Conv2d-56            [-1, 512, 7, 7]         131,072BatchNorm2d-57            [-1, 512, 7, 7]           1,024ReLU-58            [-1, 512, 7, 7]               0BasicBlock-59            [-1, 512, 7, 7]               0Conv2d-60            [-1, 512, 7, 7]       2,359,296BatchNorm2d-61            [-1, 512, 7, 7]           1,024ReLU-62            [-1, 512, 7, 7]               0Conv2d-63            [-1, 512, 7, 7]       2,359,296BatchNorm2d-64            [-1, 512, 7, 7]           1,024ReLU-65            [-1, 512, 7, 7]               0BasicBlock-66            [-1, 512, 7, 7]               0AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0Linear-68                  [-1, 102]          52,326LogSoftmax-69                  [-1, 102]               0================================================================Total params: 11,228,838Trainable params: 11,228,838Non-trainable params: 0----------------------------------------------------------------Input size (MB): 0.57Forward/backward pass size (MB): 62.79Params size (MB): 42.83Estimated Total Size (MB): 106.20----------------------------------------------------------------

测试模型

使用一张图片作为测试,验证整个过程有没有问题,这里只输出了模型的推断置信度,但是这是随机值,所以并没有将其可是化,因为没有任何参考意义。后续对模型进行训练,对测试图片进行推断,可视化可以直观的了解推断的效果以评价模型的好坏。

def process_image(image):''' 预处理图片,返回numpy数组'''image = TF.resize(image, 256)upper_pixel = (image.height - 224) // 2left_pixel = (image.width - 224) // 2image = TF.crop(image, upper_pixel, left_pixel, 224, 224)image = TF.to_tensor(image)image = TF.normalize(image, normalize_mean, normalize_std)return image
def predict(image_path, model, topk=5):''' 读取图片预测结果,返回Top5'''image = Image.open(image_path)image = process_image(image)with torch.no_grad():model.eval()image = image.view(1,3,224,224)image = image.to(device)predictions = model.forward(image)predictions = torch.exp(predictions)top_ps, top_class = predictions.topk(topk, dim=1)return top_ps, top_class
category = 30
image_name = 'image_03475.jpg'
image_path = data_dir + f'/valid/{category}/{image_name}'probs, classes = predict(image_path, model)
print(probs)
print(classes)
    tensor([[0.0301, 0.0275, 0.0264, 0.0257, 0.0219]], device='cuda:0')tensor([[73, 76,  9, 62, 32]], device='cuda:0')

迁移学习训练分类模型实践第一篇相关推荐

  1. R语言基于Bagging算法(融合多个决策树)构建集成学习Bagging分类模型、并评估模型在测试集和训练集上的分类效果(accuray、F1、偏差Deviance):Bagging算法与随机森林对比

    R语言基于Bagging算法(融合多个决策树)构建集成学习Bagging分类模型.并评估模型在测试集和训练集上的分类效果(accuray.F1.偏差Deviance):Bagging算法与随机森林对比 ...

  2. 基于深度学习和迁移学习的识花实践

    深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 TensorFlow 和 Keras 等框架的出现大 ...

  3. 迁移学习1——基于深度学习和迁移学习的识花实践

    参考博客:https://cosx.org/2017/10/transfer-learning/ 下面的例子中将示范如何将一个图像识别的深度卷积网络,VGG,迁移到识别花朵类型的新任务上,在原先的任务 ...

  4. 分类 迁移学习_迁移学习时间序列分类

    迁移学习时间序列分类 题目: Transfer learning for time series classification 作者: Hassan Ismail Fawaz, Germain For ...

  5. 分类模型 第1篇:分类模型概述

    机器学习主要用于解决分类.回归和聚类问题,分类属于监督学习算法,是指根据已有的数据和标签(分类的类别)进行学习,预测未知数据的标签.分类问题的目标是预测数据的类别标签(class label),可以把 ...

  6. R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification)

    R使用LSTM模型构建深度学习文本分类模型(Quora Insincere Questions Classification) Long Short Term 网络-- 一般就叫做 LSTM --是一 ...

  7. Pytorch模型迁移和迁移学习,导入部分模型参数

    Pytorch模型迁移和迁移学习 目录 Pytorch模型迁移和迁移学习 1. 利用resnet18做迁移学习 2. 修改网络名称并迁移学习 3.去除原模型的某些模块 1. 利用resnet18做迁移 ...

  8. pytorch1.7教程实验——迁移学习训练卷积神经网络进行图像分类

    只是贴上跑通的代码以供参考学习 参考网址:迁移学习训练卷积神经网络进行图像分类 需要用到的数据集下载网址: https://download.pytorch.org/tutorial/hymenopt ...

  9. matlab实现鬼波信号压制算法(附鬼波算法压制工具包)  代码实践--第一篇 频率-空间域自适应鬼波压制

    matlab实现鬼波信号压制算法(附鬼波算法压制工具包)  代码实践 涵盖了频率-空间域.频率-波数域.拉东域鬼波压制算法     建议实践之前熟练掌握各个域鬼波压制方法的原理,才能对代码有更深入的了 ...

最新文章

  1. 将数字转化为特殊符号的密码
  2. UVa11038 - How Many O's?(统计0的个数)
  3. Python_切割和查找
  4. 计算开始到结束的时间_阿里钉钉首次战胜微信,云计算的涨停潮只是开始,远未结束...
  5. Zookeeper - 简述ZAB 协议和zookeeper
  6. SQL SERVER 事务例子
  7. 使用Beautiful Soup 中遇到的小问题-----只能提取网页上第一页信息
  8. 【转】string.Format对C#字符串格式化
  9. T-SQL数据类型的细微差别(四)
  10. android高仿ios11系统,安卓仿ios11桌面全套文件
  11. scratch素材准备
  12. Apache Pulsar 中文社区先锋奖与年度优秀案例出炉!
  13. MySQL数据库知识点大全
  14. VMware安装教程
  15. 细聊智能家居开发中必备的通信协议
  16. CentOS7下安装和配置MySQL5.7亲测有效(附图文)
  17. 三维建筑动画的制作流程
  18. 世界有时特别吝啬【摘自《青年文摘》】
  19. CH340驱动(含各平台)
  20. office提示“office未获得合适的许可,你可能是盗版软件的受害者。”解决方法

热门文章

  1. 最大似然估计法(MLE)
  2. Ticwatch2_3G版省电优化设置
  3. 欧拉测试网,测试埋伏空投
  4. 有道云协作支持Markdown了,云笔记也快了吧,哈哈
  5. Edge浏览器移除桔梗网页的方法
  6. 平面印刷品三折页使用规范
  7. 清北计算机专业研究生在哪里读研,本科清北去普通985甚至211读研,你会反向读研吗?...
  8. matlab rgb2gray() 的坑
  9. GET_PERS_LIST_4_CONFIG_ID (UI2CL_WD_CFG_UTILS)
  10. Dev C++调试模式出现蓝条按下一步卡住的解决办法