猫狗大战是著名的竞赛网站kaggle几年前的一个比赛,参赛者得到猫狗各12500张图片,作为训练集,另外还会得到12500张猫和狗的图片,作为验证。最后提交结果至kaggle平台,获得评测分数。

本篇博文基于python3.7,Tensorflow2.1GPU版本,运行在Win10,pycharm作为IDE。相应的环境搭建,可以看博主之前的博文。
深度学习系列笔记——壹(深度学习环境的搭建及填坑之旅,基于windows)

kaggle猫狗大战官方链接
kaggle的注册和简单使用,本博客不再赘述,自行百度即可。

1、下载与整理数据集

(1)下载数据集
(下载时建议使用科学上网)
训练所用到的Train Dataset
测试所用到的Test Dataset
提交所用到的submission.csv文件

下载完成后,分别解压到train和test目录,得到如下结构

其中test有12500张不带标签的猫狗图片,train中含有以名字为前缀标签的猫狗图片各12500张。
猫的图片

狗的图片

为了后续训练的方便,我将猫狗的图片分别放到train目录下的cats,dogs文件夹中。以便后续程序读取文件时能够直接对类别分类。
(2)整理数据集


由于官方没有给出验证集,因此博主直接将一部分训练集的数据分隔为验证集,博主这里设置猫和狗图片的验证集为猫狗各1100张。这样划分之后,猫狗的训练集将会剩下各11400张。
设置验证集,其实是在检测模型的泛化能力,检验模型对于未训练过的例子是否有很好的鲁棒性(健壮性),这样将有助于我们了解模型训练是否出现了过拟合现象。

另外,在本例中测试集是没有标签的,这里需要注意,且猫和狗的图片混合在一起。

(3)观察数据
这里的观察数据,其实是观察一下是否出现了很多脏数据,训练数据的质量很大程度会影响到模型最终的准确率
另一方面,本例很显然可以利用CNN实现二分类,因此需要对数据的大小进行统一后送入模型,大小尺寸不一的图片对于训练好的模型是不能进行预测的,需要全部统一为模型设置的尺寸。
为方便起见以及考虑到过大的图片尺寸会导致模型训练变得极其缓慢(最终还是归结到硬件性能不足)。
官方给出的数据集,随便拿一张图,就有接近500x500的像素,这种尺度对于硬件性能不足的初学者来说简直是灾难。因为模型很可能在很长一段时间无法达到收敛且精度还很难达到要求。

因此博主这里将图片设置为200x200的尺寸
由于部分图片的尺寸和1:1略有不同,因此调整后图片可能变扁或者变长,但是对于图中的猫狗识别其实影响没有很大。此处和经典的AlexNet使用的224x224,227x227以及VGG之类的模型略有不同,本篇作为初次尝试CNN的实际应用,从简考虑,后期可在此基础上进行一定的调整。另外读者若硬件性能较佳,可尝试更大尺寸的图片喂入CNN,构造出更高精度的模型

2、利用Tensorflow处理数据

(1)数据增强
卷积神经网络有一个特点,在一定范围内(resnet便很好的诠释了为什么),增加网络的层数,往往能获得更好的效果,因为深层次的卷积层能够提取出更多表层网络无法提取出的特点,而网络层数的加深,也就意味着模型参数将会变得非常多,几千万甚至上亿的模型都是家常便饭,而随着网络层数的增加,待训练参数的增加,更意味着需要更加充足的数据去训练模型。
博主印象较为深刻的是吴恩达老师在深度学习课程中提到的,当训练数据不足时,太深的模型往往是在尝试 “记住” 数据,而不是学习数据。如何使得模型无法记住数据,而是强迫模型去学习数据内在的联系呢?这个时候,数据增强便派上了用场。这篇知乎的文章讲得不错,数据增强。
简言之,数据增强,就是在数据量不足的情况下,对已有的图片进行微小的改变。比如旋转(flips)、移位(translations)、旋转(rotations)等微小的改变。这些操作作用在像素层面,会使得我们的网络会认为这是不同的图片。从而迫使它去学习而不是生硬的记住图片。

(图片来源于网络,侵删)

