PyTorch框架中有一个很常用的包:torchvision
torchvision主要由3个子包构成:torchvision.datasetstorchvision.modelstorchvision.transforms
详细内容可参考:http://pytorch.org/docs/master/torchvision/index.html
GitHub:https://github.com/pytorch/vision/tree/master/torchvision。


这篇主要介绍torchvision.transformas,基本上PyTorch中的resize、crop、normalize等常见的数据预处理数据增强(data augmentation)操作都可以通过该接口实现。

torchvision.transformas主要涉及两个文件:transformas.pyfunctional.py,在transformas.py中定义了各种data augmentation的类,在每个类中通过调用functional.py中对应的函数完成data augmentation操作。

$ vim /home/lwp/.local/lib/python2.7/site-packages/torchvision/transforms/transforms.py

使用示例,这是Re-ID MGN模型实现代码中的一段,https://github.com/lwplw/re-id_mgn/blob/master/pytorch_MGN/data/init.py,用到了ResizeRandomHorizontalFlipToTensorNormalize

from importlib import import_module
from torchvision import transforms
from utils.random_erasing import RandomErasing
from data.sampler import RandomSampler
from torch.utils.data import dataloaderclass Data:def __init__(self, args):train_list = [transforms.Resize((args.height, args.width), interpolation=3),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]if args.random_erasing:train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))train_transform = transforms.Compose(train_list)test_transform = transforms.Compose([transforms.Resize((args.height, args.width), interpolation=3),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])if not args.test_only:module_train = import_module('data.' + args.data_train.lower())self.trainset = getattr(module_train, args.data_train)(args, train_transform, 'train')self.train_loader = dataloader.DataLoader(self.trainset,sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),#shuffle=True,batch_size=args.batchid * args.batchimage,num_workers=args.nThread)else:self.train_loader = Noneif args.data_test in ['Market1501']:module = import_module('data.' + args.data_train.lower())self.testset = getattr(module, args.data_test)(args, test_transform, 'test')self.queryset = getattr(module, args.data_test)(args, test_transform, 'query')else:raise Exception()self.test_loader = dataloader.DataLoader(self.testset, batch_size=args.batchtest, num_workers=args.nThread)self.query_loader = dataloader.DataLoader(self.queryset, batch_size=args.batchtest, num_workers=args.nThread)

各种操作的类定义在transformas.py文件中:

  • from.import functional as F,导入了functional.py中具体的data augmentation函数;
  • __all__列表定义了可以从外部import的函数名或类名。
from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps, ImageEnhance
try:import accimage
except ImportError:accimage = None
import numpy as np
import numbers
import types
import collections
import warningsfrom . import functional as F__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad","Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip","RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation","ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]_pil_interpolation_to_str = {Image.NEAREST: 'PIL.Image.NEAREST',Image.BILINEAR: 'PIL.Image.BILINEAR',Image.BICUBIC: 'PIL.Image.BICUBIC',Image.LANCZOS: 'PIL.Image.LANCZOS',
}

Compose()

用来管理各个transform,其中__call__方法就是对输入img遍历所有的transform操作。

class Compose(object):"""Composes several transforms together.Args:transforms (list of ``Transform`` objects): list of transforms to compose.Example:>>> transforms.Compose([>>>     transforms.CenterCrop(10),>>>     transforms.ToTensor(),>>> ])"""def __init__(self, transforms):self.transforms = transformsdef __call__(self, img):for t in self.transforms:img = t(img)return imgdef __repr__(self):format_string = self.__class__.__name__ + '('for t in self.transforms:format_string += '\n'format_string += '    {0}'.format(t)format_string += '\n)'return format_string

ToTensor()

Convert a PIL Image or numpy.ndarray to tensor.
在做数据归一化之前必须要把PIL Image转成Tensor,其它resize或crop操作不需要。

class ToTensor(object):"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.Converts a PIL Image or numpy.ndarray (H x W x C) in the range[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]."""def __call__(self, pic):"""Args:pic (PIL Image or numpy.ndarray): Image to be converted to tensor.Returns:Tensor: Converted image."""return F.to_tensor(pic)def __repr__(self):return self.__class__.__name__ + '()'

ToPILImage()

Convert a tensor or an ndarray to PIL Image.
ToTensor()的反向操作。


