开源代码:https://github.com/xxcheng0708/Pytorch_Image_Classifier_Template​​​​​

使用pytorch框架搭建一个图像分类模型通常包含以下步骤:

1、数据加载DataSet,DataLoader,数据转换transforms

2、构建模型、模型训练

3、模型误差分析

下面依次来看一下上述几个步骤的实现方法:

一、数据加载、数据增强

a)、有时候torchvision.transform中提供的数据转换方法不能满足项目需要,需要自定义数据转换方法进行数据增强,以下InvertTransform类实现了__init__和__call__方法对图像的像素值进行翻转:

b)、数据加载可以参考pytorch数据集加载之DataSet和DataLoader。

c)、使用torchvision.transforms.Compose组合多种数据转换方法进行数据转换:

其中,train_transform用于训练集,val_transform用于验证集,训练集和验证集要进行相同的数据转换操作。并且在pytorch中提供的transform数据转换方法有的处理目标是Image对象,有的处理目标是tensor,并且经过处理后的数据维度变为[N,C,H,W]。

d)、使用Dataset和DataLoader进行数据加载

  二、模型构建

在pytorch中提供了常用的分类网络的创建接口以及预训练权重,一般情况下直接使用预训练权重来初始化backbone网络,修改网络的输出层来适配自己的数据集,仅训练网络输出层或者输出层及其之前几层就可以。模型构建方法实例如下:

在上述代码中,创建了MobileNet_V2网络模型,并且同时下载了预训练模型参数,同时修改了网络中的classifier模块,将输出层的维度修改为自己的数据集类别数量,同时将classifier模块的参数和模型的其他参数进行区分,模型除classifier 模块参数之外的参数使用预训练模型参数进行初始化,classifier模块参数使用随机初始化。然后在优化器中对参数进行分组训练,classifier模块的参数需要重点训练,使用较大的学习率,其他模块的参数是预训练模型参数,并且处于浅层网络中,参数仅需要稍微修改就可以,使用较小的学习率。

三、模型训练

模型训练的主要流程是,从DataLoader中分批加载数据送入模型,将模型预测结果与真实结果使用定义的loss计算损失,基于计算的损失进行梯度反向传播进行参数优化。

四、模型评估,模型准确性评估

五、学习曲线

根据迭代训练过程中保存训练集、验证集的准确率、损失值、学习率绘制学习曲线,来判断模型的学习情况,是否出现过拟合、欠拟合等。

六、误差分析

吴恩达大佬不止一次强调过误差分析的重要性,在分类模型训练过程中进行误差分析,可以清晰的看到误差来源,哪些样本容易被误识别,这些样例有什么规律,从而进行数据调整提升模型性能。在最近吴恩达大佬的一次讲座中,再次强调了误差分析的重要性,讲座中提出当你的模型性能遇到瓶颈时,是以模型为中心调整模型呢?还是以数据为中心调整数据呢?大佬的结论是调整模型几乎不会带来性能的提升,而调整数据能带来模型性能的大幅提升。下面就看一下图像分类模型的误差分析示例代码:

该方法根据模型预测结果以及样本的真实标签,计算样本的正确与否,对于预测错误的样本,保存样本的真实标签,预测标签,从而能够清晰的看到那些被分类错误的样本都被误分类成了什么。

就写到这吧!有疑问欢迎随时交流。

