本文以研究报告形式描述了使用TensorFlow Slim 提供的预训练模型 Inception-ResNet-V2 进行猫狗图片分类的研究。

题图来自于网络,如侵权请留言,立即删除。

一、 问题定义

(一)项目概述

本项目是机器学习竞赛平台Kaggle上的一个图片分类项目Dogs vs. Cats[1]——猫狗大战,项目要解决的问题实际是一个计算机视觉领域的图像分类问题,图像分类一般的工作模式为给定一张图片,判断其属于某个有限类别集合中的哪一类。这个领域不仅非常有趣,而且具有非常大的应用价值和商业价值。猫狗大战项目图片数据来自于微软研究院的一个CAPTCHA[2](for Completely Automated Public Turing test to tell Computers and Humans Apart)项目Asirra[3] 的子数据集。

(二)问题陈述

猫狗大战项目要求对一个混合了猫和狗的图片数据集进行二分类,项目提供了用于训练、测试的两部分数据,要求使用算法程序在训练集上对已分类的猫和狗的图片进行建模,然后利用建立的模型对测试集上多张打乱顺序的未标记猫和狗的图片进行推断,输出图片是狗的概率,使用交叉熵损失值作为模型好坏的评估分数,最终分数需要进入Public Leaderboard 10%。

此毕业项目要求使用深度学习方法进行建模,将采用常见的CNN模型解决这个猫狗图片分类问题。

(三)评价指标

使用交叉熵损失loss值作为评估指标,交叉熵损失函数是神经网络分类算法中常用的损失函数,其值越小说明模型拟合的分布约接近真实分布,模型表现越好。交叉熵损失函数公式定义如下:

公式1 交叉熵损失函数[4]。

为数据集中的图片数;

为预测概率;

在图片是狗时为1,否则为0;

是以

为底的自然对数。

使用交叉熵作为损失函数一方面是其适用于分类问题,二是题目得分评判标准也是交叉熵loss值,与得分评判标准一致使得训练阶段的评估指标与题目推断阶段评估指标比较是有意义的。

二、 分析

(一)数据的探索

输入数据包含训练集、测试集两部分,图片格式为JPEG。

训练集:train.zip,包含25000张已标记的图片文件,文件名格式为“类别.图片id.jpg”,类别为cat或dog,图片id为数字,如cat.0.jpg、dog.12247.jpg。

测试集:test.zip,包含12500张未标记的图片文件,文件名格式为“图片id.jpg”, 图片id为数字,如1.jpg、11605.jpg。

训练集数据中标记为猫、狗的图片分别有125000张,比例1:1,训练集、测试集比例为1:2。

数据集中图片尺寸大小不一,训练集、测试集图片尺寸分别有8513、4888种,在训练和推断时需要统一尺寸。数据中图像不一定完整包含完整猫或狗的身体,有的主体在图片中很小,图片背景复杂,图片里会出现人或其他物体,如图1。训练集中包含少量非猫或狗的图片,如图2,这些异常数据大约占训练集的5.6 ‱,需要被清理掉。图1图2

图片数值异常可能导致训练时模型不收敛,经检查验证集、测试集数据图片RGB值都在[0,255]区间内,数值正常,可直接进行归一化。

(二)探索性可视化

训练集、测试集如果尺寸分布差异很大,训练出来的模型可能在预测时表现不佳。为了确认这一点,对训练集、测试集图片尺寸分布进行了可视化,如图3、图4。由可视化结果可见同一个集合里图片尺寸差异非常大,各尺寸分布比较平衡,训练集中存在两个离群点。训练集和测试集尺寸分布形状是相似的,并且除离群点外尺寸分布区间也是一致的,图片宽和高约在[30,500]像素之间,因此训练集训练出的模型是可以用于测试集的。图3 训练集尺寸分布散点图图4 测试集尺寸分布散点图

(三)算法和技术

图像分类目前最流行的解决方案是使用CNN(卷积神经网络),CNN是一种多层神经网络,使用多个卷积层、池化层堆叠提取图片特征(称之为feature map),末端用多个全连接层堆叠得到概率,使用softmax归一化输出最终概率。常见的CNN有很多,如InceptionNet[5]、VGG[6]、ResNet[7]、Inception-ResNet[8]。

