Kaggle猫狗大战——基于Pytorch的CNN网络分类:数据获取、预处理、载入(1)

第一次写CSDN博客,之前一直是靠着CSDN学学代码,这次不得不亲自上场了,就想着将学习的过程都记录下来。新人分享,可能菜了点,还请大家多多包涵。这次的目标是构建一个Kaggle猫狗大战的CNN识别网络,内容有点多,就分了几步讲。第一章就先讲讲一些准备工作,包括数据获取、程序的框架、预处理这些。

数据获取

首先你需要获得猫狗的数据,建议去Kaggle官网上下,缺点就是Kaggle官网上的train包里猫狗的标签是分开的,但是test包里是未区分的,不太方便验证,所以对train包进行拆分,选择23000张图片作为train包,2000张图片作为valid包,猫狗比例相等。新的train包和valid包的链接如下:
链接:https://pan.baidu.com/s/1c69WjBvh97PSU4hC4hFolg
提取码:lp34

程序框架

话不多说,上图(软件是Pycharm,语言是Python3.7,深度学习平台是Pytorch,话说这装也是真不好装,不过网上教程很多):

data是数据预处理和打包的程序,main是数据训练的主程序,models是自己写的CNN网络存放的程序,predict是模型训练好之后用来实验的程序。train包和valid包里面都把猫和狗分开存放,之后处理起来会方便很多。acc.png是保存下来的准确率曲线,best_model.pt是保存下来的测试效果最好的网络。

数据预处理

(整体的代码在最后)

load_data函数中比较重要的就是数据预处理,因为输入的图片是尺寸不一的,需要将其调整到同一尺寸。这里是新建了一个data_transforms函数,分为‘train’和‘valid’两种处理流程。

简单预处理的话,data_transforms[‘train’]三行足以:
(1)将图片随机裁剪再resize成固定尺寸(我设的是224,这个跟之后神经网络的全连接层参数有关,一般常用就是224和200);
(2)将灰度范围从0-255变换到0-1之间;
(3)把灰度范围从0-1变换到[X1,X2]之间(这是对整个数据集的灰度分布进行统计后得到的灰度的上下限,这样一来就是能使灰度分布更均匀,差异更大)。

transforms.RandomResizedCrop(input_size, scale=(0.7, 1)),  # 图像进行随机裁剪后再resize成固定大小
transforms.ToTensor(),  # 把灰度范围从0-255变换到0-1之间
transforms.Normalize([0.4864, 0.4533, 0.4154], [0.2625, 0.2558, 0.2586])  # 把灰度范围从0-1变换到[X1,X2]之间

其中,data_transforms[‘train’]里面也可以添加一些数据增广的手段,用来提高训练集的泛化能力,比如水平/垂直翻转,旋转, 缩放,裁剪,剪切,平移,改变对比度,色彩抖动,添加噪声等。

具体调用的函数如下:

【1】随机比例缩放(按比例缩放)
torchvision.transforms.Resize()函数,函数有两个参数,第一个参数为缩放大小,如果为一个值则会按比例缩放,否则按传入的值缩放;第二个参数表示缩放图片使用的方法,默认的是双线性差值。

【2】随机位置截取
随机位置截取能够提取图片中的局部信息,使得网络接受的输入具有多尺度的特征。
在torchvision中主要有以下两种方式,一个是torchvision.transforms.RandomCrop(),传入的参数是截取出图片的长和宽,在图片的随机位置进行截取;第二个是torchvision.transforms.CenterCrop(),同样传入图片的长和宽,会在图片的中心进行截取。

【3】随机水平翻转和竖直翻转
torchvision.transforms.RandomHorizontalFlip()函数和torchvision.transforms.RandomVerticalFlip()函数,不需要参数。

【4】随机角度旋转
torchvision.transforms.RandomRotation()函数,传入的参数是角度,若不传入参数,则随机旋转。

【5】亮度、对比度和颜色变化
torchvision.transforms.ColorJitter()函数有四个参数。
第一个参数为brightness:如果brights<1,会变暗;如果>1,会更亮一些;
第二个参数为contrast:当对比度降低,会发灰。如果对比度升高,白色的地方会更白,灰色的地方会更灰,其值取0-1;
第三个参数为saturation:饱和度降低,图像更暗淡。饱和度升高,图像更鲜艳,其值取0-1
第四个参数为hue:改变底色(偏紫偏红),其值取0-0.5。