使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)相关推荐

  1. Pytorch通用图像分类模型(支持20+分类模型),直接带入数据就可训练自己的数据集,包括模型训练、推理、部署。

    Pytorch-Image-Classifier-Collection 介绍 ============================== 支持多模型工程化的图像分类器 =============== ...

  2. pytorch 模型可视化_高效使用Pytorch的6个技巧:为你的训练Pipeline提供强大动力

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Eugene Khvedchenya 编译:ronghuaiyang 导读 ...

  3. pytorch 模型可视化_【深度学习】高效使用Pytorch的6个技巧:为你的训练Pipeline提供强大动力...

    作者:Eugene Khvedchenya   编译:ronghuaiyang 导读 只报告模型的Top-1准确率往往是不够的. 将train.py脚本转换为具有一些附加特性的强大pipeline 每 ...

  4. gpu处理信号_在PyTorch中使用DistributedDataParallel进行多GPU分布式模型训练

    先进的深度学习模型参数正以指数级速度增长:去年的GPT-2有大约7.5亿个参数,今年的GPT-3有1750亿个参数.虽然GPT是一个比较极端的例子但是各种SOTA模型正在推动越来越大的模型进入生产应用 ...

  5. Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练)

    Pytorch基础训练库Pytorch-Base-Trainer(支持模型剪枝 分布式训练) 目录 Pytorch基础训练库Pytorch-Base-Trainer(PBT)(支持分布式训练) 1.I ...

  6. 结构化数据建模——titanic数据集的模型建立和训练(Pytorch版)

    本文参考<20天吃透Pytorch>来实现titanic数据集的模型建立和训练 在书中理论的同时加入自己的理解. 一,准备数据 数据加载 titanic数据集的目标是根据乘客信息预测他们在 ...

  7. 把一个dataset的表放在另一个dataset里面_现在开始:用你的Mac训练和部署一个图像分类模型...

    可能有些同学学习机器学习的时候比较迷茫,不知道该怎么上手,看了很多经典书籍介绍的各种算法,但还是不知道怎么用它来解决问题,就算知道了,又发现需要准备环境.准备训练和部署的机器,啊,好麻烦. 今天,我来 ...

  8. 【金融】【pytorch】使用深度学习预测期货收盘价涨跌——LSTM模型构建与训练

    [金融][pytorch]使用深度学习预测期货收盘价涨跌--LSTM模型构建与训练 LSTM 创建模型 模型训练 查看指标 LSTM 创建模型 指标函数参考<如何用keras/tf/pytorc ...

  9. 【神经网络与深度学习】CIFAR10数据集介绍,并使用卷积神经网络训练图像分类模型——[附完整训练代码]

    [神经网络与深度学习]CIFAR-10数据集介绍,并使用卷积神经网络训练模型--[附完整代码] 一.CIFAR-10数据集介绍 1.1 CIFAR-10数据集的内容 1.2 CIFAR-10数据集的结 ...

最新文章

  1. 我马上会重新利用这个博客的
  2. Eclipse配置Android开发环境
  3. 面向对象编程 封装 继承 多态(三大特征)(第三篇)
  4. 小G的项链(Manacher)
  5. windows 终端查看python位置
  6. 打通高德、UC、微博,支付宝小程序组建“阿里联盟军”对抗微信小程序?| 技术头条...
  7. 开课吧Java课堂:字符串如何处理?
  8. mybatis批量删除提示类型错误
  9. stochastic noise and deterministic noise
  10. Jmeter学习之旅(四)——各类型的HTTP接口功能测试
  11. 字符串全排列 java实现
  12. Wordpress 5.2 beta 2 发布,支持 Emoji 12
  13. UE4 android开发
  14. Apple Magic Mouse 卡顿的问题
  15. ai面试的优缺点_找工作时让AI给你面试,你愿意吗?
  16. 带你深入剖析TCP/IP协议、TCP协议和UDP协议、IP协议
  17. 【python+selenium】自动登陆青果教务系统
  18. 通信工程是计算机类还是电子信息类公考,通信工程属于电子信息类吗
  19. 一刀工具箱 - 成语查询工具
  20. LABVIEW绘制等高线

热门文章

  1. 程序员之路:sublime使用技巧
  2. excel 右键打不开表格修复以及excel打开独立窗口的修复
  3. nvm node版本切换无效
  4. matlab 判断语句是否为真,matlab 条件判断语句不生效
  5. linux 运行class文杰,Linux详细教程
  6. Git 指令,看这个就够了,赶紧收藏,方便查阅
  7. VxWorks设备驱动程序开发指南---驱动程序的分类
  8. C++ 中左值和右值引用的讲解
  9. Android gradle配置签名文件
  10. PyAutoGui图像操作(二):图像定位不稳定解决方案