transforms运行机制

torchvision是pytorch的计算机视觉工具包,在torchvision中有三个主要的模块:

  • torchvision.transforms,常用的图像预处理方法,在transforms中提供了一系列的图像预处理方法,例如数据的标准化,中心化,旋转,翻转等等;
  • torchvision.datasets,定义了一系列常用的公开数据集的datasets,比如常用的MNIST,CIFAR-10,ImageNet等等;
  • torchvision.model,提供大量常用的预训练模型,例如AlexNet,VGG,ResNet,GoogLeNet等等;

transforms

torchvision.transforms:常用的图像预处理方法

  • 数据中心化
  • 数据标准化
  • 缩放
  • 裁剪
  • 旋转
  • 翻转
  • 填充
  • 噪声添加
  • 灰度变换
  • 线性变换
  • 仿射变换
  • 亮度、饱和度及对比度变换

深度学习是由数据驱动的,数据的数量以及分布对模型的优劣起到决定性作用,所以需要对数据进行一定的预处理以及数据增强,用来提升模型的泛化能力;

观察下面这个图,这是经过数据增强之后生成的一系列数据,一共有64张图片,这64张图片都来源于一张原始图片,经过一系列的缩放、裁剪、平移、变换等等操作的组合,生成了64张图片;对图片进行数据增强的原因是为了提高模型的泛化能力,类似于5年高考,3年模拟的卷子;5年高考的真题卷就类似于原始训练数据,3年模拟就相当于做一些数据增强,去丰富训练数据;假如在三年模拟的卷子中出现了当年的高考题,那么分数自然有所提高;同样的,如果我们做数据增强,生成了与测试样本很相似的图片,那么模型的泛化能力自然可以得到提高,这就是做数据增强的原因;

看一下代码,这里使用上一篇博客介绍的人民币二分类实验的代码的数据预处理部分,
数据标准化——transforms.normalize

# ============================ step 1/5 数据 ============================
# 这部分设置数据的路径
split_dir = os.path.join("C:/Users/10530/Desktop/pytorch/rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")#设置数据标准化的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]# transforms.Compose的功能是将一系列的transforms方法进行有序的组合包装,在具体实现的时候,会依次按顺序对图像进行操作
train_transform = transforms.Compose([transforms.Resize((32, 32)),  #Resize,将图像缩放到32*32的大小transforms.RandomCrop(32, padding=4),  #RandomCrop,对数据进行随机的裁剪transforms.ToTensor(),  #ToTensor,将图片转成张量的形式同时会进行归一化操作,把像素值的区间从0-255归一化到0-1transforms.Normalize(norm_mean, norm_std),  #标准化操作,将数据的均值变为0,标准差变为1
])   # Resize的功能是缩放,RandomCrop的功能是裁剪,ToTensor的功能是把图片变为张量#验证集的预处理的方法,对比训练集,少了RandomCrop这一部分,因为在验证集中是不需要对数据进行数据增强的
valid_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例,MyDataset必须是用户自己构建的
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)  # data_dir是数据的路径,transform是数据预处理
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)  # 一个用于训练,一个用于验证# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)  # shuffle=True,每一个epoch中样本都是乱序的
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

同样,在模型训练中设置断点,断点位置位于如下代码处:

for i, data in enumerate(train_loader):

进行debug,并点击step into进行操作,在跳转后的代码中进行一个是否采用多进程的判断:

    def __iter__(self):if self.num_workers == 0:return _SingleProcessDataLoaderIter(self)else:return _MultiProcessingDataLoaderIter(self)

选择单进程的运行机制,进入dataloader.py界面,找到def init(self)方法,点击Run to Cursor,程序就会运行到光标所在的行,具体如下***的代码:

    def __next__(self):****index = self._next_index()  # may raise StopIterationdata = self.dataset_fetcher.fetch(index)  # may raise StopIterationif self.pin_memory:data = _utils.pin_memory.pin_memory(data)return data

这一步的作用是获取Index,也就是要读取哪些数据。得到Index就可以进入dataset_fetcher.fetch(index),根据索引去获取数据;进入到fetch函数:

class _MapDatasetFetcher(_BaseDatasetFetcher):def __init__(self, dataset, auto_collation, collate_fn, drop_last):super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)def fetch(self, possibly_batched_index):if self.auto_collation:data = [self.dataset[idx] for idx in possibly_batched_index]else:data = self.dataset[possibly_batched_index]return self.collate_fn(data)