上述这些知名的CNN由于网络层次比较深,参数量大,导致计算量非常大,所需训练时间比较长,通常需要使用GPU加速训练。实际使用时一般采用在大规模数据集上训练过的预训练模型进行微调(fine tune),即迁移学习,这样做不仅可以利用其已训练的参数,而且还可以大幅减少训练时间。这里使用tensorflow slim[9] 模块里提供的Inception-ResNet-V2预训练模型,Inception-ResNet-V2架构如图5所示,其使用ResNet中的残差连接与Inception思想结合,在模型精度和训练速度都有提升。Inception-ResNet-V2在ImageNet数据集上Top-1、Top-5准确率分别达到了80.4、95.3,在ILSVRC[10]图像分类基准测试中实现了当时(2016年8月31日)最好的成绩。

ImageNet数据集有1400多万幅图像,涵盖2万多个类别,类别中也包含猫和狗,ImageNet数据集和本项目相似性比较大,进行迁移学习时可以利用预训练模型卷积层参数提取特征,并调整预训练模型输出1000个分类为2个分类,训练全连接层以拟合本项目数据集。图5 Inception-ResNet-V2网络架构图[11]

(四)基准模型

题目中要求得分进入Public Leaderboard 前10%,得分以交叉熵损失loss值计算,总参赛人数1314,前10%即排名在1~131名,131名得分为0.06127,即最终得分需要在(0.00000, 0.06127)之间。

三、 方法

(一)数据预处理

本项目数据预处理主要包含异常数据清理、训练集划分验证集、图片数据读取、标签独热编码、图片变换、数值归一化、分批、打乱顺序。

异常数据清理:训练集中大约包含了15张非猫或狗的图像,这些图片属于离群数据,可能会影响模型精度,需要移除。可以利用ImageNet预训练模型可以找出非猫或狗的图片以确定要清理哪些图片。鉴于时间原因且数据量不是太大,采用了人工挑选异常图片,这大约花费了1小时,这种方法并不推荐用于大数据集上,这些异常图片文件名如下:

cat.4688.jpg,cat.5418.jpg,cat.7377.jpg,cat.7564.jpg,cat.8100.jpg,cat.8456.jpg,cat.10029.jpg,cat.12272.jpg,dog.1259.jpg,dog.1895.jpg,dog.4367.jpg,dog.8736.jpg,dog.9517.jpg,dog.10190.jpg,dog.11299.jpg。

训练集划分验证集:通常模型训练时需要在验证集上观察评估指标是否达到要求,而本项目只提供了训练集和测试集,所以需要从训练集中划分一部分数据作为验证集,使用4:1的比例从训练集中随机挑选出一部分图片作为验证集。

图片数据读取:训练时需要将图片读入内存,并转换为tensorflow使用的张量格式。读取图片数据时采用了tensorflow的文件读取管线[12],这样不需要一次性将所有数据读入内存,可以明显感受到训练逻辑等待读取数据时间比一次性读入要短。图片读入内存后进行JPEG解码,图片是彩色图片,解码时指定通道数为3(3表示Red、Green、Blue三个颜色通道),解码后图片数据形状为299×299×3的数组。

标签独热编码:分类问题通常需要将离散的标签值转换为独热编码的onehot向量,通常可以采用tensorflow或scikit-learn提供的API进行转换,但是这里没必要,因为只有两个类别,根据图片文件名称判断是dog还是cat决定编码为[0,1]或[1,0]即可,这里[0,1]表示狗,[1,0]表示猫。为了保证图片编码与标签对应关系是正确的,这里做了可视化以验证,如图6:图6 图片标签与图片对应关系

由图6可看出标签与图片对应关系是正确的,读入的图片数据和实际文件也是对应的。

图片变换:使用的图片变换主要为尺寸调整,图像尺寸根据使用的CNN模型决定,Inception-ResNet-V2使用的输入图片尺寸为299×299,调整尺寸时图片宽度或高度小于299则填充黑色像素至299,大于299则从图片中心位置开始裁剪至299,裁剪前后效果示例如图7、图8:图7 dog.8556.jp裁切前图8 dog.8556.jpg裁切后

示例中图7、图8高度一致是为了文档美观进行了等比缩放,图7裁剪前实际尺寸为499 × 375,图8裁剪后实际尺寸为299×299。

数值归一化:图片数值在[0,255]区间,为了训练时收敛速度快,避免激活函数饱和,需要将图片数值转换至[0,1]区间。这里采用了tensorflow内置的tensorflow.image.convert_image_dtype操作,不仅执行了归一化,还可以转换数据类型为需要的tensorflow.float32数据类型。归一化一张图片部分数据示例如下:

[[[0.5529412 0.4156863 0.26666668]

[0.5372549 0.4039216 0.25882354]

[0.53333336 0.40000004 0.2509804 ]

...

[0.43921572 0.3137255 0.16470589]

[0.43921572 0.3137255 0.16470589]

[0.4431373 0.31764707 0.16862746]]

...

[0.6862745 0.4901961 0.25882354]

[0.76470596 0.5686275 0.3372549 ]

[0.7568628 0.56078434 0.32941177]]]

分批与打乱顺序:分批是为了保证一次处理的数据量不超过内存容量或显存容量,这里受制于显存容量,批次大小设置为32。打乱训练数据顺序是防止模型过拟合或欠拟合的一种方法,采用随机打乱即可。

(二)执行过程

训练阶段代码采用tensorflow.slim模块中预置的Inception-ResNet-V2模块构建网络,使用tensorflow.slim模块提供的ImageNet预训练模型参数的ckpt文件恢复网络参数。通过使用tensorflow.slim.learning模块train函数控制训练迭代,可简化读取数据的线程控制和保存模型参数。训练阶段训练集batch_size设置为32,验证集批次大小batch_size设置为160,learning_rate尝试0.01、0.001,迭代次数epochs尝试5、10次。

推断阶段需要载入保存的模型,对测试集进行预测,输出每个图片。

tensorflow.slim提供的文档相对简陋,很多情况下需要阅读源代码解决问题,实践中遇到主要问题及解决方案如下:

1).由于Inception-ResNet-V2预训练模型使用ImageNet数据集训练时是1000个类别,而本项目是2个类别,训练时需重新训练全连接层,恢复参数时需要排除“Logits”和“AuxLogits”这两层的参数。

2). 训练时要固定卷积层参数,训练全连接层参数,即训练“Logits”和“AuxLogits”这两层的参数,这需要从所有可训练变量中查找上述两层的所有变量,并传递给tensorflow.slim.learning.train函数的variables_to_train参数。如果未固定“Logits”和“AuxLogits”这两层参数,会出现训练很久不收敛的现象。

3).由于tensorflow.slim.learning.train函数内部控制了Tensorflow Session的创建,恢复变量时需要通过指定初始化函数给tensorflow.slim.learning.train函数的init_fn参数,在这个初始化函数中实现恢复参数,幸运的是tensorflow.contrib.framework中提供了assign_from_checkpoint_fn函数来实现此功能。

4).通常在训练时需要在日志中输出训练集和验证集上的准确率、loss、step等信息以便观察模型训练的情况。同样由于tensorflow.slim.train函数对外屏蔽了Session,需要通过传入回调函数的方式给train方法的train_step_fn参数来实现这一功能。

尝试多次后,最终在learning_rate=0.001,epochs=5时获得最佳结果,得分为0.06584。

(三)完善

最终获得的最优的得分为0.06584,这个得分大于0.06127,未能进入Public满Leaderboard前10%的要求。尝试数据提升,通随机水平翻转、随机调整色调、对比度、饱和度、明亮度,但对模型的提升不明显。

考虑到最终得分使用log loss的交叉熵,对于正确的样本0.995和1相差不大,对于错误的样本,0和0.005差距非常大[13],所以此处可采用一个clip trick,将测试集预测的数值限制在[0.005, 0.995]这个区间。结果证明采用这个trick之后,得分提升到0.04845,排名进入51名,详见表1。由表1可看出,使用了clip trick对两个模型的最终得分都是有明显提升的。表1 不同学习速率和是否使用clip的得分对比,score为测试集上的交叉熵loss,epoch为5

四、 结果

(一)模型的评价与验证

通过保存训练时验证集loss和准确率可以绘制训练次数与loss和准确率的曲线,这里使用了tensorflow自带的可视化工具tensorboard进行可视化,如图9。从图上可看出,统一在训练次数为2000次时,learning_rate=0.001时准确率和loss抖动较小,更加稳健,表现更好。图9 不同学习速率表现对比

(二)合理性分析

学习速率是一个超参数,设置过小会导致收敛较慢,设置过大会导致准确率和loss振荡较大难以收敛。由图9可见,学习速率在0.01时,准确率和loss振幅和振荡频率都比较大,从0.01调整至到0.001后,准确率和loss振荡幅度和频率有明显降低,模型更加稳健,所以在尝试的两个学习速率中,学习速率为0.001时,模型得到最优分数是合理的。

五、 项目结论

(一)结果可视化

