作者&编辑 | 郭冰洋

1 简介

小伙伴们在利用公共数据集动手搭建图像分类模型时,有没有注意到这样一个问题呢——每个数据集不同类别的样本数目几乎都是一样的。这是因为不同类别的样例数目差异较小,对分类器的性能影响不大,可以在避免其他因素的影响下,充分反映分类模型的性能。反之,如果类别间的样例数目相差过大,会对学习过程造成一定的影响,从而导致分类模型的性能变差。这就是本篇文章将要讨论的类别不平衡问题(Class Imbalance)。

类别不平衡是指分类任务中不同类别的训练样本数目相差较大的情况,通常是由于样本较难采集或样本示例较少而引起的,经常出现在疾病类别诊断、欺诈类型判别等任务中。

尽管在传统机器学习领域内,有关类别不平衡的问题已经得到了详尽的研究,但在深度学习领域内,其相关探索随着深度学习的发展,经历了一个先抑后扬的过程。

在反向传播算法诞生初期,有关深度学习的研究尚未成熟,但仍有相关科研人员研究过类别样例的数目对梯度传播的影响,并得出样例数目较多的类别在反向传播时对权重占主导地位。这一现象会使网络训练初期,快速的降低数目较多类别的错误率,但随着训练的迭代次数增加,数目较少类的错误率会随之上升[1]。

随后的十余年里,由于深度学习受到计算资源的限制、数据集采集的难度较大等影响,相关研究并没有得到进一步的探索,直到近年来才大放异,而深度学习领域内的类别不平衡问题,也得到了更加深入的研究。

本篇文章将对目前涉及到的相关解决方案进行汇总,共分为数据层面、算法层面、数据和算法混合层面三个方面,仅列举具有代表性的方案阐述,以供读者参考。

2 方法汇总

1、基于数据层面的方法

基于数据层面的方法主要对参与训练的数据集进行相应的处理,以减少类别不平衡带来的影响。

Hensman等[2]提出了

提升样本(over sampling)的方法,即对于类别数目较少的类别,从中随机选择一些图片进行复制并添加至该类别包含的图像内,直到这个类别的图片数目和最大数目类的个数相等为止。通过实验发现,这一方法对最终的分类结果有了非常大的提升。

Lee等[3]提出了一种

两阶段(two-phase)训练法。首先根据数据集分布情况设置一个阈值N,通常为最少类别所包含样例个数。随后对样例个数大于阈值的类别进行随机抽取,直到达到阈值。此时根据阈值抽取的数据集作为第一阶段的训练样本进行训练,并保存模型参数。最后采用第一阶段的模型作为预训练数据,再在整个数据集上进行训练,对最终的分类结果有了一定的提升.

Pouyanfar等[4]则提出了一种

动态采样(dynamic sampling)的方法。该方法借鉴了提升样本的思想,将根据训练结果对数据集进行动态调整,对结果较好的类别进行随机删除样本操作,对结果较差的类别进行随机复制操作,以保证分类模型每次学习都能学到相关的信息。

2、基于算法层面的方法

基于算法层面的方法主要对现有的深度学习算法进行改进,通过修改损失函数或学习方式的方法来消除类别不平衡带来的影响。

Wang等[5]提出

mean squared false error (MSFE) loss。这一新的损失函数是在mean false error (MFE) loss的基础上进行改进,具体公式如下图所示:

MSFE loss能够很好地平衡正反例之间的关系,从而实现更好的优化结果。

Buda等[6]提出

输出阈值(output thresholding)的方法,通过调整网络结果的输出阈值来改善类别不平衡的问题。模型设计者根据数据集的构成和输出的概率值,人工设计一个合理的阈值,以降低样本数目较少的类别的输出要求,使得其预测结果更加合理。

3、基于数据和算法的混合方法

上述两类层面的方法均能取得较好的改善结果,如果将两种思想加以结合,能否有进一步的提升呢?

Huang等[7]提出

