• 1. 数据处理

    • 数据加载
    • ImageFolder
    • DataLoader加载数据
    • sampler:采样模块

1. 数据处理

数据加载

在Pytorch 中,数据加载可以通过自己定义的数据集对象来实现。数据集对象被抽象为Dataset类,实现自己定义的数据集需要继承Dataset,并实现两个Python魔法方法。

  • __getitem__: 返回一条数据或一个样本。obj[index]等价于obj.__getitem__(index).
  • __len__: 返回样本的数量。len(obj)等价于obj.__len__().
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as npclass DogCat(data.Dataset):def __init__(self,root):imgs=os.listdir(root)#所有图片的绝对路径#这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片self.imgs=[os.path.join(root, img) for img in imgs]def __getitem__(self, index):img_path=self.imgs[index]#dog->1, cat->0label=1 if 'dog' in img_path.split("/")[-1] else 0pil_img=Image.open(img_path)array=np.asarray(pil_img)data=t.from_numpy(array)return data,labeldef __len__(self):return len(self.image)dataset=DogCat('N:/百度网盘/kaggle/DogCat')
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
for img,label in dataset:print(img.size(),img.float().mean(),label)

结果:

torch.Size([280, 300, 3]) tensor(71.6653) 0
torch.Size([396, 312, 3]) tensor(131.8400) 0
torch.Size([414, 500, 3]) tensor(156.6921) 0
torch.Size([375, 499, 3]) tensor(96.8243) 0
torch.Size([445, 431, 3]) tensor(103.8582) 1
torch.Size([373, 302, 3]) tensor(160.0512) 1
torch.Size([240, 288, 3]) tensor(95.1983) 1
torch.Size([499, 375, 3]) tensor(90.5196) 1

问题:结果大小不一,这对于batch训练的神经网络来说很不友好。
返回的样本数值交大,未归一化至【-1,1】

针对上述问题,pytorch提供了torchvision。它是一个视觉工具包,提供了很多视觉图像处理的工具。
其中transforms模块提供了对PIL Image对象和Tensor对象的常用操作。

对PIL Image的常见操作如下:

  • Scale/Resize: 调整尺寸,长宽比保持不变; #Resize
  • CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片;
  • Pad: 填充;
  • ToTensor: 将PIL Image对象转换成Tensor,会自动将【0,255】归一化至【0,1】。

对Tensor的常见操作如下:

  • Normalize: 标准化,即减均值,除以标准差;
  • ToPILImage:将Tensor转为PIL Image.

    如果要对图片进行多个操作,可通过Compose将这些操作拼接起来,类似于nn.Sequential.
    这些操作定义之后是以对象的形式存在,真正使用时需要调用它的__call__方法,类似于nn.Mudule.
    例如:要将图片调整为224*224,首先应构建操作trans=Scale((224,224)),然后调用trans(img).

import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transforms=T.Compose([T.Resize(224),  #缩放图片(Image),保持长宽比不变,最短边为224像素T.CenterCrop(224), #从图片中间裁剪出224*224的图片T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1】T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])  #标准化至【-1,1】,规定均值和方差
])class DogCat(data.Dataset):def __init__(self,root, transforms=None):imgs=os.listdir(root)self.imgs=[os.path.join(root, img) for img in imgs]self.transforms=transformsdef __getitem__(self, index):img_path=self.imgs[index]#dog->1, cat->0label=1 if 'dog' in img_path.split("/")[-1] else 0data=Image.open(img_path)if self.transforms:data=self.transforms(data)return data,labeldef __len__(self):return len(self.imgs)        dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
for img,label in dataset:print(img.size(),label)

结果:

torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1

除了上述操作外,transforms还可以通过Lambda封装自定义的转换策略.
例如相对PIL Image进行随机旋转,则可以写成trans=T.Lambda(lambda img: img.rotate(random()*360)).

ImageFolder

下面介绍一个会经常使用到的Dataset——ImageFolder,它的实现和上述DogCat很相似。
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:

ImageFolder(root, transform=None, target_transform = None, loader = default_loader)