模型最终预测结果见图9,随机展示了5张图片及其预测为够的概率,可以看出与图片中实际的动物是一致的,猫的图片预测为狗的概率都是0.005(由于trick限制了数据范围最小为0.005,实际计算出的概率<=0.005),狗的图片都在90%以上。图10 模型预测是狗概率与实际图片对照

结果提交至Kaggle上最终得分0.04845,见图11。图11 Kaggle得分

(二)对项目的思考

通过本项目实现了猫狗图像的二分类,首先对数据进行预处理,然后利用tensorflow slim模块搭建了深度卷积神经网络Inception-RestNet-V2,并利用迁移学习加速了模型训练的过程,根据交叉熵评估函数的特点采用了数值计算上的trick,最终得分进入Public Leaderboard 前4%,达到题目10%的要求。

项目中对于第一次使用缺乏完善文档的tensorflow slim来说是有困难的,幸运的是所有问题都解决了。项目中最难的地方在于提升模型的得分,要多次调整超参数,尝试多个方法后没有取得更好的成绩非常令人沮丧。最有意思的是使用了一个小的trick居然提升非常大,这是万万没想到的。

最终的模型和结果在这个猫狗分类问题上是符合期望的,没有使用trick时,单模型精度已经非常高了。这个模型无法用于通用场景,模型参数和数据集是非常相关的,通用场景的数据与本项目数据集分布不一致,所以无法适用,但网络架构是可以通用的。

(三)需要做出的改进

如果不用trick就能否冲进前10%,这值得一试,使用集成学习可能是一个好办法。可考虑Inception-ResNet-V2、Inception-V4、ResNet-V2三个预训练模型集成,集成时使用三个模型的预训练权重进行特征提取,提取特征后进行拼接,再dropout加全连接进行训练,这样通过不同模型学到的不同特征组合很可能获得更高的精度、更低的loss。

六、 引用

[1] Kaggle Inc. Dogs vs. Cats Redux: Kernels Edition | Kaggle [EB/OL]. https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition, 2018-03-23

[2] Wikipedia. CAPHCHA-Wikipedia[EB/OL]. https://en.wikipedia.org/wiki/CAPTCHA, 2018-03-23

[3] Microsoft. Asirra: A CAPTCHA that Exploits Interest-Aligned Manual Image Categorization - Microsoft Research

[4] Kaggle Inc. Dogs vs. Cats Redux: Kernels Edition | Kaggle [EB/OL]. https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition#evaluation, 2018-03-23

[5] Szegedy, C., Liu, W., Jia, Y., Sermanet, P., Reed, S., Anguelov, D., Erhan, D., Vanhoucke, V. and Rabinovich, A. (2018). Going Deeper with Convolutions. [EB/OL]. http://arxiv.org/abs/1409.4842v1,2018-03-23

[6] Simonyan, K. and Zisserman, A. (2018). Very Deep Convolutional Networks for Large-Scale Image Recognition[EB/OL]. http://arxiv.org/abs/1409.1556.pdf, 2018-03-23.

[7] He, K., Zhang, X., Ren, S. and Sun, J. (2018). Deep Residual Learning for Image Recognition[EB/OL] https://arxiv.org/abs/1512.03385, 2018-03-23

[8] Szegedy, C., Ioffe, S., Vanhoucke, V. and Alemi, A. (2018). Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning. [EB/OL] http://arxiv.org/abs/1602.07261, 2018-03-23

[9] Google. models/research/slim at master · tensorflow/models[EB/OL]. https://github.com/tensorflow/models/tree/master/research/slim, 2018-03-23

[10] image-net.org. ImageNet Large Scale Visual Recognition Competition 2012 (ILSVRC2012) [EB/OL]. http://image-net.org/challenges/LSVRC/2012/results.html,2018-03-23

[11] yeephycho. A Note to Techniques in Convolutional Neural Networks and Their Influences III (paper summary) | yeephycho

[13] ypwhs. ypwhs/dogs_vs_cats: 猫狗大战[EB/OL]. https://github.com/ypwhs/dogs_vs_cats, 2017-07-26

