转载自  训练集样本不平衡问题对CNN的影响

训练集样本不平衡问题对CNN的影响

本文首发于知乎专栏“ai insight”!

卷积神经网络(CNN)可以说是目前处理图像最有力的工具了。

而在机器学习分类问题中,样本不平衡又是一个经常遇到的问题。最近在使用CNN进行图片分类时,发现CNN对训练集样本不平衡问题很敏感。

在网上搜索了一下,发现http://www.diva-portal.org/smash/get/diva2:811111/FULLTEXT01.pdf这篇文章对这个问题已经做了比较细致的探索。于是就把它简单整理了一下,相关的记录如下。

1、实验数据与使用的网络

所谓样本不平衡,就是指在分类问题中,每一类对应的样本的个数不同,而且差别较大。

这样的不平衡的样本往往使机器学习算法的表现变得比较差。那么在CNN中又有什么样的影响呢?作者选用了CIFAR-10作为数据源来生成不平衡的样本数据。

CIFAR-10是一个简单的图像分类数据集。共有10类(airplane,automobile,bird,cat,deer,dog, frog,horse,ship,truck),每一类含有5000张训练图片,1000张测试图片。

CIFAR-10样例如图:

训练时,选择的网络是这里的CIFAR-10训练网络和参数(来自Alex Krizhevsky)。这个网络含有3个卷积层,还有10个输出结点。

之所以不选用效果更好的CNN网络,是因为我们的目的是在实验时训练很多次进行比较,而不是获得多么好的性能。

而这个CNN网络因为比较浅,训练速度比较快,比较符合我们的要求。

2、类别不平衡数据的生成

直接从原始CIFAR-10采样,通过控制每一类采样的个数,就可以产生类别不平衡的训练数据。如下表所示:

这里的每一行就表示“一份”训练数据。而每个数字就表示这个类别占这“一份”训练数据的百分比。

Dist. 1:类别平衡,每一类都占用10%的数据。

Dist. 2、Dist. 3:一部分类别的数据比另一部分多。

Dist. 4、Dist 5:只有一类数据比较多。

Dist. 6、Dist 7:只有一类数据比较少。

Dist. 8: 数据个数呈线性分布。

Dist. 9:数据个数呈指数级分布。

Dist. 10、Dist. 11:交通工具对应的类别中的样本数都比动物的多

对每一份训练数据都进行训练,测试时用的测试集还是每类1000个的原始测试集,保持不变。

3、类别不平衡数据的训练结果

以上数据经过训练后,每一类对应的预测正确率如下:

第一列Total表示总的正确率,下面是每一类分别的正确率。

从实验结果中可以看出:

  1. 类别完全平衡时,结果最好。

  2. 类别“越不平衡”,效果越差。比如Dist. 3就比Dist. 2更不平衡,效果就更差。同样的对比还有Dist. 4和Dist. 5,Dist. 8和Dist. 9。其中Dist. 5和Dist. 9更是完全训练失败了。

4、过采样训练的结果

作者还实验了“过采样”(oversampling)这种平衡数据集的方法。

这里的过采样方法是:对每一份数据集中比较少的类,直接复制其中的图片增大样本数量直至所有类别平衡。

再次训练,进行测试,结果为:

可以发现过采样的效果非常好,基本与平衡时候的表现一样了。

过采样前后效果对比,可以发现过采样效果非常好:

5、总结

CNN确实对训练样本中类别不平衡的问题很敏感。

平衡的类别往往能获得最佳的表现,而不平衡的类别往往使模型的效果下降。如果训练样本不平衡,可以使用过采样平衡样本之后再训练。

这确实是一个“经验主义”的结论,但多少给我们平常训练CNN模型带来一些启发和帮助。