Normalize()

数据归一化处理,调用前数据需处理成Tensor

class Normalize(object):"""Normalize a tensor image with mean and standard deviation.Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transformwill normalize each channel of the input ``torch.*Tensor`` i.e.``input[channel] = (input[channel] - mean[channel]) / std[channel]``Args:mean (sequence): Sequence of means for each channel.std (sequence): Sequence of standard deviations for each channel."""def __init__(self, mean, std):self.mean = meanself.std = stddef __call__(self, tensor):"""Args:tensor (Tensor): Tensor image of size (C, H, W) to be normalized.Returns:Tensor: Normalized Tensor image."""return F.normalize(tensor, self.mean, self.std)def __repr__(self):return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

Resize()

PIL Image实现resize操作。

  • 如果输入为单个int值,则将输入图像的短边resize到这个int数,长边根据对应比例调整,图像长宽比保持不变。
  • 如果输入为(h,w),且h、w为int,则直接将输入图像resize到(h,w)尺寸,图像的长宽比可能会发生变化

__call__方法中调用了functional.py脚本中的resize函数来完成resize操作。因为输入是PIL Image,所以resize函数基本是在调用Image的各种方法。

class Resize(object):"""Resize the input PIL Image to the given size.Args:size (sequence or int): Desired output size. If size is a sequence like(h, w), output size will be matched to this. If size is an int,smaller edge of the image will be matched to this number.i.e, if height > width, then image will be rescaled to(size * height / width, size)interpolation (int, optional): Desired interpolation. Default is``PIL.Image.BILINEAR``"""def __init__(self, size, interpolation=Image.BILINEAR):assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)self.size = sizeself.interpolation = interpolationdef __call__(self, img):"""Args:img (PIL Image): Image to be scaled.Returns:PIL Image: Rescaled image."""return F.resize(img, self.size, self.interpolation)def __repr__(self):interpolate_str = _pil_interpolation_to_str[self.interpolation]return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)

CenterCrop()

以输入图像img的中心作为中心点进行指定size的crop操作,在数据增强中一版不会去使用该方法。因为当size固定时,对于同一张img,N次CenterCrop的结果是一样的。
size可以给单个int值,也可以给(int(size), int(size))

class CenterCrop(object):"""Crops the given PIL Image at the center.Args:size (sequence or int): Desired output size of the crop. If size is anint instead of sequence like (h, w), a square crop (size, size) ismade."""def __init__(self, size):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizedef __call__(self, img):"""Args:img (PIL Image): Image to be cropped.Returns:PIL Image: Cropped image."""return F.center_crop(img, self.size)def __repr__(self):return self.__class__.__name__ + '(size={0})'.format(self.size)

RandomCrop()

RandomCrop相比前面的CenterCrop要更加常用一些,两者的区别在于RandomCrop在crop时的中心点坐标是随机的,不再是输入图像的中心坐标,因此基本上每次crop生成的图像都是不同的。

class RandomCrop(object):"""Crop the given PIL Image at a random location.Args:size (sequence or int): Desired output size of the crop. If size is anint instead of sequence like (h, w), a square crop (size, size) ismade.padding (int or sequence, optional): Optional padding on each borderof the image. Default is 0, i.e no padding. If a sequence of length4 is provided, it is used to pad left, top, right, bottom bordersrespectively.pad_if_needed (boolean): It will pad the image if smaller than thedesired size to avoid raising an exception."""def __init__(self, size, padding=0, pad_if_needed=False):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizeself.padding = paddingself.pad_if_needed = pad_if_needed@staticmethoddef get_params(img, output_size):"""Get parameters for ``crop`` for a random crop.Args:img (PIL Image): Image to be cropped.output_size (tuple): Expected output size of the crop.Returns:tuple: params (i, j, h, w) to be passed to ``crop`` for random crop."""w, h = img.sizeth, tw = output_sizeif w == tw and h == th:return 0, 0, h, wi = random.randint(0, h - th)j = random.randint(0, w - tw)return i, j, th, twdef __call__(self, img):"""Args:img (PIL Image): Image to be cropped.Returns:PIL Image: Cropped image."""if self.padding > 0:img = F.pad(img, self.padding)# pad the width if neededif self.pad_if_needed and img.size[0] < self.size[1]:img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))# pad the height if neededif self.pad_if_needed and img.size[1] < self.size[0]:img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))i, j, h, w = self.get_params(img, self.size)return F.crop(img, i, j, h, w)def __repr__(self):return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)