它主要有四个参数:

  • root :在root指定的路径下寻找图片
  • transform: 对PIL Image进行转换操作, transform的输入是使用loader读取图片返回的对象;
  • target_transform :对label的转换;
  • loader: 指定加载图片的函数,默认操作是读取为PIL Image对象。
    label是按照文件夹名顺序排序后存成字典的,即{类名:类序号(从0开始)} ,一般来说最好直接将文件夹命名为从0开始的数字,这样会和ImageFolder实际的label一直,如果不是这种命名规则,建议通过self.class_to_idx属性了解label和文件夹名的映射关系。
from torchvision.datasets import ImageFolder
dataset=ImageFolder('N:\\data\\')
dataset.class_to_idx

运行结果:

{'cat': 0, 'dog': 1}

输入:

#所有图片的路径和对应的label
dataset.imgs

输出:

[('N:\\data\\cat\\cat.1.jpg', 0),('N:\\data\\cat\\cat.2.jpg', 0),('N:\\data\\cat\\cat.3.jpg', 0),('N:\\data\\cat\\cat.4.jpg', 0),('N:\\data\\dog\\dog.9131.jpg', 1),('N:\\data\\dog\\dog.9132.jpg', 1),('N:\\data\\dog\\dog.9133.jpg', 1),('N:\\data\\dog\\dog.9134.jpg', 1)]
#没有任何的transform,所以返回的还是PIL Image对象
dataset[0][1]  #第一维是第几张图,第二维为1返回label

输出:0

dataset[0][0] #第一维是第几张图,第二维为0返回图片数据,返回的Image对象如图所示:

输出:

加上transform:

normilize=T.Normalize(mean=[0.4,0.4,0.4],std=[0.2,0.2,0.2])
transform=T.Compose([T.RandomResizedCrop (224),T.RandomHorizontalFlip(),T.ToTensor(),normilize,
])
dataset=ImageFolder('N:\\data\\',transform=transform)
#深度学习中图片数据一般保存为CxHxWx,即通道数x图片高x图片宽
dataset[0][0].size()

输出:

torch.Size([3, 224, 224])
to_img=T.ToPILImage()
#0.2和0.4是标准差和均值的近似
to_img(dataset[0][0]*0.2+0.4)

输出:

DataLoader加载数据

Dateset只负责数据的抽象,一次调用__getitem__只返回一个样本。
在训练神经网络时,是对一个batch的数据进行操作,同时还要进行shuffle和并行加速等。
对此,pytorch提供了DataLoader帮助我们实现这些功能。
DataLoader的函数定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
  • dataset: 加载的数据集)Dataset对象;
  • batch_size: 批大小;
  • shuffle:是否将数据打乱;
  • sampler:样本抽样
  • num_workers:使用多进程加载的进程数,0代表不使用多进程;
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可;
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些;
  • drop_last:dataset 中的数据个数可能不是 batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。
from torch.utils.data import DataLoader
dataloader=DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)dataiter=iter(dataloader)
imgs,labels=next(dataiter)
imgs.size()

输出:

torch.Size([3, 3, 224, 224])

dataloader是一个可迭代的对象,我们可以像使用迭代器一样使用它,例如:

for batch_datas,batch_labels in dataloader:train()

dataiter=iter(dataloader)
batch_datas,batch_labels =next(dataiter)

sampler:采样模块

Pytorch 中还提供了一个sampler模块,用来对数据进行采样。
常用的有随机采样器RandonSampler,当dataloadershuffle参数为True时,系统会自动调用这个采样器 ,实现打乱数据。

默认的采样器是SequentialSampler, 它会按顺序一个一个进行采样。

这里介绍另外一个很有用的采样方法:它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它进行重采样。

构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部样本数目。
replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。
如果设为False,则当某一类样本被全部选取完,但样本数目仍为达到num_samples时,sampler将不会再从该类中选取数据,此时可能导致weights参数失效。
下面举例说明:
1)

