图像增广

在深度卷积神经网络里我们提到过,大规模数据集是成功应用深度神经网络的前提。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力。例如,我们可以对图像进行不同方式的裁剪,使感兴趣的物体出现在不同位置,从而减轻模型对物体出现位置的依赖性。我们也可以调整亮度、色彩等因素来降低模型对色彩的敏感度。可以说,在当年AlexNet的成功中,图像增广技术功不可没。本节我们将讨论这个在计算机视觉里被广泛使用的技术。

首先,导入实验所需的包或模块。

%matplotlib inline
import os
import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import sys
from PIL import Imagesys.path.append("/home/input/")
#置当前使用的GPU设备仅为0号设备
os.environ["CUDA_VISIBLE_DEVICES"] = "0"   import d2lzh1981 as d2l# 定义device,是否使用GPU,依据计算机配置自动会选择
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

常用的图像增广方法

我们来读取一张形状为400×500400\times 500400×500(高和宽分别为400像素和500像素)的图像作为实验的样例。

d2l.set_figsize()
img = Image.open('/home/kesci/input/img2083/img/cat1.jpg')  #数据集文件路径
d2l.plt.imshow(img)


下面定义绘图函数show_images。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def show_images(imgs, num_rows, num_cols, scale=2):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)for i in range(num_rows):for j in range(num_cols):axes[i][j].imshow(imgs[i * num_cols + j])axes[i][j].axes.get_xaxis().set_visible(False)axes[i][j].axes.get_yaxis().set_visible(False)return axes

大部分图像增广方法都有一定的随机性。为了方便观察图像增广的效果,接下来我们定义一个辅助函数apply。这个函数对输入图像img多次运行图像增广方法aug并展示所有的结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for _ in range(num_rows * num_cols)]show_images(Y, num_rows, num_cols, scale)

翻转和裁剪

左右翻转图像通常不改变物体的类别。它是最早也是最广泛使用的一种图像增广方法。下面我们通过torchvision.transforms模块创建RandomHorizontalFlip实例来实现一半概率的图像水平(左右)翻转。

apply(img, torchvision.transforms.RandomHorizontalFlip())


上下翻转不如左右翻转通用。但是至少对于样例图像,上下翻转不会造成识别障碍。下面我们创建RandomVerticalFlip实例来实现一半概率的图像垂直(上下)翻转。

apply(img, torchvision.transforms.RandomVerticalFlip())


在我们使用的样例图像里,猫在图像正中间,但一般情况下可能不是这样。在5.4节(池化层)里我们解释了池化层能降低卷积层对目标位置的敏感度。除此之外,我们还可以通过对图像随机裁剪来让物体以不同的比例出现在图像的不同位置,这同样能够降低模型对目标位置的敏感性。

在下面的代码里,我们每次随机裁剪出一块面积为原面积10%∼100%10\% \sim 100\%10%∼100%的区域,且该区域的宽和高之比随机取自0.5∼20.5 \sim 20.5∼2,然后再将该区域的宽和高分别缩放到200像素。若无特殊说明,本节中aaa和bbb之间的随机数指的是从区间[a,b][a,b][a,b]中随机均匀采样所得到的连续值。

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)


另一类增广方法是变化颜色。我们可以从4个方面改变图像的颜色:亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。在下面的例子里,我们将图像的亮度随机变化为原图亮度的50%50\%50%(1−0.51-0.51−0.5)∼150%\sim 150\%∼150%(1+0.51+0.51+0.5)。

apply(img, torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0))


我们也可以随机变化图像的色调。

apply(img, torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5))


我们也可以同时设置如何随机变化图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。

叠加多个图像增广方法

实际应用中我们会将多个图像增广方法叠加使用。我们可以通过Compose实例将上面定义的多个图像增广方法叠加起来,再应用到每张图像之上。

augs = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

使用图像增广训练模型