RandomHorizontalFlip()

图像随机水平翻转,翻转概率为0.5

class RandomHorizontalFlip(object):"""Horizontally flip the given PIL Image randomly with a given probability.Args:p (float): probability of the image being flipped. Default value is 0.5"""def __init__(self, p=0.5):self.p = pdef __call__(self, img):"""Args:img (PIL Image): Image to be flipped.Returns:PIL Image: Randomly flipped image."""if random.random() < self.p:return F.hflip(img)return imgdef __repr__(self):return self.__class__.__name__ + '(p={})'.format(self.p)

RandomVerticalFlip()

图像随机垂直翻转

class RandomVerticalFlip(object):"""Vertically flip the given PIL Image randomly with a given probability.Args:p (float): probability of the image being flipped. Default value is 0.5"""def __init__(self, p=0.5):self.p = pdef __call__(self, img):"""Args:img (PIL Image): Image to be flipped.Returns:PIL Image: Randomly flipped image."""if random.random() < self.p:return F.vflip(img)return imgdef __repr__(self):return self.__class__.__name__ + '(p={})'.format(self.p)

RandomResizedCrop()

CenterCropRandomCrop在crop时是固定sizeRandomResizedCrop则是random size的crop。

该类源码需要3个参数:sizescaleratio,这里我在使用中将接口中size修改成了size_h, size_w。方法为先crop,再resize到指定尺寸。
crop时,其中心点坐标和宽高是由get_params方法得到的,首先在scale限定的数值范围内随机生成一个数,用这个数乘以输入图像的面积作为crop后图像的面积,然后在ratio限定的数值范围内随机生成一个数,表示宽高比,根据这两个值就可以得到crop图像的宽高。crop图像的中心点坐标,是类RandomCrop类一样是随机生成的。

class RandomResizedCrop(object):"""Crop the given PIL Image to random size and aspect ratio.A crop of random size (default: of 0.08 to 1.0) of the original size and a randomaspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This cropis finally resized to given size.This is popularly used to train the Inception networks.Args:size: expected output size of each edgescale: range of size of the origin size croppedratio: range of aspect ratio of the origin aspect ratio croppedinterpolation: Default: PIL.Image.BILINEAR"""def __init__(self, size_h, size_w, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):self.size = (size_h, size_w)self.interpolation = interpolationself.scale = scaleself.ratio = ratio@staticmethoddef get_params(img, scale, ratio):"""Get parameters for ``crop`` for a random sized crop.Args:img (PIL Image): Image to be cropped.scale (tuple): range of size of the origin size croppedratio (tuple): range of aspect ratio of the origin aspect ratio croppedReturns:tuple: params (i, j, h, w) to be passed to ``crop`` for a randomsized crop."""for attempt in range(10):area = img.size[0] * img.size[1]target_area = random.uniform(*scale) * areaaspect_ratio = random.uniform(*ratio)w = int(round(math.sqrt(target_area * aspect_ratio)))h = int(round(math.sqrt(target_area / aspect_ratio)))if random.random() < 0.5:w, h = h, wif w <= img.size[0] and h <= img.size[1]:i = random.randint(0, img.size[1] - h)j = random.randint(0, img.size[0] - w)return i, j, h, w# Fallbackw = min(img.size[0], img.size[1])i = (img.size[1] - w) // 2j = (img.size[0] - w) // 2return i, j, w, wdef __call__(self, img):"""Args:img (PIL Image): Image to be cropped and resized.Returns:PIL Image: Randomly cropped and resized image."""i, j, h, w = self.get_params(img, self.scale, self.ratio)return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)def __repr__(self):interpolate_str = _pil_interpolation_to_str[self.interpolation]format_string = self.__class__.__name__ + '(size={0}'.format(self.size)format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))format_string += ', interpolation={0})'.format(interpolate_str)return format_string