在fetch函数中,代码

data = [self.dataset[idx] for idx in possibly_batched_index]

调用了dataset,接着进入dataset所在的代码位置,如下所示:

    def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等return img, label

dataest代码位于类RMBDataset(Dataset)中的def getitem()函数,在getitem()中根据索引去获取图片的路径以及标签;然后采用代码

img = Image.open(path_img).convert('RGB')     # 0~255

打开图片,读取进来的图片是一个PIL的数据类型,然后在getitem中调用transform()进行图像预处理操作,通过step_into进入transform()代码位置进行分析,代码位于transform中的def call()函数

    def __call__(self, img):for t in self.transforms:img = t(img)return img

call()函数是一个for循环,也就是依次有序地从compose中去调用预处理方法,第一个预处理方法是t(img),其功能是是Resize缩放;第二个功能是裁剪,第三个功能是进行张量操作,第四个功能是进行归一化;对compose的四个功能循环结束之后,就会返回transform。

transform是在__getitem__()中调用,并且在__getitem__()中实现数据预处理,然后通过__getitem__返回一个样本;

    def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:img = self.transform(img)   # 在这里做transform,转为tensor等等

执行step out操作返回fetch()函数,接着就是不断地循环index获取一个batch_size大小的数据,最后在return的时候调用collate_fn()函数,将数据整理成一个batch_data的形式。

然后执行step out操作返回到dataloader.py中的__next__()函数中,然后跳出dataloader.py回到主代码当中,接着数据就读取进来了。这就是pytorch数据读取和transforms的运行机制。

回顾上面的数据读取流程图,transforms是在getitem中使用的;在getitem中读取一张图片,然后对这一张图片进行一系列预处理,然后返回图片以及标签。

了解了transforms的机制,现在学习一个比较常用的预处理方法,数据的标准化transforms.Normalize;

transforms.Normalize

  • 功能:逐channel的对图像进行标准化,即数据的均值变为0,标准差变为1
  • 标准化的计算公式为 output=(input−mean)/stdoutput = (input - mean) /stdoutput=(input−mean)/std
  • mean:各通道的均值
  • std:各通道的标准差
  • inplace:是否原位操作
transform.Normalize(mean,std,inplace=False)

回到代码中看一下normalize的具体实现方法,transform是在dataset的getitem中实现的,所以可以直接去dataset的getitem函数中设置断点,具体如下:

    def __getitem__(self, index):path_img, label = self.data_info[index]img = Image.open(path_img).convert('RGB')     # 0~255if self.transform is not None:***img = self.transform(img)   # 在这里做transform,转为tensor等等return img, label

代码中***标注的地方就是断点的设置位置,进行debug操作,点击step into进入详细代码环境,进入了transforms.py中的call()函数中,在call函数中循环transforms。

    def __call__(self, tensor):"""Args:tensor (Tensor): Tensor image of size (C, H, W) to be normalized.Returns:Tensor: Normalized Tensor image."""return F.normalize(tensor, self.mean, self.std, self.inplace)

接着进入transforms中查看normalize的实现,来到了normalize()类中的__call__()函数中,代码只有一行,实际上这行代码是调用了pytorch中的function中normalize方法;pytorch的function提供了很多常用的函数,使用step into查看normalize中的具体实现。

    if not _is_tensor_image(tensor):  #输入的合法性判断raise TypeError('tensor is not a torch image.')if not inplace:   #判断是否需要原地操作tensor = tensor.clone()dtype = tensor.dtypemean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)std = torch.as_tensor(std, dtype=dtype, device=tensor.device)tensor.sub_(mean[:, None, None]).div_(std[:, None, None])   #归一化公式return tensor

首先是输入的合法性判断,输入的是tensor,也就是原始的图像,接着判断是否要原地操作,如果不是inplace就需要将张量复制一份到新的内存空间中。下面的代码就是获取数据的均值和标准差,并将数据转换为张量。注意在sub_和div_后面有下划线,意思是进行原位操作,这样就完成了数据标准化的操作。

对数据进行标准化之后可以加快模型的收敛,具体可以看百面机器学习的第一章。