【6】概率转换为灰度图
RandomGrayscale()函数是依据一定的概率将图片转换成灰度图。Grayscale是RandomGrayscale的一个特例,也就是概率等于1的RandomGrayscale。

【7】仿射变换
仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切(切成平行四边形)和翻转。

【8】随机遮挡
RandomErasing()函数,第一个参数为p,执行该操作的概率;第二个参数为scale,遮挡区域的面积;第三个参数为ratio,遮挡区域长宽比;第四个参数为value,遮挡区域的像素值;第五个参数inplace,是否执行原位操作。

需要注意的是,随机遮挡接受的是Tensor,它是在一个张量上进行操作。所以在之前要执行一个ToTensor()。后面的ToTensor()和Normalize可以不要。

【9】随机选取上述方法进行处理
(1)transforms.RandomChoice():在一系列方法中随机挑选一个。
(2)transforms.RandomApply():每次依概率执行还是不执行,执行就执行一组。
(3)transforms.RandomOrder():打乱顺序再执行。

比如我这里,就设计了四种变换(水平翻转、竖直翻转、对比度变化+颜色变化、错切+平移),然后每张图片按随机顺序进行这四种处理。

但是!!!要注意的是,数据增广不要弄得太离谱,理论上,数据增广会提高模型的准确率,但是会降低训练收敛速度。