PyTorch源码解析--torchvision.transforms(数据预处理、数据增强)相关推荐

  1. pytorch源码解析2——数据处理torch.utils.data

    迭代器 理解 Python 的迭代器是解读 PyTorch 中 torch.utils.data 模块的关键. 在 Dataset, Sampler 和 DataLoader 这三个类中都会用到 py ...

  2. yolov3之pytorch源码解析_springmvc源码架构解析之view

    说在前面 前期回顾 sharding-jdbc源码解析 更新完毕 spring源码解析 更新完毕 spring-mvc源码解析 更新完毕 spring-tx源码解析 更新完毕 spring-boot源 ...

  3. SSD PyTorch源码解析

    0. 引言 0.1 代码来源 代码来源:https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Detection/SSD ...

  4. pytorch源码解析:Python层 pytorchmodule源码

    尝试使用了pytorch,相比其他深度学习框架,pytorch显得简洁易懂.花时间读了部分源码,主要结合简单例子带着问题阅读,不涉及源码中C拓展库的实现. 一个简单例子 实现单层softmax二分类, ...

  5. vue源码解析之『我的数据去哪了』

    之前学习vue的时候就对$set很感兴趣,但奈何一直都是"小打小闹",本以为莫非这玩意根本用不到而渐渐淡忘没想到最近项目中接二连三的出现类似的问题让我不得不重视起来.决心探探这&q ...

  6. Soul网关源码解析目录

    Soul网关源码解析目录 Soul网关源码解析文章列表     对用Java写的高性能网关:Soul,进行一波学习和研究,下面是相关的文章记录 掘金 了解与初步运行 Soul网关源码解析(一) 概览 ...

  7. Android 源码解析之Adapter和AdapterView与适配器模式

    概述 在Android中大量存在着适配器模式,其中的设计思路就是Adapter(提供数据)设在到AdapterView(展示数据集合的视图),其中Adapter体系结构如下 而AdapterView有 ...

  8. Attention is all you need pytorch实现 源码解析01 - 数据预处理、词表的构建

    我们今天开始分析著名的attention is all you need 论文的pytorch实现的源码解析. 由于项目很大,所以我们会分开几讲来进行讲解. 先上源码:https://github.c ...

  9. weiler-atherton多边形裁剪算法_EAST算法超详细源码解析:数据预处理与标签生成...

    作者简介 CW,广东深圳人,毕业于中山大学(SYSU)数据科学与计算机学院,毕业后就业于腾讯计算机系统有限公司技术工程与事业群(TEG)从事Devops工作,期间在AI LAB实习过,实操过道路交通元 ...

最新文章

  1. Luogu P1091 合唱队形
  2. 【转】SVM入门(一)SVM的八股简介
  3. python中name没有定义_python中__name__的使用
  4. 使用layer.tips实现鼠标悬浮时触发事件提示消息实现
  5. java 递归 求一个数的阶乘
  6. PKI与证书服务应用-要点总结
  7. java go md5_Go语言中三种不同md5计算方式的性能比较
  8. springmvc应用-自定义参数解析器
  9. python在科学计算中的应用_Python在科学计算中的应用
  10. WAS:WAS 6.1/7.0上EJB客户端开发
  11. Spring学习笔记—Spring之旅
  12. 中国大学生在线官方微博入围全国十大中央机构微博、全国十大教育微博
  13. 入门云计算要学习掌握哪些技术?
  14. Component name “XXX“ should always be multi-word vue/multi-word-component-names
  15. 自然语言处理NLP中的N-gram模型
  16. 抑郁症治疗过程中有哪些变化?
  17. 考研、考公还是找工作?别在大学因为迷茫这个问题浪费时间了
  18. IMPERVA-WAF 更换网卡
  19. HC05蓝牙模块与手机APP连接
  20. 3D游戏建模师看不看学历?现在转行还能行吗?

热门文章

  1. ASP.NET ZERO 学习 —— (3) 开发手册之介绍和MVC 应用前端
  2. Apple推出针对有缺陷的iPhone 8逻辑板的维修计划
  3. 低功耗深度休眠后无法唤醒、烧录程序,怎么办?(华大半导体HC32L136)
  4. lamp +discuz+wordpress+phpmyadmin实例安装详解
  5. VC++图片类型的识别(附源码)
  6. 数据中台-让数据用起来-第一章笔记
  7. 易语言 判断网络是否连接
  8. 公众号开发——自动回复功能
  9. Java并发包多线程总结
  10. Android之sd卡,SharedPreferences,xml