文章目录

  • 一、数据准备
  • 二、模型修改
  • 三、模型训练
  • 四、模型效果可视化
  • 五、如何分别计算每个类别的精确率和召回率

MMclassification 是一个分类工具库,这篇文章是简单记录一下如何用该工具库来训练自己的分类模型,包括数据准备,模型修改,模型训练,模型测试等等。

MMclassification链接:https://github.com/open-mmlab/mmclassification

安装:https://mmclassification.readthedocs.io/en/latest/install.html

训练:https://mmclassification.readthedocs.io/en/latest/getting_started.html

一、数据准备

MMclassification 支持 ImageNet 和 cifar 两种数据格式,我们以 ImageNet 为例来看看数据结构:

|- imagenet
|    |- classmap.txt
|    |- train
|    |   |- cls1
|    |   |- cls2
|    |   |- cls3
|    |   |- ...
|    |- train.txt
|    |- val
|    |   |- images
|    |- val.txt

假设我们要训练一个猫狗二分类模型,则需要组织的形式如下:

|- dog_cat_dataset
|    |- classmap.txt
|    |- train
|    |   |- dog
|    |   |- cat
|    |- train.txt
|    |- val
|    |   |- images
|    |- val.txt

其中,classmap.txt 中的内容如下:

dog 0
cat 1

二、模型修改

假设使用 resnet18 来训练,则我们需要修改的内容主要集中在 config 文件里边,修改后的config文件 resnet18_b32x8_dog_cat_cls.py 如下:

  • 修改类别:将 1000 类改为 2 类
  • 修改数据路径:data
  • 如果数据前处理需要修改的话,也可以在config里边修改
  • 因为config是最高级的,所以在这里修改后会覆盖模型从mmcls库中读出来的参数
_base_ = ['../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py','../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(head=dict(type='LinearClsHead',num_classes=2,in_channels=512,loss=dict(type='CrossEntropyLoss', loss_weight=1.0),topk=(1, ),))data = dict(samples_per_gpu=32,workers_per_gpu=1,train=dict(data_prefix='data/dog_cat_dataset/train',ann_file='data/dog_cat_dataset/train.txt',classes='data/dog_cat_dataset/classmap.txt'),val=dict(data_prefix='data/dog_cat_dataset/val',ann_file='data/dog_cat_dataset/val.txt',classes='data/dog_cat_dataset/classmap.txt'),test=dict(# replace `data/val` with `data/test` for standard testdata_prefix='data/dog_cat_dataset/val',ann_file='data/dog_cat_dataset/val.txt',classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={'topk': (1, )})

三、模型训练

python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py

四、模型效果可视化

python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls

使用 gradcam 可视化:

python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py  ./models/epoch_99.pth --s
ave-path visual_path/4.jpg

五、如何分别计算每个类别的精确率和召回率

先进行测试,得到 result.pkl 文件,然后运行下面的程序即可:

python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrixdef parse_args():parser = argparse.ArgumentParser(description='calculate precision and recall for each class')parser.add_argument('config', help='test config file path')args = parser.parse_args()return argsdef main():args = parse_args()cfg = mmcv.Config.fromfile(args.config)dataset = build_dataset(cfg.data.test)pred = mmcv.load("./result.pkl")['pred_label']matrix = confusion_matrix(pred, dataset.get_gt_labels())print('confusion_matrix:', matrix)cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])cat_precision = matrix[0,0]/sum(matrix[0])dog_precision = matrix[1,1]/sum(matrix[1])print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))if __name__ == '__main__':main()

