数据集介绍

首先是要下载数据集,下载地址:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition

数据解压之后会有两个文件夹,一个是“train”,一个是“test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据,也是网站要求提交标签的。

在train文件夹里边是一些已经命名好的图像,有猫也有狗

而在test文件夹中是只有编号名的图像

大致了解了数据集后,下边就开始划分数据集

代码

先放一段代码,这是从书中截取出来的:

# coding:utf8
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as Tclass DogCat(data.Dataset):def __init__(self, root, transforms=None, train=True, test=False):"""主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据"""self.test = testimgs = [os.path.join(root, img) for img in os.listdir(root)]# test1: data/test1/8973.jpg# train: data/train/cat.10004.jpg if self.test:imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))else:imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))imgs_num = len(imgs)if self.test:self.imgs = imgselif train:self.imgs = imgs[:int(0.7 * imgs_num)]else:self.imgs = imgs[int(0.7 * imgs_num):]if transforms is None:normalize = T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])if self.test or not train:self.transforms = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),normalize])else:self.transforms = T.Compose([T.Resize(256),T.RandomReSizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),normalize])def __getitem__(self, index):"""一次返回一张图片的数据"""img_path = self.imgs[index]if self.test:label = int(self.imgs[index].split('.')[-2].split('/')[-1])else:label = 1 if 'dog' in img_path.split('/')[-1] else 0data = Image.open(img_path)data = self.transforms(data)return data, labeldef __len__(self):return len(self.imgs)

详解

这里建立了一个类,继承自data.Dataset,里边有三个方法是必须重写的:

class DogCat(data.Dataset):def __init__(self, root, transforms=None, train=True, test=False):"""主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据"""#这个__init__方法是初始化,里边可以对数据进行一些预处理def __getitem__(self, index):"""一次返回一张图片的数据"""#__getitem__方法是迭代器需要,当读取数据集的时候就会调用__getitem__方法,#一次读取一张照片,因此,这里主要实现返回图像与标签的功能def __len__(self):#这个函数的目的是返回数据集大小,也是必不可少的部分

下面开始解释每个方法中语句的功能

    def __init__(self, root, transforms=None, train=True, test=False):#root是根目录,用来存放数据#transforms是对图像做出转换#train和test是标记self.test = test#os.listdir(root)获取root目录下所有文件名imgs = [os.path.join(root, img) for img in os.listdir(root)]#根据测试集与训练集图片命名不同进行不同的划分if self.test:imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))else:imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))#获取图像数量imgs_num = len(imgs)#将test文件夹中图像作为测试集if self.test:self.imgs = imgs#将训练集70%作为训练集elif train:self.imgs = imgs[:int(0.7 * imgs_num)]#将训练集30%作为验证集else:self.imgs = imgs[int(0.7 * imgs_num):]#下边对图像做变换if transforms is None:normalize = T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])if self.test or not train:self.transforms = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),normalize])else:self.transforms = T.Compose([T.Resize(256),T.RandomReSizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),normalize])
    def __getitem__(self, index):"""一次返回一张图片的数据"""#根据下标获取标签img_path = self.imgs[index]if self.test:label = int(self.imgs[index].split('.')[-2].split('/')[-1])else:label = 1 if 'dog' in img_path.split('/')[-1] else 0data = Image.open(img_path)data = self.transforms(data)#返回图像数据与标签return data, label
    def __len__(self):#返回数据集长度return len(self.imgs)

到此位置,数据集的划分与数据类已经完成

完整训练过程可以看我另一篇博客:

https://blog.csdn.net/qq_41685265/article/details/104898848

