
这篇主要介绍torchvision.transformas,基本上PyTorch中的resize、crop、normalize等常见的数据预处理数据增强(data augmentation)操作都可以通过该接口实现。

torchvision.transformas主要涉及两个文件,在transformas.py中定义了各种data augmentation的类,在每个类中通过调用functional.py中对应的函数完成data augmentation操作。

$ vim /home/lwp/.local/lib/python2.7/site-packages/torchvision/transforms/

使用示例,这是Re-ID MGN模型实现代码中的一段,,用到了ResizeRandomHorizontalFlipToTensorNormalize

from importlib import import_module
from torchvision import transforms
from utils.random_erasing import RandomErasing
from data.sampler import RandomSampler
from import dataloaderclass Data:def __init__(self, args):train_list = [transforms.Resize((args.height, args.width), interpolation=3),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]if args.random_erasing:train_list.append(RandomErasing(probability=args.probability, mean=[0.0, 0.0, 0.0]))train_transform = transforms.Compose(train_list)test_transform = transforms.Compose([transforms.Resize((args.height, args.width), interpolation=3),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])if not args.test_only:module_train = import_module('data.' + args.data_train.lower())self.trainset = getattr(module_train, args.data_train)(args, train_transform, 'train')self.train_loader = dataloader.DataLoader(self.trainset,sampler=RandomSampler(self.trainset,args.batchid,batch_image=args.batchimage),#shuffle=True,batch_size=args.batchid * args.batchimage,num_workers=args.nThread)else:self.train_loader = Noneif args.data_test in ['Market1501']:module = import_module('data.' + args.data_train.lower())self.testset = getattr(module, args.data_test)(args, test_transform, 'test')self.queryset = getattr(module, args.data_test)(args, test_transform, 'query')else:raise Exception()self.test_loader = dataloader.DataLoader(self.testset, batch_size=args.batchtest, num_workers=args.nThread)self.query_loader = dataloader.DataLoader(self.queryset, batch_size=args.batchtest, num_workers=args.nThread)


  • from.import functional as F,导入了functional.py中具体的data augmentation函数;
  • __all__列表定义了可以从外部import的函数名或类名。
from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps, ImageEnhance
try:import accimage
except ImportError:accimage = None
import numpy as np
import numbers
import types
import collections
import warningsfrom . import functional as F__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad","Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip","RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation","ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]_pil_interpolation_to_str = {Image.NEAREST: 'PIL.Image.NEAREST',Image.BILINEAR: 'PIL.Image.BILINEAR',Image.BICUBIC: 'PIL.Image.BICUBIC',Image.LANCZOS: 'PIL.Image.LANCZOS',



class Compose(object):"""Composes several transforms together.Args:transforms (list of ``Transform`` objects): list of transforms to compose.Example:>>> transforms.Compose([>>>     transforms.CenterCrop(10),>>>     transforms.ToTensor(),>>> ])"""def __init__(self, transforms):self.transforms = transformsdef __call__(self, img):for t in self.transforms:img = t(img)return imgdef __repr__(self):format_string = self.__class__.__name__ + '('for t in self.transforms:format_string += '\n'format_string += '    {0}'.format(t)format_string += '\n)'return format_string


Convert a PIL Image or numpy.ndarray to tensor.
在做数据归一化之前必须要把PIL Image转成Tensor,其它resize或crop操作不需要。

class ToTensor(object):"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.Converts a PIL Image or numpy.ndarray (H x W x C) in the range[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]."""def __call__(self, pic):"""Args:pic (PIL Image or numpy.ndarray): Image to be converted to tensor.Returns:Tensor: Converted image."""return F.to_tensor(pic)def __repr__(self):return self.__class__.__name__ + '()'


Convert a tensor or an ndarray to PIL Image.



class Normalize(object):"""Normalize a tensor image with mean and standard deviation.Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transformwill normalize each channel of the input ``torch.*Tensor`` i.e.``input[channel] = (input[channel] - mean[channel]) / std[channel]``Args:mean (sequence): Sequence of means for each channel.std (sequence): Sequence of standard deviations for each channel."""def __init__(self, mean, std):self.mean = meanself.std = stddef __call__(self, tensor):"""Args:tensor (Tensor): Tensor image of size (C, H, W) to be normalized.Returns:Tensor: Normalized Tensor image."""return F.normalize(tensor, self.mean, self.std)def __repr__(self):return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


PIL Image实现resize操作。

  • 如果输入为单个int值,则将输入图像的短边resize到这个int数,长边根据对应比例调整,图像长宽比保持不变。
  • 如果输入为(h,w),且h、w为int,则直接将输入图像resize到(h,w)尺寸,图像的长宽比可能会发生变化

__call__方法中调用了functional.py脚本中的resize函数来完成resize操作。因为输入是PIL Image,所以resize函数基本是在调用Image的各种方法。

class Resize(object):"""Resize the input PIL Image to the given size.Args:size (sequence or int): Desired output size. If size is a sequence like(h, w), output size will be matched to this. If size is an int,smaller edge of the image will be matched to this number.i.e, if height > width, then image will be rescaled to(size * height / width, size)interpolation (int, optional): Desired interpolation. Default is``PIL.Image.BILINEAR``"""def __init__(self, size, interpolation=Image.BILINEAR):assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)self.size = sizeself.interpolation = interpolationdef __call__(self, img):"""Args:img (PIL Image): Image to be scaled.Returns:PIL Image: Rescaled image."""return F.resize(img, self.size, self.interpolation)def __repr__(self):interpolate_str = _pil_interpolation_to_str[self.interpolation]return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)