Large Margin Local Embedding (LMLE)的方法,采用五倍抽样法(quintuplet sampling )和tripleheader hinge loss函数,可以更好地提取样本特征,随后将特征送入改进的K-NN分类模型,能够实现更好的聚类效果。除此之外,Dong等[8]则融合了难例挖掘和类别修正损失函数的思想,同样是在数据和损失函数进行改进。

由于篇幅和时间有限,本文只列取了每个类别的典型解决方案。同时也搜集了关于解决类别不平衡问题的相关综述文献,截图如下:

具体名称可以借鉴参考文献[9]。

3 参考文献

[1] Anand R, Mehrotra KG, Mohan CK, Ranka S. An improved algorithm for neural network classification of imbalanced training sets. IEEE Trans Neural Netw. 1993;4(6):962–9.

[2] Hensman P, Masko D. The impact of imbalanced training data for convolutional neural networks. 2015.

[3] Lee H, Park M, Kim J. Plankton classification on imbalanced large scale database via convolutional neural networks with transfer learning. In: 2016 IEEE international conference on image processing (ICIP). 2016. p. 3713–7.

[4] Pouyanfar S, Tao Y, Mohan A, Tian H, Kaseb AS, Gauen K, Dailey R, Aghajanzadeh S, Lu Y, Chen S, Shyu M. Dynamic sampling in convolutional neural networks for imbalanced data classification. In: 2018 IEEE conference on multimedia information processing and retrieval (MIPR). 2018. p. 112–7.

[5] Wang S, Liu W, Wu J, Cao L, Meng Q, Kennedy PJ. Training deep neural networks on imbalanced data sets. In: 2016 international joint conference on neural networks (IJCNN). 2016. p. 4368–74.

[6] Buda M, Maki A, Mazurowski MA. A systematic study of the class imbalance problem in convolutional neural

networks. Neural Netw. 2018;106:249–59.

[7] Huang C, Li Y, Loy CC, Tang X. Learning deep representation for imbalanced classification. In: 2016 IEEE conference on computer vision and pattern recognition (CVPR). 2016. p. 5375–84.

[8] Dong Q, Gong S, Zhu X. Imbalanced deep learning by minority class incremental rectification. In: IEEE transactions on pattern analysis and machine intelligence. 2018. p. 1–1

[9] Justin M. Johnson and Taghi M. Khoshgoftaar.Survey on deep learning with class imbalance.Johnson and Khoshgoftaar J Big Data.(2019) 6:27

总结

以上就是关于类别不平衡问题的相关解决方案,详细内容可以阅读参考文献综述9,相信通过更加详细的文章阅读,你会收获更多的经验!

https://www.toutiao.com/a6727841366342107655/

