Pytorch通用图像分类模型(支持20+分类模型),直接带入数据就可训练自己的数据集,包括模型训练、推理、部署。
Pytorch-Image-Classifier-Collection
介绍
==============================
支持多模型工程化的图像分类器
==============================
软件架构
Pytorch+opencv
模型支持架构
模型
- | - | - | - |
---|---|---|---|
resnet18 | resnet34 | resnet50 | resnet101 |
resnet152 | resnext101_32x8d | resnext50_32x4d | wide_resnet50_2 |
wide_resnet101_2 | densenet121 | densenet161 | densenet169 |
densenet201 | vgg11 | vgg13 | vgg13_bn |
vgg19 | vgg19_bn | vgg16 | vgg16_bn |
inception_v3 | mobilenet_v2 | mobilenet_v3_small | mobilenet_v3_large |
shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 |
alexnet | googlenet | mnasnet0_5 | mnasnet1_0 |
mnasnet1_3 | mnasnet0_75 | squeezenet1_0 | squeezenet1_1 |
efficientnet-b0(0-7) |
损失函数
- | - | - | - |
---|---|---|---|
mse | l1 | smooth_l1 | cross_entropy |
优化器
- | - | - | - |
---|---|---|---|
SGD | ASGD | Adam | AdamW |
Adamax | Adagrad | Adadelta | SparseAdam |
LBFGS | Rprop | RMSprop |
安装教程
pytorch>=1.5即可,其余库自行安装即可。
使用说明
配置文件config/config.yaml
data_dir: "./data/" #数据集存放地址 train_rate: 0.8 #数据集划分,训练集比例 image_size: 128 #输入网络图像大小 net_type: "shufflenet_v2_x1_0" pretrained: True #是否添加预训练权重 batch_size: 4 #批次 init_lr: 0.01 #初始学习率 optimizer: 'Adam' #优化器 class_names: [ 'cat','dog' ] #你的类别名称,必须和data文件夹下的类别文件名一样 epochs: 10 #训练总轮次 loss_type: "mse" # mse / l1 / smooth_l1 / cross_entropy #损失函数 model_dir: "./shufflenet_v2_x1_0/weight/" #权重存放地址 log_dir: "./shufflenet_v2_x1_0/logs/" # tensorboard可视化文件存放地址
模型训练
# 第一次训练 python train.py # 接着自己未训练完成的模型继续训练 python train.py --weights_path 模型保存路径
模型推理
# 检测图片 python infer.py image --image_path 图片地址 # 检测视频 python infer.py video --video_path 图片地址 # 检测摄像头 python infer.pu camera --camera_id 摄像头id
部署
onnx打包部署
# onnx打包 python pack_tools/pytorch_to_onnx.py --config_path 配置文件地址 --weights_path 模型权重存放地址 # onnx推理部署 # 检测图片 python pack_tools/pytorch_onnx_infer.py image --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --image_path 图片地址 # 检测视频 python pack_tools/pytorch_onnx_infer.py video --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --video_path 图片地址 # 检测摄像头 python pack_tools/pytorch_onnx_infer.py camera --config_path 配置文件地址 --onnx_path 打包完成的onnx包地址 --camera_id 摄像头id,默认为0
模型剪枝、量化压缩加速
模型剪枝微调
# 模型剪枝微调 python prune_model/pruning_model.py --weight_path 已训练好的模型权重地址 --prune_type 修剪模型的方式,支持:l1filter,l2filter,fpgm --sparsity 模型稀疏化比例 --finetune_epoches 微调模型的轮次数 --dummy_input 输入模型的形状,例如:(10,3,128,128) # onnx推理部署 # 检测图片 python infer_prune_model.py image --prune_weights_path 剪枝后的模型权重路径 --image_path 图片地址 # 检测视频 python infer_prune_model.py video --prune_weights_path 剪枝后的模型权重路径 --video_path 图片地址 # 检测摄像头 python infer_prune_model.py camera --prune_weights_path 剪枝后的模型权重路径 --camera_id 摄像头id,默认为0
参与贡献
作者:qiaofengsheng
B站地址:深度学习麋了鹿的个人空间_哔哩哔哩_Bilibili
github地址:https://github.com/qiaofengsheng/Pytorch-Image-Classifier-Collection.git
gitee地址:Pytorch-Image-Classifier-Collection: 支持多模型工程化的图像分类器
Pytorch通用图像分类模型(支持20+分类模型),直接带入数据就可训练自己的数据集,包括模型训练、推理、部署。相关推荐
- 搭建并训练多标签数据集的模型并将结果可视化
#搭建并训练多标签数据集的模型并将结果可视化(tensorflow2) 1.数据集的介绍 该数据为拥有颜色与衣服类别两个标签的衣服识别,对于这样的数据集要求我们的神经网络需要两个输出,一个是类别,另一 ...
- Html显示3D Obj模型(支持mtl纹理)的源码方案(2:一张图创建人脸模型)
不要瞧不起html脚本语言,自从浏览器取消了flash,你是不是觉得html做不了炫酷的内容,错了!现在html支持OpenGL,支持直接渲染3D游戏.今天我们就来讲如何在html渲染3D 的Obj模 ...
- Windows下使用Darknet训练自己的数据集(模型:yolov4-tiny、数据集:垃圾分类)
本文章主要介绍如何使用Darknet在windows下训练自己的数据集,其中模型使用的是yolov4-tiny,数据集使用的是自己垃圾分类数据集(需要的自取:在我上传的资源中有) PS:这是我的第一篇 ...
- 【深度学习】mask_rcnn训练自己的数据集以及模型使用(实践结合GitHub项目)
根据requirements - 开源项目默认的.txt进行库安装 环境:WIN10 + Anoconda + Pycharm + python3.6.2 mask_rcnn基本流程1.训练 1)la ...
- 搭建基于飞桨的OCR工具库,总模型仅8.6M的超轻量级中文OCR,单模型支持中英文数字组合识别、竖排文本识别、长文本识别的PaddleOCR
介绍 基于飞桨的OCR工具库,包含总模型仅8.6M的超轻量级中文OCR,单模型支持中英文数字组合识别.竖排文本识别.长文本识别.同时支持多种文本检测.文本识别的训练算法. 相关链接 PaddleOCR ...
- 从零开始编写一个宠物识别系统(爬虫、模型训练和调优、模型部署、Web服务)
心血来潮,想从零开始编写一个相对完整的深度学习小项目.想到就做,那么首先要考虑的问题是,写什么? 思量再三,我决定写一个宠物识别系统,即给定一张图片,判断图片上的宠物是什么.宠物种类暂定为四类--猫. ...
- 调用“抱抱脸团队打造的Transformers pipeline API” 通过预训练模型,快速训练和微调自己的模型
本文章根据官方文件总结而成,根据第三方库Transformers and pytorch快速搭建自己的神经网络架构,可以直接下载预训练模型,涉及的数据集包括音频.文字.图像等,实用性非常强! 官方链接 ...
- 基于LightGBM分类实现英雄联盟数据预测(二)
基于LightGBM分类实现英雄联盟数据预测(二) 这里写目录标题 基于LightGBM分类实现英雄联盟数据预测(二) Step5:利用 LightGBM 进行训练与预测 plt.figure(fig ...
- Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练)
Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练) 目录 Pytorch基础训练库Pytorch-Base-Trainer(PBT)(支持分布式训练) 1.I ...
最新文章
- 还在用Logback?Log4j2的异步性能已经无敌了,还不快试试
- Matlab之Kalman:用线性系统状态方程,通过系统输入输出观测数据,对系统状态进行最优估计的算法
- 微分算子为什么也是空间滤波器
- android自定义图片文本,Android 实现文字与图片的混排
- 信息系统项目管理师论文评分标准
- comsol固体传热_【 COMSOL 知识库】如何解决 COMSOL 软件“内存不足” 的问题
- J2SE理解之一:声明和访问控制
- Extjs 4.2 MVC+ThreeJs学习笔记(二)一个简单的ThreeJS场景
- Android 充电LED控制
- 静态的通讯录(C语言)
- luogu 1337
- 电脑突然调节不了亮度?让我教你来恢复
- 推荐一款专业串烧歌曲的音乐合并软件
- Linux下关闭udhcpc客户端时,通知服务器释放租约
- python 画图工具——matplotlib命令式函数
- HSV和RGB相互转换
- 维护最短路径条数和途径点的权值累加
- win10 1607 密匙
- VMware Workstation 无法连接到虚拟机
- 教你亲手制作一个虚拟数字人,超全步骤详解
热门文章
- 远程管理计算机用户账户限制,用户帐户控制和远程限制 - Windows Server | Microsoft Docs...
- 灵隐寺招聘:没有KPI,佛系上班……
- SEC起诉瑞波,中本聪早有论断
- GridView文本自动换行
- Elasticsearch Search API
- 第十八届全国大学生智能汽车竞赛网络报名方法
- 摆脱流量依赖,“心智营销”是玄学吗?
- VMware检测不到vulnhub靶机IP地址解决办法
- 新手必看——微软认证考试
- java编写股票交易软件有哪些,java开发程序源代码_炒股软件说明-小S股票