size可以给单个int值,也可以给(int(size), int(size))

class CenterCrop(object):"""Crops the given PIL Image at the center.Args:size (sequence or int): Desired output size of the crop. If size is anint instead of sequence like (h, w), a square crop (size, size) ismade."""def __init__(self, size):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizedef __call__(self, img):"""Args:img (PIL Image): Image to be cropped.Returns:PIL Image: Cropped image."""return F.center_crop(img, self.size)def __repr__(self):return self.__class__.__name__ + '(size={0})'.format(self.size)



class RandomCrop(object):"""Crop the given PIL Image at a random location.Args:size (sequence or int): Desired output size of the crop. If size is anint instead of sequence like (h, w), a square crop (size, size) ismade.padding (int or sequence, optional): Optional padding on each borderof the image. Default is 0, i.e no padding. If a sequence of length4 is provided, it is used to pad left, top, right, bottom bordersrespectively.pad_if_needed (boolean): It will pad the image if smaller than thedesired size to avoid raising an exception."""def __init__(self, size, padding=0, pad_if_needed=False):if isinstance(size, numbers.Number):self.size = (int(size), int(size))else:self.size = sizeself.padding = paddingself.pad_if_needed = pad_if_needed@staticmethoddef get_params(img, output_size):"""Get parameters for ``crop`` for a random crop.Args:img (PIL Image): Image to be cropped.output_size (tuple): Expected output size of the crop.Returns:tuple: params (i, j, h, w) to be passed to ``crop`` for random crop."""w, h = img.sizeth, tw = output_sizeif w == tw and h == th:return 0, 0, h, wi = random.randint(0, h - th)j = random.randint(0, w - tw)return i, j, th, twdef __call__(self, img):"""Args:img (PIL Image): Image to be cropped.Returns:PIL Image: Cropped image."""if self.padding > 0:img = F.pad(img, self.padding)# pad the width if neededif self.pad_if_needed and img.size[0] < self.size[1]:img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))# pad the height if neededif self.pad_if_needed and img.size[1] < self.size[0]:img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))i, j, h, w = self.get_params(img, self.size)return F.crop(img, i, j, h, w)def __repr__(self):return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)



class RandomHorizontalFlip(object):"""Horizontally flip the given PIL Image randomly with a given probability.Args:p (float): probability of the image being flipped. Default value is 0.5"""def __init__(self, p=0.5):self.p = pdef __call__(self, img):"""Args:img (PIL Image): Image to be flipped.Returns:PIL Image: Randomly flipped image."""if random.random() < self.p:return F.hflip(img)return imgdef __repr__(self):return self.__class__.__name__ + '(p={})'.format(self.p)



class RandomVerticalFlip(object):"""Vertically flip the given PIL Image randomly with a given probability.Args:p (float): probability of the image being flipped. Default value is 0.5"""def __init__(self, p=0.5):self.p = pdef __call__(self, img):"""Args:img (PIL Image): Image to be flipped.Returns:PIL Image: Randomly flipped image."""if random.random() < self.p:return F.vflip(img)return imgdef __repr__(self):return self.__class__.__name__ + '(p={})'.format(self.p)


CenterCropRandomCrop在crop时是固定sizeRandomResizedCrop则是random size的crop。

该类源码需要3个参数:sizescaleratio,这里我在使用中将接口中size修改成了size_h, size_w。方法为先crop,再resize到指定尺寸。

class RandomResizedCrop(object):"""Crop the given PIL Image to random size and aspect ratio.A crop of random size (default: of 0.08 to 1.0) of the original size and a randomaspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This cropis finally resized to given size.This is popularly used to train the Inception networks.Args:size: expected output size of each edgescale: range of size of the origin size croppedratio: range of aspect ratio of the origin aspect ratio croppedinterpolation: Default: PIL.Image.BILINEAR"""def __init__(self, size_h, size_w, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):self.size = (size_h, size_w)self.interpolation = interpolationself.scale = scaleself.ratio = ratio@staticmethoddef get_params(img, scale, ratio):"""Get parameters for ``crop`` for a random sized crop.Args:img (PIL Image): Image to be cropped.scale (tuple): range of size of the origin size croppedratio (tuple): range of aspect ratio of the origin aspect ratio croppedReturns:tuple: params (i, j, h, w) to be passed to ``crop`` for a randomsized crop."""for attempt in range(10):area = img.size[0] * img.size[1]target_area = random.uniform(*scale) * areaaspect_ratio = random.uniform(*ratio)w = int(round(math.sqrt(target_area * aspect_ratio)))h = int(round(math.sqrt(target_area / aspect_ratio)))if random.random() < 0.5:w, h = h, wif w <= img.size[0] and h <= img.size[1]:i = random.randint(0, img.size[1] - h)j = random.randint(0, img.size[0] - w)return i, j, h, w# Fallbackw = min(img.size[0], img.size[1])i = (img.size[1] - w) // 2j = (img.size[0] - w) // 2return i, j, w, wdef __call__(self, img):"""Args:img (PIL Image): Image to be cropped and resized.Returns:PIL Image: Randomly cropped and resized image."""i, j, h, w = self.get_params(img, self.scale, self.ratio)return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)def __repr__(self):interpolate_str = _pil_interpolation_to_str[self.interpolation]format_string = self.__class__.__name__ + '(size={0}'.format(self.size)format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))format_string += ', interpolation={0})'.format(interpolate_str)return format_string