transforms1 = transforms.RandomHorizontalFlip(p=0.5)transforms2 = transforms.RandomVerticalFlip(p=0.5)transforms3 = transforms.ColorJitter(contrast=0.5,hue=0.4)transforms4 = transforms.RandomAffine(0,translate=(0.2,0.2),shear=(20,20))data_transforms = {'train': transforms.Compose([## your code here# transforms.RandomOrder([transforms1,transforms2,transforms3,transforms4]),transforms.RandomResizedCrop(input_size, scale=(0.7, 1)),  # 图像进行随机裁剪后再resize成固定大小transforms.ToTensor(),  # 把灰度范围从0-255变换到0-1之间transforms.Normalize([0.4864, 0.4533, 0.4154], [0.2625, 0.2558, 0.2586])  # 把灰度范围从0-1变换到[X1,X2]之间]),'valid': transforms.Compose([transforms.Resize(input_size),  # 图像短边长度变为input_sizetransforms.CenterCrop(input_size),  # 从正中间剪正方形transforms.ToTensor(),transforms.Normalize([0.4864, 0.4533, 0.4154], [0.2625, 0.2558, 0.2586])]),}

data_transforms[‘valid’]则要简单不少,主要就是缩放然后裁剪,调整灰度分布即可。

transforms.Resize(input_size),  # 图像短边长度变为input_size
transforms.CenterCrop(input_size),  # 从正中间剪正方形
transforms.ToTensor(),
transforms.Normalize([0.4864, 0.4533, 0.4154], [0.2625, 0.2558, 0.2586])

数据载入

网上的做法一般是新建一个DogCatDataSet的新数据集类,需要继承Pytorch中的data.Dataset父类。其实不新建数据集类,直接用Pytorch的父类也可以。使用torchvision.datasets.ImageFolder()函数:

image_dataset_train = datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train'])
image_dataset_valid = datasets.ImageFolder(os.path.join(data_dir, 'valid'), data_transforms['valid'])

ImageFolder()函数的作用是:假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

它主要有四个参数:
root:在root指定的路径下寻找图片
transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
target_transform:对label的转换
loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象
label是按照文件夹名顺序排序后存成字典,即{类名:类序号(从0开始)},一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一致,如果不是这种命名规范,建议看看self.class_to_idx属性以了解label和文件夹名的映射关系。

然后是数据打包,我不太清楚该如何称呼这一步,按我的理解,这一步就是将训练集(或测试集)中的数据按batch_size打成数据包,传给神经网络(如果你用的是GPU,那就是CPU将数据打成一个个数据包,传给GPU进行计算)。

train_loader = torch.utils.data.DataLoader(image_dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)
# DataLoader将自定义的Dataset根据batch size大小、是否shuffle等封装成一个又一个batch大小的Tensor,数据给模型进行训练测试
valid_loader = torch.utils.data.DataLoader(image_dataset_valid, batch_size=batch_size, shuffle=True, num_workers=8)

利用DataLoader函数,按指定的batch_size、是否shuffle、并行线路数、是否pin_memory来将数据打成包。其中一些参数的意义如下:
(1)batch_size:一个数据包里的数据个数,理解就是多少张图片,batch_size越小,数据包越小,需要传递的次数也就越多,训练速度也就越慢(因为每次epoch都是要训练所有图片一次才完成,GPU训练一批数据是很快的,但是CPU向GPU传数据很慢);
(2)shuffle:随机处理,表示将数据顺序打乱,对训练集需要进行这一步处理(不然数据就是连续一堆猫,然后连续一堆狗),测试集可做可不做;
(3)num_workers:同时进行数据包传递的并行线路个数,一般看你的CPU是多少核的,num_workers越多,传递越快(理论上,但是不建议设太多,一般和CPU数相同,或取一半);
(4)pin_memory:是否将数据先传给缓存、再传给GPU,如果电脑性能好,建议=true,可以节省一些数据传递的时间。

data.py

提示:单个程序无法运行,需要去我另外两个博客(本系列的2、3)里,把main.py和models.py两个程序都摘下来,然后运行main.py,才能训练神经网络。

/*data.py*/
import torchvision
from torchvision import datasets, transforms
import torch.utils.data
import os, random, glob
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np## Note that: here we provide a basic solution for loading data and transforming data.
## You can directly change it if you find something wrong or not good enough.## the mean and standard variance of imagenet dataset
## mean_vals = [0.4864, 0.4533, 0.4154]
## std_vals = [0.2625, 0.2558, 0.2586]class DogCatDataSet(torch.utils.data.Dataset):                    # 新建一个数据集类,并且需要继承PyTorch中的data.Dataset父类def __init__(self, img_dir, transform=None):                  # 默认构造函数,传入数据集类别(训练或测试),以及数据集路径self.transform = transform                                # 转换关系dog_dir = os.path.join(img_dir, "dog")                    # “狗”文件夹的路径cat_dir = os.path.join(img_dir, "cat")imgsLib = []imgsLib.extend(glob.glob(os.path.join(dog_dir, "*.jpg"))) # glob()函数,对某一元素进行匹配;extend()函数,将可迭代的元素添加到列表中imgsLib.extend(glob.glob(os.path.join(cat_dir, "*.jpg")))random.shuffle(imgsLib)                                   # 打乱数据集self.imgsLib = imgsLib# 作为迭代器必须要有的def __getitem__(self, index):img_path = self.imgsLib[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0      # 狗的label设为1,猫的设为0img = Image.open(img_path).convert("RGB")img = self.transform(img)return img, labeldef __len__(self):return len(self.imgsLib)def load_data(data_dir="./data/", input_size=224, batch_size=36):data_transforms = {'train': transforms.Compose([## your code heretransforms.RandomResizedCrop(input_size, scale=(0.7, 1)),  # 图像进行随机裁剪后再resize成固定大小transforms.ToTensor(),  # 把灰度范围从0-255变换到0-1之间transforms.Normalize([0.4864, 0.4533, 0.4154], [0.2625, 0.2558, 0.2586])  # 把灰度范围从0-1变换到[X1,X2]之间]),'valid': transforms.Compose([transforms.Resize(input_size),  # 图像短边长度变为input_sizetransforms.CenterCrop(input_size),  # 从正中间剪正方形transforms.ToTensor(),transforms.Normalize([0.4864, 0.4533, 0.4154], [0.2625, 0.2558, 0.2586])]),}image_dataset_train = DogCatDataSet(os.path.join(data_dir, 'train'),data_transforms['train'])image_dataset_valid = DogCatDataSet(os.path.join(data_dir, 'valid'), data_transforms['valid'])train_loader = torch.utils.data.DataLoader(image_dataset_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)# DataLoader将自定义的Dataset根据batch size大小、是否shuffle等封装成一个又一个batch大小的Tensor,数据给模型进行训练测试valid_loader = torch.utils.data.DataLoader(image_dataset_valid, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)return train_loader, valid_loader

Kaggle猫狗大战——基于Pytorch的CNN网络分类:数据获取、预处理、载入(1)相关推荐

  1. 基于PyTorch搭建CNN实现视频动作分类任务代码详解

    数据及具体讲解来源: 基于PyTorch搭建CNN实现视频动作分类任务 import torch import torch.nn as nn import torchvision.transforms ...

  2. 基于Pytorch实现猫狗分类

    基于Pytorch实现猫狗分类 一.环境配置 二.数据集准备 三.猫狗分类的实例 四.实现分类预测测试 五.参考资料 一.环境配置 1.环境使用 Anaconda 2.配置Pytorch pip in ...

  3. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)...

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  4. 基于tensorflow、CNN网络识别花卉的种类(图像识别)

    基于tensorflow.CNN网络识别花卉的种类 这是一个图像识别项目,基于 tensorflow,现有的 CNN 网络可以识别四种花的种类.适合新手对使用 tensorflow进行一个完整的图像识 ...

  5. 深度学习(十七)基于改进Coarse-to-fine CNN网络的人脸特征点定位

    基于改进Coarse-to-fine CNN网络的人脸特征点定位 原文地址:http://blog.csdn.net/hjimce/article/details/50099115 作者:hjimce ...

  6. 基于keras的CNN图片分类模型的搭建以及参数调试

    基于keras的CNN图片分类模型的搭建与调参 更新一下这篇博客,因为最近在CNN调参方面取得了一些进展,顺便做一下总结. 我的项目目标是搭建一个可以分五类的卷积神经网络,然后我找了一些资料看了一些博 ...

  7. 基于pytorch的胶囊网络minst图像分类实现

    关于<Dynamic Routing Between Capsules>这篇论文的代码复现网上有很多,基本都是做图像重构的.我修改了其中一部分代码,实现了minst图像分类. 参考:基于p ...

  8. 基于Pytorch的猫狗分类

    无偿分享~ 猫狗二分类文件下载地址 在下一章说        猫狗分类这个真是困扰我好几天,找了好多资料都是以TensorFlow的猫狗分类,但我们要求的是以pytorch的猫狗分类.刚开始我找到了也 ...

  9. 基于pytorch的Faster-Rcnn网络实现视力表字符检测

    今天要做的是使用一个基于pytorch环境下的Faster-Rcnn网络实现对视力表字符的检测任务. 使用平台:pycharm:环境:torch1.5.0.cuda10.2 目录 一.制作数据集 二. ...

最新文章

  1. TensorFlow(6)神经网络训练(DNN)
  2. [云炬创业基础笔记]第五章创业机会评估测试9
  3. 计算智能-群智能算法-蚁群算法matlab实现
  4. 【Java文件下载】如何让浏览器直接下载后端返回的图片,而不是直接打开
  5. 作为新手程序员,掉过的那些坑!
  6. LNMP - nginx代理详解
  7. RF工具ride使用
  8. GOM跟GEE登陆器列表文件加密教程
  9. c语言 数据结构面试题及答案,数据结构c语言版试题大全(含答案).docx
  10. leetcode刷题(32)——88. 合并两个有序数组
  11. 对计算机科学的认识论文,关于对计算机的认识论文
  12. 论文参考文献格式与设置
  13. 关于Win10 driver irql not less or equal ndis.sys的个人解决过程
  14. 单片机炫彩灯实训报告_单片机跑马灯实验报告
  15. DRM系列(3)之DRM_IOCTL_MODE_MAP_DUMB
  16. 前端实现压缩图片的功能(vue-element)
  17. Mockplus Cloud自动生成规格,Mockplus Cloud交互式动画原型
  18. IBM Cloud 2015 - Invoice - 03 payment 支付方式
  19. 结构体的定义、初始化
  20. Exchange-获取主、所有SMTP地址

热门文章

  1. 一篇文章教会你使用HTML5 SVG 标签
  2. flink的timeWindowAll流无法输出数据
  3. Message:Message: 前言中不允许有内容
  4. 计算机网络(三) 广播信道及局域网
  5. Python-常用正则
  6. SQL2017及管理工具安装全过程
  7. 含胶原蛋白的食物有哪些?
  8. 思考、创新、坚持——阿里做了七年前端,我的成长经验分享
  9. Linux 如何刷新 DNS 缓存
  10. 定时器循环彩灯实验c语言,单片机实验6__定时器控制循环彩灯实验.doc