transforms用法介绍

 torchvision.transforms模块主要用于对图像进行转换等一系列预处理操作,其主要目的是对图像数据进行增强,进而提高模型的泛化能力。对图像预处理操作有数据中心化,缩放,裁剪,旋转,翻转,填充,添加噪声,灰度变换,线性变换,仿射变换,亮度,饱和度,对比变换等。

transforms.Compose

 transforms.Compose是将一系列的图像转换函数进行组合,实现时能够按照这些函数的顺序依次去图像进行处理操作,需要注意的是同样的功能也可以用torch.nn.Sequential函数来实现。

CLASS torchvision.transforms.Compose(transforms)

  • transforms:表示图像变换组合的列

transforms.Compose具体实例的代码如下所示

transform_train = transforms.Compose([transforms.RandomCrop(cut_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),
])
transform_train = torch.nn.Sequential(transforms.RandomCrop(cut_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(), )transform_test = transforms.Compose([transforms.TenCrop(cut_size),transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])
transform_test =  torch.nn.Sequential(transforms.TenCrop(cut_size),transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),)

transforms.ToTensor

 transforms.ToTensor的作用是将一个PIL Image格式的图片或者是取值范围为[0,255][0,255][0,255],形状为[H×W×C][\mathrm{H} \times \mathrm{W} \times \mathrm{C}][H×W×C]numpy.ndarray的数组转换为取值范围为[0.0,1.0][0.0,1.0][0.0,1.0],形状为[C×H×W][\mathrm{C}\times \mathrm{H}\times \mathrm{W}][C×H×W]的tensor格式图片。

transforms.RandomCrop

 transforms.RandomCrop的作用是在图片的随机位置上进行裁剪并返回新的图片。

CLASS torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode=‘constant’)

  • size:表示裁剪图片的输出尺寸,如果参数是一个整数则裁剪的是一个正方形
  • padding:表示图像每个边框上的可选填充。默认值是None
  • pad_if_needed:如果图像小于所需大小,它将填充图像,以避免引发异常
  • fill:表示像素填充值,默认值为0。如果元组长度为3,则用于分别填充R、G、B通道
  • padding_mode:表示像素填充值的类型,默认是常值,也有边缘填充,反射和对称

transforms.RandomCrop具体实例的代码实现和对应的可视化图如下所示

from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.io import read_imageplt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)def show(imgs):fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)for i, img in enumerate(imgs):img = T.ToPILImage()(img.to('cpu'))axs[0, i].imshow(np.asarray(img))axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])plt.show()img1 = read_image(str(Path('assets') / 'KOBE1.png'))
img2 = read_image(str(Path('assets') / 'KOBE2.png'))
show([img1, img2])transforms = T.Compose([T.RandomCrop(224),])# transforms = torch.nn.Sequential(
#     T.RandomCrop(224),
# )device = 'cuda' if torch.cuda.is_available() else 'cpu'
img1 = img1.to(device)
img2 = img2.to(device)transformed_img1 = transforms(img1)
transformed_img2 = transforms(img2)
show([transformed_img1, transformed_img2])

transforms.RandomHorizontalFlip

 transforms.RandomHorizontalFlip的作用是以特定的概率将图片进行水平翻转。

CLASS torchvision.transforms.RandomHorizontalFlip(p=0.5)

  • p:表示图片水平翻转的概率,默认值是0.5

transforms.RandomHorizontalFlip具体实例的代码实现和对应的可视化图如下所示

from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.io import read_imageplt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)def show(imgs):fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)for i, img in enumerate(imgs):img = T.ToPILImage()(img.to('cpu'))axs[0, i].imshow(np.asarray(img))axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])plt.show()img1 = read_image(str(Path('assets') / 'KOBE1.png'))  # type : torch
img2 = read_image(str(Path('assets') / 'KOBE2.png'))  # type : torch
show([img1, img2])transforms = T.Compose([T.RandomHorizontalFlip(p=0.9),])# transforms = torch.nn.Sequential(
#      T.RandomHorizontalFlip(p=0.3),
# )device = 'cuda' if torch.cuda.is_available() else 'cpu'
img1 = img1.to(device)
img2 = img2.to(device)transformed_img1 = transforms(img1)
transformed_img2 = transforms(img2)
show([transformed_img1, transformed_img2])

transforms.TenCrop

 transforms.RandomCrop的作用是可以将一张图片的四个角和中心进行裁剪后,然后加上返回的翻转后共10张图片,其中默认翻转是水平翻转。

CLASS torchvision.transforms.TenCrop(size, vertical_flip=False)

  • size:表示裁剪图片的输出尺寸,如果参数是一个整数则裁剪的是一个正方形
  • vertical_flip:表示图片是否用垂直翻转代替水平翻转None