下面我们来看一个将图像增广应用在实际训练中的例子。这里我们使用CIFAR-10数据集,而不是之前我们一直使用的Fashion-MNIST数据集。这是因为Fashion-MNIST数据集中物体的位置和尺寸都已经经过归一化处理,而CIFAR-10数据集中物体的颜色和大小区别更加显著。下面展示了CIFAR-10数据集中前32张训练图像。

CIFAR_ROOT_PATH = '/home/input/cifar102021'
all_imges = torchvision.datasets.CIFAR10(train=True, root=CIFAR_ROOT_PATH, download = True)
# all_imges的每一个元素都是(image, label)
show_images([all_imges[i][0] for i in range(32)], 4, 8, scale=0.8);

为了在预测时得到确定的结果,我们通常只将图像增广应用在训练样本上,而不在预测时使用含随机操作的图像增广。在这里我们只使用最简单的随机左右翻转。此外,我们使用ToTensor将小批量图像转成PyTorch需要的格式,即形状为(批量大小, 通道数, 高, 宽)、值域在0到1之间且类型为32位浮点数。

flip_aug = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor()])no_aug = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

接下来我们定义一个辅助函数来方便读取图像并应用图像增广。有关DataLoader的详细介绍,可参考更早的图像分类数据集(Fashion-MNIST)。

num_workers = 0 if sys.platform.startswith('win32') else 4
def load_cifar10(is_train, augs, batch_size, root=CIFAR_ROOT_PATH):dataset = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=augs, download=False)return DataLoader(dataset, batch_size=batch_size, shuffle=is_train, num_workers=num_workers)

使用图像增广训练模型

我们先定义train函数使用GPU训练并评价模型。

# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)batch_count = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = d2l.evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

然后就可以定义train_with_data_aug函数使用图像增广来训练模型了。该函数使用Adam算法作为训练使用的优化算法,然后将图像增广应用于训练数据集之上,最后调用刚才定义的train函数训练并评价模型。

def train_with_data_aug(train_augs, test_augs, lr=0.001):batch_size, net = 256, d2l.resnet18(10)optimizer = torch.optim.Adam(net.parameters(), lr=lr)loss = torch.nn.CrossEntropyLoss()train_iter = load_cifar10(True, train_augs, batch_size)test_iter = load_cifar10(False, test_augs, batch_size)train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)

下面使用随机左右翻转的图像增广来训练模型。

train_with_data_aug(flip_aug, no_aug)

