自己在进行人脸识别测试过程,开始利用自己的照片进行训练,由于开始准确率低,就开始增加自己照片的数量,开始是准确率提升,而后就开始降低,以前了解过这个方面知识,因此在网上找一些相关资料进行验证,后来发现有人进行过详细的测试,于是自己进行一些梳理。

实验数据与使用的网络

所谓样本不平衡,就是指在分类问题中,每一类对应的样本的个数不同,而且差别较大。这样的不平衡的样本往往使机器学习算法的表现变得比较差。那么在CNN中又有什么样的影响呢?作者选用了CIFAR-10作为数据源来生成不平衡的样本数据。

CIFAR-10是一个简单的图像分类数据集。共有10类(airplane,automobile,bird,cat,deer,dog, frog,horse,ship,truck),每一类含有5000张训练图片,1000张测试图片。

CIFAR-10样例如图:

训练时,选择的网络是这里的CIFAR-10训练网络和参数(来自Alex Krizhevsky)。这个网络含有3个卷积层,还有10个输出结点。

之所以不选用效果更好的CNN网络,是因为我们的目的是在实验时训练很多次进行比较,而不是获得多么好的性能。而这个CNN网络因为比较浅,训练速度比较快,比较符合我们的要求。

类别不平衡数据的生成

直接从原始CIFAR-10采样,通过控制每一类采样的个数,就可以产生类别不平衡的训练数据。如下表所示:

这里的每一行就表示“一份”训练数据。而每个数字就表示这个类别占这“一份”训练数据的百分比。

Dist. 1:类别平衡,每一类都占用10%的数据。

Dist. 2、Dist. 3:一部分类别的数据比另一部分多。

Dist. 4、Dist 5:只有一类数据比较多。

Dist. 6、Dist 7:只有一类数据比较少。

Dist. 8: 数据个数呈线性分布。

Dist. 9:数据个数呈指数级分布。

Dist. 10、Dist. 11:交通工具对应的类别中的样本数都比动物的多

对每一份训练数据都进行训练,测试时用的测试集还是每类1000个的原始测试集,保持不变。

类别不平衡数据的训练结果

以上数据经过训练后,每一类对应的预测正确率如下:

第一列Total表示总的正确率,下面是每一类分别的正确率。

从实验结果中可以看出:

  • 类别完全平衡时,结果最好。
  • 类别“越不平衡”,效果越差。比如Dist. 3就比Dist. 2更不平衡,效果就更差。同样的对比还有Dist. 4和Dist. 5,Dist. 8和Dist. 9。其中Dist. 5和Dist. 9更是完全训练失败了。

过采样训练的结果

作者还实验了“过采样”(oversampling)这种平衡数据集的方法。这里的过采样方法是:对每一份数据集中比较少的类,直接复制其中的图片增大样本数量直至所有类别平衡。

再次训练,进行测试,结果为:

可以发现过采样的效果非常好,基本与平衡时候的表现一样了。

过采样前后效果对比,可以发现过采样效果非常好:

总结

CNN确实对训练样本中类别不平衡的问题很敏感。平衡的类别往往能获得最佳的表现,而不平衡的类别往往使模型的效果下降。如果训练样本不平衡,可以使用过采样平衡样本之后再训练。自己系统也按照这个思路进行改造,确实效果明显