需要注意的是transforms.TenCrop函数的输入必须是PIL\mathrm{PIL}PIL的图片格式,具体实例的代码实现和对应的可视化图如下所示

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as npimport torch
import torchvision.transforms as Tplt.rcParams["savefig.bbox"] = 'tight'torch.manual_seed(0)def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):if not isinstance(imgs[0], list):# Make a 2d grid even if there's just 1 rowimgs = [imgs]num_rows = len(imgs)num_cols = len(imgs[0]) + with_origfig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)for row_idx, row in enumerate(imgs):row = [orig_img] + row if with_orig else rowfor col_idx, img in enumerate(row):ax = axs[row_idx, col_idx]ax.imshow(np.asarray(img), **imshow_kwargs)ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])if with_orig:axs[0, 0].set(title='Original image')axs[0, 0].title.set_size(8)if row_title is not None:for row_idx in range(num_rows):axs[row_idx, 0].set(ylabel=row_title[row_idx])plt.tight_layout()plt.show()orig_img = Image.open(Path('assets') / 'KOBE1.png')   # tyep : PIL
(top_left, top_right, bottom_left, bottom_right, center, flip_top_left, flip_top_right, flip_bottom_left, flip_bottom_right, flip_center) = T.TenCrop(size=(200,200))(orig_img)
plot([[top_left, top_right, bottom_left, bottom_right, center], [flip_top_left, flip_top_right, flip_bottom_left, flip_bottom_right, flip_center]],with_orig=False)

格外需要注意transforms.TenCrop对于每张图片会返回10张变换后的图片,尤其是在测试阶段会导致图片数量和标签数量不匹配,可以进行如下处理

transform = Compose([FiveCrop(size), # this is a list of PIL ImagesLambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor])
#In your test loop you can do the following:
input, target = batch # input is a 5d tensor, target is 2d
bs, ncrops, c, h, w = input.size()
result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops

pytorch中的transforms介绍相关推荐

  1. pytorch中的transforms.ToTensor和transforms.Normalize理解

  2. pytorch中的MSELoss函数

    基本概念 均方误差(mean square error, MSE),是反应估计量与被估计量之间差异程度的一种度量,设ttt是根据子样确定的总体参数θ\thetaθ的一个估计量,(θ−t)2(\thet ...

  3. Pytorch中transforms.Compose()的使用

    torchvision介绍 torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型.torchvision.transforms主要是用于 ...

  4. pytorch中torch.optim的介绍

    pytorch中torch.optim的介绍 这是torch自带的一个优化器,里面自带了求导,更新等操作.开门见山直接讲怎么使用: 常用的引入: import torch.optim as optim ...

  5. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  6. PyTorch中的循环神经网络RNN函数及词嵌入函数介绍

    一.pytroch中的RNN相关函数介绍 1.对于简单的RNN结构,有两种方式进行调用: 1.1 torch.nn.RNN():可以接收一个序列的输入,默认会传入全0的隐藏状态,也可以自己定义初始的隐 ...

  7. python中squeeze函数_详解pytorch中squeeze()和unsqueeze()函数介绍

    squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的 ...

  8. 简单介绍pytorch中分布式训练DDP使用 (结合实例,快速入门)

    文章目录 DDP原理 pytorch中DDP使用 相关的概念 使用流程 如何启动 torch.distributed.launch spawn调用方式 针对实例voxceleb_trainer多卡介绍 ...

  9. pytorch中Parameter()介绍

    用法介绍  pytorch中的Parameter函数可以对某个张量进行参数化.它可以将不可训练的张量转化为可训练的参数类型,同时将转化后的张量绑定到模型可训练参数的列表中,当更新模型的参数时一并将其更 ...

  10. PyTorch中tensor介绍

          PyTorch中的张量(Tensor)如同数组和矩阵一样,是一种特殊的数据结构.在PyTorch中,神经网络的输入.输出以及网络的参数等数据,都是使用张量来进行描述.       torc ...

最新文章

  1. 最短路径树 php,CCNA-开放式最短路径优先(OSPF)真实考题
  2. python3知识点之---------字符串的介绍
  3. HTML 中的marquee标签详解
  4. 观点 | 云原生时代来袭 下一代云数据库技术将走向何方?
  5. SAP HANA 三大特点
  6. 一文读懂图卷积GCN
  7. python认识if语句_python初认识、基础数据类型以及 if 流程控制
  8. 吝啬的国度 ---用vector 来构图
  9. 思维导图软件哪个好?盘点10款好用的思维导图软件
  10. 华为方会提供一份CRS(客户需求)和SOW(工作任务书)
  11. 转载 elm中文手册
  12. 办公室计算机收不到主机打印机,图文详解电脑怎么连接办公室打印机 一招教你搞定!...
  13. 国际象棋 java_A和B和国际象棋
  14. 转载:技术大停滞——范式春梦中的地球工业文明4:范式春梦外的阴影
  15. Elasticsearch实现内容精确匹配查询
  16. canvas橡皮擦功能
  17. 图像分类——猫狗大战问题
  18. node js 通过url下载文件到本地指定目录
  19. NAS还是HFS?教你1分钟免费搭建私有云
  20. 基于keepalived的mysql_【实用】基于keepalived的mysql双主高可用系统

热门文章

  1. 简单大学生静态HTML网页作品 HTML5+CSS大作业——圣诞节节日(7页) 带轮播特效
  2. yxy小蒟蒻的201119总结
  3. 6-2 求解一元二次方程实根的函数 (10 分)
  4. you have got to find what you love
  5. 史上最贵的merge代码,新浪程序员因加班错失年会77万大奖!
  6. ES6转化ES5方法(处理低版本手机白屏等兼容问题)
  7. matlab语言中的[~,b]=sort(A)用法介绍
  8. 前期总结+开学展望(WYL)
  9. 计算机名校远程在职硕士信息汇总Online Master
  10. the little schemer 笔记(10.1)