训练集样本不平衡问题对CNN的影响相关推荐

  1. 训练集样本不平衡问题对深度学习的影响

    自己在进行人脸识别测试过程,开始利用自己的照片进行训练,由于开始准确率低,就开始增加自己照片的数量,开始是准确率提升,而后就开始降低,以前了解过这个方面知识,因此在网上找一些相关资料进行验证,后来发现 ...

  2. 2*2矩阵训练集比例对BP神经网络分类性能影响

    r1 r2     <1 <1 吸引子 c >1 >1 排斥子 p >1 <1 鞍点 a <1 >1 反鞍点 fa 构造一个网络 d2( x(a||fa ...

  3. ML基础 : 训练集,验证集,测试集关系及划分 Relation and Devision among training set, validation set and testing set...

    首先三个概念存在于 有监督学习的范畴 Training set: A set of examples used for learning, which is to fit the parameters ...

  4. 神经网络训练集两张图片之间的相互作用

    在<神经网络分类训练集的图片到底是如何相互影响的?>中得到了一个经验关系,如果(两张图片)混合后的迭代次数变小了,分类准确率可能变大:如果二者混合后迭代次数变大,分类准确率可能会变小.但上 ...

  5. 【NLP项目-文本分类】划分测试集,训练集,验证集

    目录 一.不分词划分数据集 1.划分数据集 2.将各数据集写入txt文件 二.分词划分数据集 1.分词 2.完整代码 本篇文章的主要任务是将自己的数据集使用在Chinese-Text-Classifi ...

  6. 神经网络调参训练集噪音比例对网络性能的影响

    这次用于实验训练集噪音比例对网络性能的影响,网络结构81*60*2,训练集用的是mnist的训练集的0和1,测试集用的mnist的测试集的0和1,学习率固定位0.1,batchsize=20,试验了训 ...

  7. 腾讯提超强少样本目标检测算法,公开1000类检测训练集FSOD | CVPR 2020

    作者 | VincentLee 来源 | 晓飞的算法工程笔记 不同于正常的目标检测任务,few-show目标检测任务需要通过几张新目标类别的图片在测试集中找出所有对应的前景.为了处理好这个任务,论文主 ...

  8. 绘制测试集、训练集的每一个病人或者样本的raidomics signiture图(绘制raidomics signature图),以及ROC曲线图

    绘制测试集.训练集的每一个病人或者样本的raidomics signiture图(绘制raidomics signature图),以及ROC曲线图 受试者工作特征曲线 (receiver operat ...

  9. matlab pca 测试样本,matlab_PCA,训练集与测试集分开,原理和用法

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 PCA基本流程: 1.训练集矩阵算协方差矩阵A; 2.算协方差矩阵特征值与特征向量; 3.按特征值的大小排列特征矩阵,得B,对应的特征值(按从大到小排列) ...

最新文章

  1. CTFshow 文件包含 web87
  2. C++ 之类的静态成员
  3. xtrabackup压缩备份多线程备份(lz4,pigz)全详解
  4. 基于BPMN2.0的工单系统架构设计(上)
  5. MySQL 5.7系列之sys schema(2)
  6. BZOJ 2763: [JLOI2011]飞行路线 【SPFA】
  7. hadoop 文件介绍
  8. 权限系统设计学习总结(3)——多账户的统一登录
  9. 主流H.264编码器对比测试 (MSU出品)
  10. php开源框架和平台(XAMPP、Wamp5和AppServ)简述
  11. EventThread线程对VSync的接收
  12. 关于多数据源(除自己数据库外,另一部分数据需通过接口调取第三方获取)的查询问题...
  13. 视频工具下载(m3u8、MP4)
  14. 数学----向量点积公式推导
  15. 爬虫python技术分享_Python技术分享:爬虫
  16. Google Earth Engine(GEE)——R 语言图像概览
  17. 在eclipse中修改tomcat端口
  18. 如何使用 mps 开发原生小程序
  19. MFC界面库BCGControlBar v30.0新功能详解:Desktop Alert Window
  20. 钢材规格解读的软件_钢材规格表及软件下载

热门文章

  1. [数据结构]二叉树的性质
  2. [蓝桥杯][历届试题]网络寻路-dfs,图的遍历
  3. ROADS POJ - 1724(最短路+邻接表+dfs)
  4. java继承层次结构,在状态模式中实现继承层次结构 - java
  5. IP地址与MAC地址的区别
  6. ffmpeg加环境变量
  7. 2016-2017 ACM-ICPC CHINA-Final(7 / 12)
  8. 与Min_25筛有关的一些模板
  9. P2467 [SDOI2010]地精部落
  10. 24dian(牛客多校第三场)