【图像分类】如何使用 mmclassification 训练自己的分类模型
文章目录
- 一、数据准备
- 二、模型修改
- 三、模型训练
- 四、模型效果可视化
- 五、如何分别计算每个类别的精确率和召回率
MMclassification 是一个分类工具库,这篇文章是简单记录一下如何用该工具库来训练自己的分类模型,包括数据准备,模型修改,模型训练,模型测试等等。
MMclassification链接:https://github.com/open-mmlab/mmclassification
安装:https://mmclassification.readthedocs.io/en/latest/install.html
训练:https://mmclassification.readthedocs.io/en/latest/getting_started.html
一、数据准备
MMclassification 支持 ImageNet 和 cifar 两种数据格式,我们以 ImageNet 为例来看看数据结构:
|- imagenet
| |- classmap.txt
| |- train
| | |- cls1
| | |- cls2
| | |- cls3
| | |- ...
| |- train.txt
| |- val
| | |- images
| |- val.txt
假设我们要训练一个猫狗二分类模型,则需要组织的形式如下:
|- dog_cat_dataset
| |- classmap.txt
| |- train
| | |- dog
| | |- cat
| |- train.txt
| |- val
| | |- images
| |- val.txt
其中,classmap.txt 中的内容如下:
dog 0
cat 1
二、模型修改
假设使用 resnet18 来训练,则我们需要修改的内容主要集中在 config 文件里边,修改后的config文件 resnet18_b32x8_dog_cat_cls.py
如下:
- 修改类别:将 1000 类改为 2 类
- 修改数据路径:data
- 如果数据前处理需要修改的话,也可以在config里边修改
- 因为config是最高级的,所以在这里修改后会覆盖模型从mmcls库中读出来的参数
_base_ = ['../_base_/models/resnet18.py', '../_base_/datasets/imagenet_bs32.py','../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
model = dict(head=dict(type='LinearClsHead',num_classes=2,in_channels=512,loss=dict(type='CrossEntropyLoss', loss_weight=1.0),topk=(1, ),))data = dict(samples_per_gpu=32,workers_per_gpu=1,train=dict(data_prefix='data/dog_cat_dataset/train',ann_file='data/dog_cat_dataset/train.txt',classes='data/dog_cat_dataset/classmap.txt'),val=dict(data_prefix='data/dog_cat_dataset/val',ann_file='data/dog_cat_dataset/val.txt',classes='data/dog_cat_dataset/classmap.txt'),test=dict(# replace `data/val` with `data/test` for standard testdata_prefix='data/dog_cat_dataset/val',ann_file='data/dog_cat_dataset/val.txt',classes='data/dog_cat_dataset/classmap.txt'))
evaluation = dict(interval=1, metric='accuracy', metric_options={'topk': (1, )})
三、模型训练
python tools/train.py configs/resnet/resnet18_b32x8_dog_cat_cls.py
四、模型效果可视化
python tools/test.py configs/resnet/resnet18_b32x8_dog_cat_cls.py ./models/epoch_99.pth --out result.pkl --show-dir output_cls
使用 gradcam 可视化:
python tools/visualizations/vis_cam.py visual_img/4.jpg configs/resnet/resnet18_b32x8_door.py ./models/epoch_99.pth --s
ave-path visual_path/4.jpg
五、如何分别计算每个类别的精确率和召回率
先进行测试,得到 result.pkl
文件,然后运行下面的程序即可:
python tools/cal_precision.py configs/resnet/resnet18_b32x8_imagenet.py
import mmcv
import argparse
from mmcls.datasets import build_dataset
from mmcls.core.evaluation import calculate_confusion_matrix
from sklearn.metrics import confusion_matrixdef parse_args():parser = argparse.ArgumentParser(description='calculate precision and recall for each class')parser.add_argument('config', help='test config file path')args = parser.parse_args()return argsdef main():args = parse_args()cfg = mmcv.Config.fromfile(args.config)dataset = build_dataset(cfg.data.test)pred = mmcv.load("./result.pkl")['pred_label']matrix = confusion_matrix(pred, dataset.get_gt_labels())print('confusion_matrix:', matrix)cat_recall = matrix[0,0]/(matrix[0,0]+matrix[1,0])dog_recall = matrix[1,1]/(matrix[0,1]+matrix[1,1])cat_precision = matrix[0,0]/sum(matrix[0])dog_precision = matrix[1,1]/sum(matrix[1])print(' cat_precision:{} \n dog_precison:{} \n cat_recall:{} \n dog_recall:{}'.format(cat_precision, dog_precison, cat_recall, dog_recall))if __name__ == '__main__':main()
【图像分类】如何使用 mmclassification 训练自己的分类模型相关推荐
- tensorflow.js在nodejs训练猫狗分类模型在浏览器上使用
目录 本人系统环境 注意事项 前言 数据集准备 处理数据集 数据集初步处理 将每一张图片数据转换成张量数据(tensor) 将图片转换成张量数组的代码和运行效果 将图片的标注转换成张量数据(tenso ...
- C++利用opencv调用pytorch训练好的分类模型
pytorch保存模型 import torch.onnxd = torch.rand(1, 3, 224, 224,dtype=torch.float,device = 'cuda') m = mo ...
- 用英特尔独立显卡训练AI智能收银机分类模型
作者:罗宏裕,张晶 英特尔独立显卡技术指导:唐文凯 本文将介绍在英特尔独立显卡上训练AI智能收银机分类模型的全流程,在下一篇中将介绍基于OpenVINOTM在AIxBoard上部署训练好的模型,快速实 ...
- 详细讲解分类模型评价指标(混淆矩阵)python示例
前言 1.回归模型(regression): 对于回归模型的评估方法,通常会采用平均绝对误差(MAE).均方误差(MSE).平均绝对百分比误差(MAPE)等方法. 2.聚类模型(clustering) ...
- scikit-learn工具包中分类模型predict_proba、predict、decision_function用法详解
在使用sklearn训练完分类模型后,下一步就是要验证一下模型的预测结果,对于分类模型,sklearn中通常提供了predict_proba.predict.decision_function三种方法 ...
- 【超详细】MMLab分类任务mmclassification:环境配置说明、训练、预测及模型结果可视化展示
本文详细介绍了使用MMLab的mmclassification进行分类任务的环境配置.训练与预测流程. 目录 文件配置说明 下载源码 配置文件 基于预训练模型微调或者续训练自己模型的方式 配置文件说明 ...
- 星星模型 维度_用模型“想象”出来的target来训练,可以提高分类的效果!
LearnFromPapers系列--用模型"想象"出来的target来训练,可以提高分类的效果 作者:郭必扬 时间:2020年最后一天 前言:今天是2020年最后一天,这篇文章也 ...
- pytorch 训练过程acc_深度学习Pytorch实现分类模型
今天将介绍深度学习中的分类模型,以下主要介绍Softmax的基本概念.神经网络模型.交叉熵损失函数.准确率以及Pytorch实现图像分类.01Softmax基本概念 在分类问题中,通常标签都为类别,可 ...
- 神经网络学习小记录19——微调VGG分类模型训练自己的数据(猫狗数据集)
神经网络学习小记录19--微调VGG分类模型训练自己的数据(猫狗数据集) 注意事项 学习前言 什么是VGG16模型 VGG模型的复杂程度 训练前准备 1.数据集处理 2.创建Keras的VGG模型 3 ...
最新文章
- vb破解万能断点816c24
- 原来Rproj还可以这么使用
- Linux内核同步 - Read/Write spin lock
- WSL:ssh 本地与阿里云数据交互
- 详细理解中缀表达式并实现
- matlab gui简单教程
- 教你100%成功安装Mathcad15
- 一款好用的Windows引导项管理工具BOOTICE
- flutter 的像素尺寸
- [4G5G专题-124]:5G培训部署篇-2-主要信令流程
- unit英语读音_unit是什么意思_unit翻译_读音_用法_翻译
- 计算机管理系统功能模块,设备管理系统功能模块
- JPA ERROR: value too long for type character varying(100)
- 面试官这样,面试就有戏了!
- 推荐四款自用的电脑神器
- 微星GS65 英雄联盟崩溃
- 谷粒商城九商品服务之商品属性及仓储服务todo
- 异常报错原因及解决方案
- 实用epub阅读器分享
- 计算机软考 下午试题,2011年计算机软考程序员考试(下午题)模拟试题及答案(2)...
热门文章
- [六省联考2017]组合数问题
- 苹果公司的新的编程语言 Swift 高级语言()两--基本数据类型
- sql server 2005 32位+64位、企业版+标准版、CD+DVD 下载地址大全
- java nio改造io,java – 将NIO与IO混合
- winform打开cad图纸_为什么CAD图纸打开后会显示很多问号“???”,该怎么解决...
- WAP2.0开发规范及原则
- php遍历一个目录 并重命名
- odbc spoon连接postgre_ado、odbc连接Postgre SQL
- python短信发送查询数据库结果_向Django数据库中的每个号码发送短信
- 58端口使用技巧跟推送_Kindle使用技巧:定时推送