pytorch —— 图像预处理模块(Transforms)相关推荐

  1. 系统学习Pytorch笔记三:Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

    Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html? Pytorch中文文档:https://pytorch-cn.readthedocs ...

  2. pytorch 图像预处理之减去均值,除以方差

    在使用 torchvision.transforms进行数据处理时我们经常进行的操作是: transforms.Normalize((0.485,0.456,0.406), (0.229,0.224, ...

  3. Pytorch图像预处理——归一化、标准化

    在深度学习图像分类.物体检测等过程中,首先要对图像进行归一化和标准化. 原理: 归一化: 式中,input表示输入的图像像素值:max().min()分别表示输入像素的最大值和最小值.output为输 ...

  4. 4.3 pytorch数据预处理:transforms图像增强方法

    一.数据增强概述 二.数据增强方法:裁剪 三.数据增强方法:翻转和旋转 四.数据增强方法:变换 五.transforms方法的选择操作 一.数据增强概述 我们来看图片中的数据增强是怎么样的. 左边的图 ...

  5. pytorch图像预处理

    1.原始图 from PIL import Image from torchvision import transforms as tfs#原始图 im = Image.open('yeban1.JP ...

  6. SDI接口图像预处理模块

  7. PyTorch主要组成模块 | 数据读入 | 数据预处理 | 模型构建 | 模型初始化 | 损失函数 | 优化器 | 训练与评估

    文章目录 一.深度学习任务框架 二.数据读入 三.数据预处理模块-transforms 1.数据预处理transforms模块机制 2.二十二种transforms数据预处理方法 1.裁剪 2. 翻转 ...

  8. PyTorch框架学习六——图像预处理transforms(二)

    PyTorch框架学习六--图像预处理transforms(二) (续)二.transforms的具体方法 4.图像变换 (1)尺寸变换:transforms.Resize() (2)标准化:tran ...

  9. PyTorch框架学习五——图像预处理transforms(一)

    PyTorch框架学习五--图像预处理transforms(一) 一.transforms运行机制 二.transforms的具体方法 1.裁剪 (1)随机裁剪:transforms.RandomCr ...

最新文章

  1. R语言White’s检验实战:检验回归模型中是否存在异方差性(heteroscedasticity)、发生了异常差(heteroscedasticity)问题如何解决
  2. 十分钟完成Bash 脚本进阶!列举Bash经典用法及其案例
  3. 【控制】《多智能体系统一致性协同演化控制理论与技术》纪良浩老师-第3章-有向二阶多智能体系统脉冲一致性
  4. 电阻参数_压敏电阻原理、参数、选型
  5. android软件开发基础课程(一)
  6. 杭电2524 矩形A + B
  7. Android studio SweetAlert for Android
  8. apache禁止访问文件或目录执行权限、禁止运行脚本PHP文件的设置方法
  9. 中兴6908的三层交换
  10. pytorch 神经网络构造
  11. MySQL内存----使用说明全局缓存+线程缓存) 转
  12. 支持树莓派的路由器系统_真香!国产64位树莓派系统上手评测
  13. 第1章-确定superboot210如何为smart210的nand flash进行的分区划分
  14. Android Studio ADB 命令大全
  15. 计算机算法设计与分析 第5版 (王晓东) 课后答案[解析]
  16. 【洛谷 P3191】 [HNOI2007]紧急疏散EVACUATE(二分答案,最大流)
  17. 操作系统的概念、四个特征以及os的发展和分类
  18. windows文件格式转换为linux格式
  19. 一文搞懂 Cocos Creator 3.x 坐标转换!建议收藏
  20. 我最有用的IntelliJ IDEA键盘快捷键

热门文章

  1. asp.net弹出alert提示框
  2. Java jdbc数据库连接池
  3. Jeecg-boot 使用心得建议
  4. 分布式面试 - 分布式服务接口请求的顺序性如何保证?
  5. Docker快速搭建Taiga敏捷开发项目管理平台
  6. C语言,期末复习之穷举法鸡兔同笼问题
  7. 最全免费C语言之苏小红版《高级语言程序设计》第七章188页小学计算机辅助教学系统程序设计
  8. 仿微信选取图片发表朋友圈功能
  9. 函数式编程 -- 函数组合
  10. netlify 部署vue_如何使用Netlify构建和部署网站-全面的教程