狗猫分类数据集划分详解相关推荐

  1. python从date目录导入数据集_PyTorch加载自己的数据集实例详解

    数据预处理在解决深度学习问题的过程中,往往需要花费大量的时间和精力. 数据处理的质量对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练, 更会提高模型性能.为解决这一问题,PyTorch提供 ...

  2. IP地址分类及范围详解

    IP地址分为公网IP地址(合法IP地址)和私有IP地址 公网IP地址主要应用于Internet上的主机访问,而私有IP地址应用于局域网中计算机的相互通信. IP地址的表示形式:分为二进制表示和点分十进 ...

  3. IPv4、IPv6地址、组播地址及子网子划分详解二子网划分

    IPv4.IPv6地址.组播地址及子网子划分详解二子网划分 5.子网划分 5.1.子网掩码 5.2.无类域间选择CIDR 5.3.根据子网掩码和CIDR值划分子网 5.4.二进制AND运算在划分子网中 ...

  4. IPv4、IPv6地址、组播地址及子网子划分详解三可变长子网掩码

    IPv4.IPv6地址.组播地址及子网子划分详解三可变长子网掩码 5.5.可变长子网掩码(VlSM) 5.5.可变长子网掩码(VlSM) 先看一下分类组网,路由选择协议RIPv1没有包含子网信息的字段 ...

  5. 计算机网络c类网络划分子网介绍,IP地址的子网划分详解

    原标题:IP地址的子网划分详解 来源:今日头条北京炫亿时代 一.子网划分基础 1.子网划分的若干个好处: ①减少网络流量 ②提高网络性能 ③简化管理 ④可以更为灵活的形成大覆盖范围的网络 2.你最好遵 ...

  6. 子网掩码必须是相邻的是什么意思_零基础IP子网划分详解

    零基础IP子网划分详解 2016.8.22修正,感谢道友刘先生的提醒 在学习IP子网划分前,首先的明白以下几个基础概念: 1.IP地址组成 IP地址组成示意图 IP地址由32位二进制组成,32位二进制 ...

  7. C/C++内存区域划分详解

    C/C++内存区域划分详解 C/C++内存分布 C/C++中,内存主要分为.堆.栈.全局/静态存储区和常量存储区. 栈:栈又叫堆栈,就是那些由编译器在需要的时候分配,在不需要的时候自动清除的变量的存储 ...

  8. functional java_java中functional interface的分类和使用详解

    java 8引入了lambda表达式,lambda表达式实际上表示的就是一个匿名的function. 在java 8之前,如果需要使用到匿名function需要new一个类的实现,但是有了lambda ...

  9. Java调用SMSLib用单口短信猫发送短信详解

    技术园地 当前位置:短信猫网站主页 > 技术园地 > [转载]Java调用SMSLib用单口短信猫发送短信详解 发布时间:2017/02/09 点击量:620 SMSLib是Apache的 ...

  10. c语言中%s的作用,C语言中%c与%s的区别与划分详解

    %c格式对应的是单个字符,%s格式对应的是字符串. 例: char a; char b[20]; scanf("%c",&a); //只能输入一个字符. scanf(&qu ...

最新文章

  1. 存储过程与函数oracle
  2. Google adwords新手推广常见错误
  3. Python之路---函数进阶
  4. iPhone 11终于没涨价但依然暴利 外媒:64GB起始容量就是个笑话
  5. zabbix 服务器监控之数据库操作
  6. MySQL(9)-----多表创建及描述表关系(需求)
  7. 【2022.3】尚硅谷Vue.js从入门到精通基础笔记(理论+实操+知识点速查)
  8. 如何扩展计算机c盘的控件,电脑C盘空间不足,怎么把c盘空间可以扩大
  9. 数学建模各种软件对比(MATLAB/Lingo/SAS/SPSS)
  10. 工作流引擎之-activiti6使用
  11. 由jar文件生成jad文件
  12. 人机交互-语音交互的优势和劣势
  13. 金融笔记:货币的概念
  14. python中sqrt(4)*sqrt(9)_Python表达式sqrt(4)*sqrt(9)的值为()
  15. QCustomPlot画带数值标签的柱状图
  16. python通过ssh通道连接PostgreSQL数据库(mysql等类同)
  17. 【三维目标检测】Complex-Yolov4详解(一): 数据处理
  18. 推荐系统学习笔记——四、Netfilx经典推荐系统架构
  19. 电脑进共享云盘报错“不允许一个用户使用一个以上用户名与服务器或共享资源的多重连接......”
  20. K3 单据,单据体自定义字段显示及时库存

热门文章

  1. 系统提示“无法删除文件,无法读取源文件或磁盘”的解决办法
  2. Linux期末复习总结
  3. 最新麻瓜编程实用主义学Python分享
  4. 纯前端js导出Excel文件
  5. C++实现简易五子棋游戏
  6. 使用Latex排版一篇IEEE Robotics and Automation Letters期刊文章
  7. 解决python中No module named ‘numpy‘问题
  8. JS 正则表达式 手机号码正则
  9. SPSS——相关分析——偏相关(Partial)分析
  10. 以太网和wifi协议