深度学习分类类别不平衡_「图像分类」 关于图像分类中类别不平衡那些事相关推荐

  1. 专访香港大学罗平:师从汤晓鸥、王晓刚,最早将深度学习应用于计算机视觉的「先行者」

    虽然 ICCV 2019 落幕已近两周,但是这场对于华人研究者而言具备「转折点」意义的国际学术顶会在大家心中掀起的波澜,想必依旧未了. 在今年这场 CV 领域的学术盛宴中,我们一如既往地看到了不少长期 ...

  2. 如何实现高速卷积?深度学习库使用了这些「黑魔法」

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 来源:公众号 机器之心 授权转载 使用深度学习库可以大幅加速CNN ...

  3. vivado中bit文件怎么没有生成_「干货」FPGA设计中深度约束技巧及调试经验总结...

    今天跟大家分享的内容很重要,也是我们调试FPGA经验的总结.随着FPGA对时序和性能的要求越来越高,高频率.大位宽的设计越来越多.在调试这些FPGA样机时,需要从写代码时就要小心谨慎,否则写出来的代码 ...

  4. oracle split函数用法_「干货」Python字符串中的split方法

    前面的文章我们有简单的介绍过什么是字符串.Python字符串的输入方式.Python字符串的拼接方法等今天我们主要分享一下Python字符串中split方法! Python字符串中的方法有很多种,其中 ...

  5. 深度学习分类任务常用评估指标

    摘要:这篇文章主要向大家介绍深度学习分类任务评价指标,主要内容包括基础应用.实用技巧.原理机制等方面,希望对大家有所帮助. 本文分享自华为云社区<深度学习分类任务常用评估指标>,原文作者: ...

  6. 深度学习用于视频检测_视频如何用于检测您的个性?

    深度学习用于视频检测 视频是新的第一印象! (Videos are the New First Impressions!) Think about the approximate number of ...

  7. 深度学习背后的数学_深度学习背后的简单数学

    深度学习背后的数学 Deep learning is one of the most important pillars in machine learning models. It is based ...

  8. 深度学习分类pytorch_pytorch使用转移学习的狗品种分类器

    深度学习分类pytorch So have you heard the word deep learning before? Or have you just started learning it? ...

  9. 使用深度学习分类猫狗图片

    使用深度学习分类猫狗图片 前言 一.下载数据 二.构建网络 三.数据预处理 四.使用数据增强 总结 前言 本文将介绍如何使用较少的数据从头开始训练一个新的深度学习模型.首先在一个2000个训练样本上训 ...

  10. 卷积神经网络经典论文集合(深度学习分类篇)

    卷积神经网络经典论文集合 为方便撰写深度学习分类网络综述,现将近年以来经典论文做一个整理.文章时间大部分参考arXiv分享时间为准,小部分为期刊的出版日期. 下载地址 CSDN:https://dow ...

最新文章

  1. 《数学之美》第17章 由电视剧《暗算》所想到的—谈谈密码学的数学原理
  2. 区块链BaaS云服务(21)腾讯CCGP“跨链服务”
  3. 筛法求素数 素数打表
  4. (数据结构)前缀,后缀以及中缀表达式
  5. linux安装easy php,Linux php安装
  6. 渗透测试之Nmap命令(二)
  7. Java操作数据库详解
  8. maya多边形建模怎样做曲面_一名合格的模型师,不得不学习掌握的几种建模方法,你会了吗?...
  9. mac如何安装python_手把手教你安装Python开发环境(二)之Mac电脑安装Python解释器...
  10. Andriod UI设计之度量单位说明(DIP,DP,PX,SP)
  11. 利用正则表达式 替换字符串中多个 URL
  12. Python操作DB2数据库
  13. dex2jar 和 jd-gui 的安装与使用
  14. 吉林大学珠海学院计算机录取分数线,大学介绍 | 吉林大学珠海学院(附录取分数线,重点专业)...
  15. Python的三目表达式and简短语法
  16. web点播VOD m3u8播放识别为live流 播放几个.ts切片停止播放 排错
  17. oracle+soacs,第 3 章 使用 C++ 编译器选项
  18. 如何搭建自己的wiki
  19. PP实施经验分享(17)——S4 PP与ME标准接口报工函数“CO_MES_PRODORDCONF_CREATE_TT”
  20. Mysql中,order by + limt的大坑

热门文章

  1. python3-matplotlib绘制散点图、绘制条形图
  2. android翻盘效果,行情艰难,Android初中级面试题助你逆风翻盘,每题都有详细答案...
  3. c语言switch caseh(op),switch语句求教
  4. greenplum配置高可用_高可用hadoop集群配置就收藏这一篇,动手搭建Hadoop(5)
  5. C语言 do while 和 while 循环 - C语言零基础入门教程
  6. Pycharm 提示:this license * has been cancelled - Python零基础入门教程
  7. BugkuCTF-WEB题web16备份是个
  8. mysql ssh错误_通过SSH隧道连接时,MySQL访问被拒绝错误
  9. C语言实现单链表面试题汇总
  10. 优秀自我简介200字_全球战疫 翰墨传情——东方盛世杯网络公益书画展优秀作品【二】...