#dataset=DogCat('N:/百度网盘/kaggle/DogCat/',transforms=transforms)
dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
#img,label=dataset[0]#相当于调用dataset.__getitem__(0)
#狗的图片取出的概率是猫的概率的两倍
#两类取出的概率与weights的绝对值大小无关,之和比值有关
weights=[2 if label==1 else 1 for data ,label in dataset]
weights

输出:

[1, 1, 1, 1, 2, 2, 2, 2]

2)

from torch.utils.data.sampler import WeightedRandomSampler
sampler=WeightedRandomSampler(weights,num_samples=9,replacement=True)
dataloader=DataLoader(dataset,batch_size=3,sampler=sampler)
for datas,labels in dataloader:print(labels.tolist())

输出:

[1, 0, 1]
[1, 0, 1]
[1, 0, 1]

可见猫狗样本比例约为1:2,另外一共有8个样本,却返回了9个样本,说明样本有被重复返回的,这就是replacement参数的作用。
下面我们将replacement设置为False.

from torch.utils.data.sampler import WeightedRandomSampler
sampler=WeightedRandomSampler(weights,num_samples=8,replacement=False)
dataloader=DataLoader(dataset,batch_size=4,sampler=sampler)
for datas,labels in dataloader:print(labels.tolist())

输出:

[0, 0, 1, 0]
[1, 0, 1, 1]

在这种情况下,num_samples等于dataset的样本总数,为了 不重复选取,sampler会将每个样本都返回,这样就失去了weight的意义。

从上面的例子可见sampler在采样中的作用:如果指定了samplershuffle将不再生效,并且sampler.num_smples会覆盖dataset的实际大小,即一个epoch返回的图片总数取决于sampler.num_samples.


总结:
完整代码:

import os
from PIL import Image
from torch.utils import data
#import numpy as np
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSamplertransforms=T.Compose([T.Resize(224),  #缩放图片(Image),保持长宽比不变,最短边为224像素T.CenterCrop(224), #从图片中间裁剪出224*224的图片T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1】T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])  #标准化至【-1,1】,规定均值和方差
])class DogCat(data.Dataset):def __init__(self,root, transforms=None):imgs=os.listdir(root)self.imgs=[os.path.join(root, img) for img in imgs]self.transforms=transformsdef __getitem__(self, index):img_path=self.imgs[index]#dog->1, cat->0label=1 if 'dog' in img_path.split("/")[-1] else 0data=Image.open(img_path)if self.transforms:data=self.transforms(data)return data,labeldef __len__(self):return len(self.imgs)        dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
print("******dataset*************")
print("dataset")
for img,label in dataset:print(img.size(),label)dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)#狗的图片取出的概率是猫的概率的两倍
#两类取出的概率与weights的绝对值大小无关,之和比值有关
weights=[2 if label==1 else 1 for data ,label in dataset]
print("******weights**************")print("weight:{}".format(weights))print("******sampler**************")sampler=WeightedRandomSampler(weights,num_samples=8,replacement=False)
dataloader=DataLoader(dataset,batch_size=4,sampler=sampler)
for datas,labels in dataloader:print(labels.tolist())

输出:

******dataset*************
dataset
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
******weights**************
weight:[1, 1, 1, 1, 2, 2, 2, 2]
******sampler**************
[0, 0, 1, 1]
[1, 0, 1, 0]