【图像分类】如何使用 mmclassification 训练自己的分类模型相关推荐

  1. tensorflow.js在nodejs训练猫狗分类模型在浏览器上使用

    目录 本人系统环境 注意事项 前言 数据集准备 处理数据集 数据集初步处理 将每一张图片数据转换成张量数据(tensor) 将图片转换成张量数组的代码和运行效果 将图片的标注转换成张量数据(tenso ...

  2. C++利用opencv调用pytorch训练好的分类模型

    pytorch保存模型 import torch.onnxd = torch.rand(1, 3, 224, 224,dtype=torch.float,device = 'cuda') m = mo ...

  3. 用英特尔独立显卡训练AI智能收银机分类模型

    作者:罗宏裕,张晶 英特尔独立显卡技术指导:唐文凯 本文将介绍在英特尔独立显卡上训练AI智能收银机分类模型的全流程,在下一篇中将介绍基于OpenVINOTM在AIxBoard上部署训练好的模型,快速实 ...

  4. 详细讲解分类模型评价指标(混淆矩阵)python示例

    前言 1.回归模型(regression): 对于回归模型的评估方法,通常会采用平均绝对误差(MAE).均方误差(MSE).平均绝对百分比误差(MAPE)等方法. 2.聚类模型(clustering) ...

  5. scikit-learn工具包中分类模型predict_proba、predict、decision_function用法详解

    在使用sklearn训练完分类模型后,下一步就是要验证一下模型的预测结果,对于分类模型,sklearn中通常提供了predict_proba.predict.decision_function三种方法 ...

  6. 【超详细】MMLab分类任务mmclassification:环境配置说明、训练、预测及模型结果可视化展示

    本文详细介绍了使用MMLab的mmclassification进行分类任务的环境配置.训练与预测流程. 目录 文件配置说明 下载源码 配置文件 基于预训练模型微调或者续训练自己模型的方式 配置文件说明 ...

  7. 星星模型 维度_用模型“想象”出来的target来训练,可以提高分类的效果!

    LearnFromPapers系列--用模型"想象"出来的target来训练,可以提高分类的效果 作者:郭必扬 时间:2020年最后一天 前言:今天是2020年最后一天,这篇文章也 ...

  8. pytorch 训练过程acc_深度学习Pytorch实现分类模型

    今天将介绍深度学习中的分类模型,以下主要介绍Softmax的基本概念.神经网络模型.交叉熵损失函数.准确率以及Pytorch实现图像分类.01Softmax基本概念 在分类问题中,通常标签都为类别,可 ...

  9. 神经网络学习小记录19——微调VGG分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录19--微调VGG分类模型训练自己的数据(猫狗数据集) 注意事项 学习前言 什么是VGG16模型 VGG模型的复杂程度 训练前准备 1.数据集处理 2.创建Keras的VGG模型 3 ...

最新文章

  1. vb破解万能断点816c24
  2. 原来Rproj还可以这么使用
  3. Linux内核同步 - Read/Write spin lock
  4. WSL:ssh 本地与阿里云数据交互
  5. 详细理解中缀表达式并实现
  6. matlab gui简单教程
  7. 教你100%成功安装Mathcad15
  8. 一款好用的Windows引导项管理工具BOOTICE
  9. flutter 的像素尺寸
  10. [4G5G专题-124]:5G培训部署篇-2-主要信令流程
  11. unit英语读音_unit是什么意思_unit翻译_读音_用法_翻译
  12. 计算机管理系统功能模块,设备管理系统功能模块
  13. JPA ERROR: value too long for type character varying(100)
  14. 面试官这样,面试就有戏了!
  15. 推荐四款自用的电脑神器
  16. 微星GS65 英雄联盟崩溃
  17. 谷粒商城九商品服务之商品属性及仓储服务todo
  18. 异常报错原因及解决方案
  19. 实用epub阅读器分享
  20. 计算机软考 下午试题,2011年计算机软考程序员考试(下午题)模拟试题及答案(2)...

热门文章

  1. [六省联考2017]组合数问题
  2. 苹果公司的新的编程语言 Swift 高级语言()两--基本数据类型
  3. sql server 2005 32位+64位、企业版+标准版、CD+DVD 下载地址大全
  4. java nio改造io,java – 将NIO与IO混合
  5. winform打开cad图纸_为什么CAD图纸打开后会显示很多问号“???”,该怎么解决...
  6. WAP2.0开发规范及原则
  7. php遍历一个目录 并重命名
  8. odbc spoon连接postgre_ado、odbc连接Postgre SQL
  9. python短信发送查询数据库结果_向Django数据库中的每个号码发送短信
  10. 58端口使用技巧跟推送_Kindle使用技巧:定时推送