训练集样本不平衡问题对深度学习的影响相关推荐

  1. 训练集样本不平衡问题对CNN的影响

    转载自  训练集样本不平衡问题对CNN的影响 训练集样本不平衡问题对CNN的影响 本文首发于知乎专栏"ai insight"! 卷积神经网络(CNN)可以说是目前处理图像最有力的工 ...

  2. 【学界】深度学习如何影响运筹学?

    来源:运筹OR帷幄 前言 最近看到一篇回答,YouTube 已将视频推荐全面改用深度学习实现.但传统上,推荐系统落在运筹学的范畴,可以归结为一个矩阵补全(matrix completion)问题,用半 ...

  3. 主编推荐 | 深度学习如何影响运筹学?

    作者:郝井华等四人 作者简介: @郝井华:清华大学运筹学博士,现任美团配送算法架构师,美团点评研究员.@成丰:北京大学智能科学系 硕士 中国国际金融贸易创新发展战略合作研究中心 · 特聘研究员.胖骁: ...

  4. 【深度学习】深度学习如何影响运筹学?

    『运筹OR帷幄』原创 作者:郝井华等四人 作者简介: @郝井华:清华大学运筹学博士,现任美团配送算法架构师,美团点评研究员.@成丰:北京大学智能科学系 硕士 中国国际金融贸易创新发展战略合作研究中心 ...

  5. alexnet训练多久收敛_如何将深度学习训练速度提升一百倍?PAISoar 来了

    阿里妹导读:得力于数据规模增长.神经网络结构的演进和计算能力的增强,深度学习的图像处理.语音识别等领域取得了飞速发展.随着训练数据规模和模型复杂度的不断增大,如何充分利用分布式集群的计算资源加快训练速 ...

  6. 干货合集 | 带你深入浅出理解深度学习(附资源打包下载)

    作者:Shashank Gupta 翻译:倪骁然 校对:卢苗苗 本文约2300字,建议阅读10分钟. 本文提供资源帮助你在放置一个conv2d层或者在Theano里调用T.grad的时候,了解到在代码 ...

  7. 训练 GPT-3,为什么原有的深度学习框架吃不消?

    本文梳理了深度学习框架在支持大规模预训练模型时面临的技术挑战,以及当前各类框架的基本解决思路,帮助算法工程师对业界各类框架的分布式训练能力有更清晰的认知. 作者 | 一流科技CEO袁进辉 头图 | 下 ...

  8. 2.10 m 个样本的梯度下降-深度学习-Stanford吴恩达教授

    ←上一篇 ↓↑ 下一篇→ 2.9 Logistic 回归的梯度下降法 回到目录 2.11 向量化 mmm 个样本的梯度下降 (Gradient Descent on mmm example) 在之前的 ...

  9. python 训练识别验证码_python使用tensorflow深度学习识别验证码

    本文介绍了python使用tensorflow深度学习识别验证码 ,分享给大家,具体如下: 除了传统的PIL包处理图片,然后用pytessert+OCR识别意外,还可以使用tessorflow训练来识 ...

最新文章

  1. **CI两种方式查询所返回的结果数量
  2. 【问题解决】Processing库安装方法简介
  3. linux下Eclipse+CDT开发环境配置与使用
  4. 您的第一个简单的机器学习项目
  5. PLSQL个性化设置
  6. 4.01~ios开发常用的宏
  7. 《白帽子讲web安全》读书笔记
  8. Python版C语言词法分析器
  9. java shell_jshell – Java Shell
  10. 蓝桥杯2019年第十届C/C++省赛A组第二题-数列求值
  11. Mac系统访问Windows共享文件的详细步骤
  12. 安装libvirt管理套件(C/S架构模式,用户管理kvm虚拟机)
  13. 【C/C++】转义字符大全
  14. selenium-模拟登录QQ空间(附模拟滑动验证码)
  15. java 时间处理工具类
  16. 如何使用手机打开CAJ文件?
  17. 这些年我的不足(不够专注,不善于推迟满足感,阅读量不够……-无网不剩 http://t.cn/zOe1RPz)
  18. 在CentOS 7最小环境下安装Cinnamon桌面环境
  19. 什么事MVC?什么是MVC!
  20. 前端日期选择器--只选择年或者年月的My97

热门文章

  1. @Cacheable和@CacheEvict的学习使用
  2. 元学习 迁移学习_元学习就是您所需要的
  3. 华为设备配置Smart Link负载分担
  4. 【证明】矩阵特征值之和等于主对角线元素之和
  5. android通知栏显示,通知栏点击事件监听
  6. asp.net mvc 网站生成二维码
  7. java 6u45 no sni 2_sjscxz.taobao.com
  8. Excel如何如何比较两列同行内容是否一致
  9. 如何获取 ChatGPT OpenAI API Key
  10. Marvin java图像处理