从零开始学Pytorch(十五)之数据增强相关推荐

  1. 从零开始学Pytorch(五)之欠拟合和过拟合

    本文首发于微信公众号"计算机视觉cv" 模型选择.过拟合和欠拟合 训练误差和泛化误差 训练误差(training error)指模型在训练数据集上表现出的误差,泛化误差(gener ...

  2. Java从零开始学四十五(Socket编程基础)

    一.网络编程中两个主要的问题 一个是如何准确的定位网络上一台或多台主机,另一个就是找到主机后如何可靠高效的进行数据传输. 在TCP/IP协议中IP层主要负责网络主机的定位,数据传输的路由,由IP地址可 ...

  3. 从零开始学Pytorch(第5天)

    从零开始学Pytorch(第5天) 前言 一.模块类的构建 1. nn.Module 2.构建一个线性回归类 二.计算图和自动求导机制 1.计算图 2.自动求导 总结 前言 今天主要了解和学习Pyto ...

  4. 从零开始学Pytorch(零)之安装Pytorch

    本文首发于公众号"计算机视觉cv" Pytorch优势   聊聊为什么使用Pytorch,个人觉得Pytorch比Tensorflow对新手更为友善,而且现在Pytorch在学术界 ...

  5. WCF技术剖析之十五:数据契约代理(DataContractSurrogate)在序列化中的作用

    如果一个类型,不一定是数据契约,和给定的数据契约具有很大的差异,而我们要将该类型的对象序列化成基于数据契约对应的XML.反之,对于一段给定的基于数据契约的XML,要通过反序列化生成该类型的对象,我们该 ...

  6. Excel数据分析从入门到精通(十五)数据透视表之动态仪表盘

    Excel数据分析从入门到精通(十五)数据透视表之动态仪表盘 1.绘制销售额汇总情况 2.绘制种类销售额情况+种类销售额占比 种类销售额情况 种类销售占比 3.绘制地区销售额情况和地区销售额占比 地区 ...

  7. E-捡贝壳 2021年广东工业大学第十五届文远知行杯程序设计竞赛(同步赛)

    E-捡贝壳 2021年广东工业大学第十五届文远知行杯程序设计竞赛(同步赛) 小明来到一片海滩上,他很喜欢捡贝壳,但他只喜欢质量为x的倍数的贝壳. 贝壳被排列成一条直线,下标从1到n编号,小明打算从编号 ...

  8. Excel数据分析从入门到精通(十五)数据透视表基础

    Excel数据分析从入门到精通(十五)数据透视表 1.Excel透视表前言 2.Excel透视表的创建 3.Excel透视表的组成 4.Excel透视表的十大技巧 ①如何创建汇总行 ②如何展示百分比 ...

  9. G-分割 2021年广东工业大学第十五届文远知行杯程序设计竞赛(同步赛)

    G-分割 2021年广东工业大学第十五届文远知行杯程序设计竞赛(同步赛) 在一个二维平面, 有n条平行于y轴的直线, 他们的x坐标是x[i], m条平行于x轴的直线y[i],他们的y坐标是y[i]. ...

  10. 【动手学深度学习PyTorch版】27 数据增强

    上一篇请移步[动手学深度学习PyTorch版]23 深度学习硬件CPU 和 GPU_水w的博客-CSDN博客 目录 一.数据增强 1.1 数据增强(主要是关于图像增强) ◼ CES上的真实的故事 ◼ ...

最新文章

  1. 数据挖掘原理与算法:Jupyter
  2. Centos 开机无法输入密码的问题
  3. Cstring转化为String
  4. dataframe类型数据的遍历_Python零基础入门到爬虫再到数据分析,这些你都是要学会的...
  5. c语言用指针两个字母交换,c语言指针基础之用指针交换两个数(代码实例)
  6. 微信小程序 ---- 学习目标认识小程序
  7. [机器学习笔记]奇异值分解SVD简介及其在推荐系统中的简单应用
  8. pthread之条件变量pthread_cond_t
  9. 电子邮箱怎么填写格式,手机邮箱格式怎么填写?
  10. 动作游戏的打击感和音效的关系
  11. 米世金《货币经济学》思维导图 附自制PPT
  12. Learning through Auxiliary Tasks——辅助任务学习or自监督学习中的pretext
  13. Alpha 冲刺(5/10)
  14. Effective Approaches to Attention-based Neural Machine Translation笔记
  15. 欧冠 欧洲杯免费直播平台
  16. 初级中学计算机知识,计算机基础知识(初级中学级教学方案课程教案).doc
  17. 使用office的邮件合并和文档附件制作带照片的准考证
  18. 【智能制造】动力电池行业智能制造发展趋势分析
  19. 我要学编程,看什么书好?--^_^,这里推荐一些个人觉得很不错的书(三)
  20. B站黑马Java基础+就业班+各种项目idea版本(正在更新)2 IO流

热门文章

  1. OpenDayLight Helium实验一 OpenDaylight的C/S模式实验
  2. Linux系统故障处理案例(一)【转】
  3. linux shell sleep/wait(转载)
  4. [FZYZOJ 1202] 金坷垃
  5. __stdcall函数调用约定
  6. IOS APP 国际化(实现不跟随系统语言,不用重启应用,代码切换stroyboard ,xib ,图片,其他资源)...
  7. 软件工程复习提纲——第四章
  8. wolfssl 何如 https post_干货:手把手教你优化关键词|亚马逊|流量|搜索量|长尾词|https...
  9. OceanBase杨传辉:一体化架构的分布式数据库已成为企业级系统首选
  10. 参会指南丨3分钟带你玩转2020数据技术嘉年华!