(2)代码实现
keras为我们提供了很好的辅助,帮助我们在简单的数据上快速扩充出大量的数据,利用数据增强,我们能够让模型更好的学会如何去学习数据,而不是单纯的记住数据。

# 导入包
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
#ImageDataGenerator是用于实时数据增强的一个类,帮助我们快速实现数据的翻转,平移,旋转等操作
from tensorflow.keras import layers
#这里的layers是后面用到的,暂且不表
import os
import time
import matplotlib.pyplot as pltstart = time.time()
#用于后期模型训练总时间的计算,按下不表# 设置目录路径,定位到前面设置的猫狗数据存放位置
PATH = os.path.join('D:/Desktop/catVSdog/data')  # 图片数据集的根目录
# 将目录区分位猫狗训练集和验证集
train_dir = os.path.join(PATH, 'train')  # train数据集 相对于根目录
validation_dir = os.path.join(PATH, 'validation')  # validation数据集 相对于根目录
train_cats_dir = os.path.join(train_dir, 'cats')  # train目录下的文件夹 每个会在之后分为一类
train_dogs_dir = os.path.join(train_dir, 'dogs')
validation_cats_dir = os.path.join(validation_dir, 'cats')  # validation目录下的文件夹 每个会在之后分为一类
validation_dogs_dir = os.path.join(validation_dir, 'dogs')num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))num_cats_val = len(os.listdir(validation_cats_dir))
num_dogs_val = len(os.listdir(validation_dogs_dir))total_train = num_cats_tr + num_dogs_tr  # 文件数量求和方便后续训练图片并设置步长时复用
total_val = num_cats_val + num_dogs_val# 为方便起见,设置变量以在预处理数据集和训练网络时使用
batch_size = 32
epochs = 6
IMG_HEIGHT = 200 #图片的尺寸设置
IMG_WIDTH = 200# 使用实时数据增强生成一批张量图像数据。 通过通道方式获取图片
#设置训练集
train_image_generator = ImageDataGenerator(rescale=1. / 255, rotation_range=5, horizontal_flip=True)
"""
rotation_range=5 指的是将图片在0-5°内随机旋转, horizontal_flip=True指的是将图片随机进行镜像翻转ImageDataGenerator还有很多其他的操作,可以前往官网查询对应的API使用
[官网教程](https://tensorflow.google.cn/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator)
"""#设置验证集
validation_image_generator = ImageDataGenerator(rescale=1. / 255)
#验证集这里不再进行数据增强,有需要也可以按照训练集的方式进行设置"""
此处设置训练集的生成器,用于对训练数据的实时生成,由于训练集较大,直接加载至显存或者内存中不显示
因此采用一边读取数据至内存一边训练的方式,二者互不干扰
"""
#训练集数据生成器
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size, directory=train_dir, shuffle=True,target_size=(IMG_HEIGHT, IMG_WIDTH), class_mode='binary')
#验证集数据生成器
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size, directory=validation_dir,target_size=(IMG_HEIGHT, IMG_WIDTH), class_mode='binary')

3、总结

至此,我们已经完成了对于数据的处理操作,下一篇博文博主将继续介绍如何利用Keras搭建出猫狗大战的CNN模型。

深度学习系列笔记——贰 (基于Tensorflow Keras搭建的猫狗大战模型 二)

在三中,将总结整个项目的全部代码,并利用训练得到的模型,对测试图片进行相应的预测。
深度学习系列笔记——贰 (基于Tensorflow Keras搭建的猫狗大战模型 三)

深度学习系列笔记——贰 (基于Tensorflow Keras搭建的猫狗大战模型 一)相关推荐

  1. 深度学习系列笔记——贰 (基于Tensorflow2 Keras搭建的猫狗大战模型 三)

    深度学习系列笔记--贰 (基于Tensorflow Keras搭建的猫狗大战模型 一) 深度学习系列笔记--贰 (基于Tensorflow Keras搭建的猫狗大战模型 二) 前面两篇博文已经介绍了如 ...

  2. 《深度学习案例精粹:基于TensorFlow与Keras》深度学习常用训练案例合集

    #好书推荐##好书奇遇季#<深度学习案例精粹:基于TensorFlow与Keras>,京东当当天猫都有发售.本书配套示例源码.PPT课件.思维导图.数据集.开发环境与答疑服务. <深 ...

  3. 深度学习系列2:框架tensorflow

    1. 背景 tensorflow是一套可以通过训练数据的计算结果来反馈修改模型参数的一套框架,由谷歌公司于2015年11月开源,可以点击playground来可视化的尝试操作tensorflow,随便 ...

  4. vs2017 开始自己的第一个深度学习例子——MNIST分类(基于TensorFlow框架)

    这是针对于博客vs2017安装和使用教程(详细)的深度学习例子--MNIST分类项目新建示例 目录 一.新建项目 二.运行代码 三.生成结果 一.新建项目 1.项目创建参照博主文章:vs2017 开始 ...

  5. 【深度学习】SETR:基于视觉 Transformer 的语义分割模型

    Visual Transformer Author:louwill Machine Learning Lab 自从Transformer在视觉领域大火之后,一系列下游视觉任务应用研究也随之多了起来.基 ...

  6. Tensorflow深度学习入门(1)——Tensorflow环境搭建

    Tensorflow深度学习入门--环境搭建 自测以下的环境搭建方式是行得通的,目前我用的就是这些 1.        下载安装Ubuntu 14.04 虚拟机 https://github.com/ ...

  7. 深度学习系列笔记之统计基础

    研究方法入门 进行调查 信任调查结果吗? 调查了谁? 调查了多少人? 调查是怎么进行的? 影响结果的是潜在变量 抽样总体的平均得分叫作总体参数μ 抽样数据和总体数据之间有个叫抽样误差 μ -x 怎样使 ...

  8. 深度学习入门笔记系列 ( 二 )——基于 tensorflow 的一些深度学习基础知识

    本系列将分为 8 篇 .今天是第二篇 .主要讲讲 TensorFlow 框架的特点和此系列笔记中涉及到的入门概念 . 1.Tensor .Flow .Session .Graphs TensorFlo ...

  9. 深度学习入门笔记(十五):深度学习框架(TensorFlow和Pytorch之争)

    欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...

最新文章

  1. java Scanner具有神奇的作用可惜大部分java开发人员不知
  2. http://blog.csdn.net/rongdeguoqian/article/details/8035080
  3. 社会管理网格化 源码_综治综合解决方案、社会治安综合治理信息平台方案
  4. 智能手机下半场迎来淘汰赛:有的拼供应链,有的打起了 AI 的主意
  5. c语言定义int 输出4386,C语言 · 矩阵乘法
  6. font-family 各字体一览表
  7. matlab双峰滤波,MATLAB中的单峰或双峰分布
  8. 如何使用小米手机对文档进行扫描
  9. 英语体系----词根词缀等----持续补充(词根词缀等,词汇,语法,简单句,长难句,写作)
  10. css3的过度,transition
  11. 在docker中挂载硬盘
  12. css气泡图片上下浮动
  13. GAMS系列分享12—GAMS基础知识——模型和求解
  14. Pytorch之深度学习实战
  15. BZOJ 2407: 探险/BZOJ 4398: 福慧双修 dijkstra 构造
  16. 博通Broadcom SDK源码学习与开发10——Cable Modem IPv6地址
  17. 汇智动力学院——Java 浅谈数据结构和算法
  18. 金融网站知识图谱问答系统:自学Python第一周
  19. 写字机上位机c语言,易懂 | 手把手教你编写你的第一个上位机
  20. 读书笔记-陆-《从你的全世界路过》

热门文章

  1. ubuntu 查看内存命令
  2. Android开发之--读取文件夹下图片生成略缩图并点击显示大图
  3. head first 23个设计模式总结
  4. 没有“熊猫“的熊猫快餐,凭什么能成为中式快餐第一?
  5. linux内核 阅读,Linux内核阅读感悟
  6. Python之路【第十七篇】:Django【进阶篇 】(转自银角大王博客)
  7. ../labgob: “../labgob“ is relative, but relative import paths are not supported in module modo
  8. 【蓝桥杯】第10届Scratch国赛第6题程序2 -- 捉迷藏
  9. AppCode3真的很不错
  10. 比汽车还贵的自行车品牌排行榜除辐轮王土拨鼠你还知道哪些?