第五章——Pytorch中常用的工具相关推荐

  1. 鸟哥的Linux私房菜(服务器)- 第五章、 Linux 常用网络指令

    第五章. Linux 常用网络指令 最近更新日期:2011/07/18 Linux 的网络功能相当的强悍,一时之间我们也无法完全的介绍所有的网络指令,这个章节主要的目的在介绍一些常见的网络指令而已. ...

  2. 《SVN宇宙版教程》:第五章 TortoiseSVN中Repo-browser介绍

    第五章 TortoiseSVN中Repo-browser介绍 导言: 窗口Repo-browser是TortoiseSVN提供的一个管理工作副本或仓库文件的工具,此窗口在使用TortoiseSVN工具 ...

  3. 计算机图形学 opengl版本 第三版------胡事民 第四章 图形学中的向量工具

    计算机图形学 opengl版本 第三版------胡事民 第四章  图形学中的向量工具 一   基础 1:向量分析和变换   两个工具  可以设计出各种几何对象 点和向量基于坐标系定义 拇指指向z轴正 ...

  4. 第十三章_Java中常用集合大整理(含底层数据结构简单介绍)

    第十三章_Java中常用集合大整理 1.集合和数组的区别 集合 既可以存储基本数据类型还可以存储引用数据类型 定长–>是数组最大的特点,也是最大的缺点 数组 只能存储引用数据类型 长度可变 相同 ...

  5. shell编程系列7--shell中常用的工具find、locate、which、whereis

    shell编程系列7--shell中常用的工具find.locate.which.whereis1.文件查找之find命令语法格式:find [路径] [选项] [操作]选项 -name 根据文件名查 ...

  6. 工作中常用,实用工具推荐!

    原文:工作中常用,实用工具推荐! Red Gate 家族 大名鼎鼎的RedGate,相信大家都不会陌生,Reflector就是它家做的.这里就不介绍了.我本地安装的是09年下的一个套装,我介绍下常用的 ...

  7. linux性能监控工具perf,Linux性能分析中常用的工具perf介绍

    今天小编要跟大家分享的文章是关于Linux性能分析中常用的工具perf介绍.系统级性能优化通常包括两个阶段:性能剖析(performance profiling)和代码优化.性能剖析的目标是寻找性能瓶 ...

  8. JAVA 开发中常用的工具有哪些?

    Java开发中常用的工具有以下几种: Eclipse:一款非常流行的开发工具,提供了很多方便的功能,如代码自动补全.调试.版本控制等. IntelliJ IDEA:一款功能强大的Java集成开发环境, ...

  9. 分享一些工作中常用的工具软件,值得收藏!

    前言 我之前分享过一篇:分享一些常用的网站和工具,值得收藏!,今天再分享一波关于工作中常用的工具软件! 文章首发在公众号(月伴飞鱼),之后同步到个人网站:http://xiaoflyfish.cn/ ...

  10. 极智AI | Pytorch 中常用乘法的 TensorRT 实现

      欢迎关注我的公众号 [极智视界],获取我的更多笔记分享   大家好,我是极智视界,本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现.   pytorch 用于训练,TensorR ...

最新文章

  1. 工作中总结的一些C#小经验,随时更新
  2. python基础一循环
  3. Gradle 学习二
  4. 【三代增强干货一枚】外向交货单Delivery (VL01N)Header屏幕增强
  5. UVALive - 3126 Taxi Cab Scheme(最小路径覆盖-二分图最大匹配)
  6. qregexp限制数字范围_数字系统实现电压电流控制的必经之路数模转换器
  7. 一个MD4在线加密脚本源码
  8. wire routing 网格寻址
  9. 手把手教你升级到MySQL 8.0
  10. 用DISM修复Win10系统文件教程
  11. IDC基础知识-名词解释
  12. java swing 颜色_Java Swing按钮颜色
  13. 闹钟和时间管理工具Alarm Clock Pro mac
  14. Allegro172版本DFM规则之DFT outline
  15. ChatGPT4.0中国怎么使用
  16. 【关于memset和0x3f3f3f3f】
  17. 关于报错FAILURE: Build failed with an exception.
  18. 大华人脸门禁(人脸闸机)sdk集成对接javaweb接口springboot版
  19. 02 事务伴生源-Propagation
  20. PostgreSQL 操作

热门文章

  1. spring事务传播特性_关于spring的事务的传播propagation特性
  2. mybatis连接mysql url_MyBatis与JDBC连接数据库所使用的url之间的差异
  3. Ubuntu18环境下安装ROS
  4. (day 16 - 双指针)剑指 Offer 35. 复杂链表的复制
  5. java安卓开发异步任务_java – 如何从android中的任何异步操作中获...
  6. php 改变地址栏,php如何修改url
  7. python编程新手常犯的错误_Python新手常犯的10个错误 - 里维斯社
  8. Javascript:json数据根据某一个字段进行排序
  9. Git:常用命令(自用)
  10. java 内存屏障_关于Java中的内存屏障