python猫狗大战游戏_Kaggle猫狗大战图片分类项目研究相关推荐

  1. python猫狗大战讲解_Kaggle猫狗大战图片分类项目研究

    本文以研究报告形式描述了使用TensorFlow Slim 提供的预训练模型 Inception-ResNet-V2 进行猫狗图片分类的研究. 题图来自于网络,如侵权请留言,立即删除. 一. 问题定义 ...

  2. Python扫雷游戏源代码及图片素材

    Python扫雷游戏源代码.源程序共有两个文件及一个资源包:main.py及mineblock.py,资源包请前往百度网盘下载, https://pan.baidu.com/s/1u-qsJhAaCJ ...

  3. Python 扫雷游戏 完整源代码+图片素材

    代码的下载地址 截图 设计需求 基础功能 实现windows扫雷游戏初级的所有功能 扫雷尺寸99 方格 每个方格尺寸3030 游戏初始化时,随机分布10个地雷 当左键点击雷区任意方格时,则游戏开始 鼠 ...

  4. python猫狗大战游戏下载_猫狗大战RPG最新下载-猫狗大战RPG游戏安卓版-Minecraft中文分享站...

    <猫狗大战RPG>游戏是一款回合制角色扮演手游,玩家们可以在这款游戏中自由的选择角色去进行游戏.卡通的游戏风格,高清的而游戏画质,丰富多样的游戏玩法,大大的提高了游戏的可玩性,会给玩家们带 ...

  5. ubuntu下使用python将ppt转成图片_Ubuntu下使用Python实现游戏制作中的切分图片功能...

    本文实例讲述了Ubuntu下使用Python实现游戏制作中的切分图片功能.分享给大家供大家参考,具体如下: why 拿到一个人物行走的素材,要用TexturePacker打包.TexturePacke ...

  6. python小游戏编程arcade----坦克动画图片合成

    python小游戏编程arcade----坦克动画图片合成 前言 坦克动画图片合成 1.PIL image 1.1 读取文件并转换 1.2 裁切,粘贴 1.3 效果图 1.4 代码实现 2.处理图片的 ...

  7. Python外星人入侵游戏——添加飞船和外星人图片

    Python外星人入侵游戏是自己在<Python编程从入门到实践>在本书里学到的.本篇主要介绍该游戏中所需要的两个图片.分别为飞船和外星人图片. 1.首先去到 http://www.itu ...

  8. Python拼图游戏源代码,可定制拼图图片,支持多种难度,可九宫格、十六宫格、二十五宫格

    配置环境 安装pygame模块 pip install pygame 引入资源 将照片,添加到resources/pictures路径下 照片.jpg格式 主函数代码 pintu.py 一个配置文件c ...

  9. 少儿编程python线上课程-少儿编程课堂|python – 用游戏学编程

    学习编程是很快乐的事情.当我们自己开发出一套时下流行的游戏时,这满满的成就感比玩儿游戏本身高出了不知道会有多少倍. 接下来一段时间我们就python从0开始学习怎么开发 flappy brid 游戏. ...

最新文章

  1. [Oracle] CPU/PSU补丁安装教程
  2. SQL Server 查询表备注信息的语句
  3. for语句 2017-03-17
  4. QStardict移植到i.MX-287开发板
  5. js 数组、对象转json 以及json转 数组、对象
  6. python实现的好玩的小程序--利用wxpy实现的微信可检测僵尸粉机器人
  7. Eucalyptus云平台搭建
  8. 深度学习教程(2) | 神经网络基础(吴恩达·完整版)
  9. Windows下批量修改文件名称
  10. 计算机显示错误屏幕,如何解决显示器分辨率错误
  11. mysql键值_如何在MySQL中存储键值对?
  12. 【java】列表导出excel(支持单元格内容换行展示)
  13. html怎么改变图片整体大小,css怎么改变图片大小?
  14. 逻辑回归的参数计算:牛顿法,梯度下降法,随机梯度下降法
  15. rk3288 linux 编译,注意了!VS-RK3288Ubuntu编译环境错误小结
  16. Mysql之视图的创建、修改、查看、删除
  17. 深度学习backbone是什么意思_什么是深度学习,深度学习是热门词
  18. java 同或_java语言中同或运算的实现
  19. Java管理扩展JMX入门学习
  20. 什么是HEVC?解释了高效视频编码,H.265和4K压缩

热门文章

  1. 21世纪发展最快的数据科学的总结
  2. 相似图片搜索的三种算法
  3. 教你把gps服务器修改为中国加速搜星,手机导航-GPS搜星速度优化(android手机)...
  4. 如何应对学习知识、技能不用就会忘(节选自《穷查理宝典》第十一讲:人类误判心理学之不用就忘倾向)
  5. 【ppt制作软件】Focusky教程 | 怎样实现表格的行列转换?
  6. 扯淡 | 如何判断一家公司是否靠谱
  7. unity解压缩文件踩坑记录
  8. UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tenso
  9. 他是互联网顶尖大佬,创办并掌舵Facebook多年,现在却被要求辞职
  10. 使用idea快速创建maven项目