计算机视觉是深度学习中最重要的一类应用,为了方便研究者使用,PyTorch 团队专门开发了一个视觉工具包torchvision,这个包独立于 PyTorch,需通过 pip instal torchvision 安装。

torchvision 主要包含三部分:

  • models:提供深度学习中各种经典网络的网络结构以及预训练好的模型,包括 AlexNetVGG 系列、ResNet 系列、Inception 系列等;
  • datasets: 提供常用的数据集加载,设计上都是继承 torch.utils.data.Dataset,主要包括 MNISTCIFAR10/100ImageNetCOCO等;
  • transforms:提供常用的数据预处理操作,主要包括对 Tensor 以及 PIL Image 对象的操作;
from torchvision import models
from torch import nn
from torchvision import datasets'''加载预训练好的模型,如果不存在会进行下载
预训练好的模型保存在 ~/.torch/models/下面'''
resnet34 = models.squeezenet1_1(pretrained=True, num_classes=1000)'''修改最后的全连接层为10分类问题(默认是ImageNet上的1000分类)'''
resnet34.fc=nn.Linear(512, 10)'''加上transform'''
transform  = T.Compose([T.ToTensor(),T.Normalize(mean=[0.4,], std=[0.2,]),
])
'''
# 指定数据集路径为data,如果数据集不存在则进行下载
# 通过train=False获取测试集
'''
dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)

Transforms 中涵盖了大部分对 TensorPIL Image 的常用处理,这些已在上文提到,这里就不再详细介绍。需要注意的是转换分为两步,

  • 第一步:构建转换操作,例如 transf = transforms.Normalize(mean=x, std=y)
  • 第二步:执行转换操作,例如 output = transf(input) 。另外还可将多个处理操作用 Compose 拼接起来,形成一个处理转换流程。
from torchvision import transforms
to_pil = transforms.ToPILImage()
to_pil(t.randn(3, 64, 64))

输出随机噪声,待补充:

torchvision 还提供了两个常用的函数。

  • 一个是 make_grid ,它能将多张图片拼接成一个网格中;
  • 另一个是 save_img ,它能将 Tensor 保存成图片。
len(dataset) # 10000
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成4*4网格图片,且会转成3通道
to_img(img)

输出:(待补充)

save_image(img, 'a.png')
Image.open('a.png')

输出:(待补充)

1. datasets

使用 torchvision.datasets 可以轻易实现对这些数据集的训练集和测试集的下载,只需要使用 torchvision.datasets 再加上需要下载的数据集的名称就可以了。

比如在这个问题中我们要用到手写数字数据集,它的名称是 MNIST,那么实现下载的代码就是
torchvision.datasets.MNIST。其他常用的数据集如 COCOImageNetCIFCAR 等都可以通过这个方法快速下载和载入。实现数据集下载的代码如下:

import torch as t
from torchvision import datasets, transformsdata_train = datasets.MNIST(root="./data", transform=transform, train=True, download=True)
data_test = datasets.MNIST(root="./data", transform=transform, train=False)

其中,

  • root 用于指定数据集在下载之后的存放路径,这里存放在根目录下的 data 文件夹中;
  • transform 用于指定导入数据集时需要对数据进行哪种变换操作;

注意,要提前定义这些变换操作;train 用于指定在数据集下载完成后需要载入哪部分数据,

  • 如果设置为 True,则说明载入的是该数据集的训练集部分;
  • 如果设置为 False,则说明载入的是该数据集的测试集部分;

2. transforms

在计算机视觉中处理的数据集有很大一部分是图片类型的,而在 PyTorch 中实际进行计算的是 Tensor 数据类型的变量,所以我们首先需要解决的是数据类型转换的问题,如果获取的数据是格式或者大小不一的图片,则还需要进行归一化和大小缩放等操作,庆幸的是,这些方法在 torch.transforms 中都能找到。

torch.transforms 中有大量的数据变换类,其中有很大一部分可以用于实现数据增强(DataArgumentation)。若在我们需要解决的问题上能够参与到模型训练中的图片数据非常有限,则这时就要通过对有限的图片数据进行各种变换,来生成新的训练集了,这些变换可以是缩小或者放大图片的大小、对图片进行水平或者垂直翻转等,都是数据增强的方法。

不过在手写数字识别的问题上可以不使用数据增强的方法,因为可用于模型训练的数据已经足够了。对数据进行载入及有相应变化的代码如下:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])

我们可以将以上代码中的 torchvision.transforms.Compose 类看作一种容器,它能够同时对多种数据变换进行组合。传入的参数是一个列表,列表中的元素就是对载入的数据进行的各种变换操作。

在以上代码中,在 torchvision.transforms.Compose 中只使用了一个类型的转换变换 transforms.ToTensor 和一个数据标准化变换transforms.Normalize

这里使用的标准化变换也叫作标准差变换法,这种方法需要使用原始数据的均值(Mean)和标准差(StandardDeviation)来进行数据的标准化,在经过标准化变换之后,数据全部符合均值为0、标准差为1的标准正态分布。

下面看看在 torchvision.transforms 中常用的数据变换操作。

  • torchvision.transforms.Resize:用于对载入的图片数据按我们需求的大小进行缩放。传递给这个类的参数可以是一个整型数据,也可以是一个类似于(h, w)的序列,其中,h 代表高度,w 代表宽度,但是如果使用的是一个整型数据,那么表示缩放的宽度和高度都是这个整型数据的值。
  • torchvision.transforms.Scale:用于对载入的图片数据按我们需求的大小进行缩放,用法和
    torchvision.transforms.Resize类似。
  • torchvision.transforms.CenterCrop:用于对载入的图片以图片中心为参考点,按我们需要的大小进行裁剪。传递给这个类的参数可以是一个整型数据,也可以是一个类似于(h,w)的序列。* torchvision.transforms.RandomCrop:用于对载入的图片按我们需要的大小进行随机裁剪。传递给这个类的参数可以是一个整型数据,也可以是一个类似于(h,w)的序列。
  • torchvision.transforms.RandomHorizontalFlip:用于对载入的图片按随机概率进行水平翻转。我们可以通过传递给这个类的参数自定义随机概率,如果没有定义,则使用默认的概率值 0.5。
  • torchvision.transforms.RandomVerticalFlip:用于对载入的图片按随机概率进行垂直翻转。我们可以通过传递给这个类的参数自定义随机概率,如果没有定义,则使用默认的概率值 0.5。
  • torchvision.transforms.ToTensor:用于对载入的图片数据进行类型转换,将之前构成 PIL 图片的数据转换成 Tensor 数据类型的变量,让 PyTorch 能够对其进行计算和处理。
  • torchvision.transforms.ToPILImage:用于将 Tensor 变量的数据转换成 PIL 图片数据,主要是为了方便图片内容的显示。

3. 数据预览和加载

在数据下载完成并且载入后,我们还需要对数据进行装载。我们可以将数据的载入理解为对图片的处理,在处理完成后,我们就需要将这些图片打包好送给我们的模型进行训练了,而装载就是这个打包
的过程。

在装载时通过 batch_size 的值来确认每个包的大小,通过 shuffle 的值来确认是否在装载的过程中打乱图片的顺序。装载图片的代码如下:

data_loader_train = torch.utils.data.DataLoader(dataset=data_train, batch_size = 64,shuffle = True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test, batch_size=64,shuffle = True)

对数据的装载使用的是 torch.utils.data.DataLoader 类,类中的

  • dataset 参数用于指定我们载入的数据集名称;
  • batch_size 参数设置了每个包中的图片数据个数,代码中的值是 64,所以在每个包中会包含64张图片;
  • shuffle 参数设置为 True,在装载的过程会将数据随机打乱顺序并进行打包;

在装载完成后,我们可以选取其中一个批次的数据进行预览。进行数据预览的代码如下:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1,2,0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print([labels[i] for i in range(64))

在以上代码中使用了 iternext 来获取一个批次的图片数据和其对应的图片标签,然后使用torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

需要传递给 torchvision.utils.make_grid 的参数就是一个批次的装载数据,每个批次的装载数据都是 4 维的,维度的构成从前往后分别为 batch_sizechannelheightweight ,分别对应一个批次中的数据个数、每张图片的色彩通道数、每张图片的高度和宽度。

在通过 torchvision.utils.make_grid 之后,图片的维度变成了( channel , height , weight ),这个批次的图片全部被整合到了一起,所以在这个维度中对应的值也和之前不一样了,但是色彩通道数保持不变。

若我们想使用Matplotlib将数据显示成正常的图片形式,则使用的数据首先必须是数组,其次这个数组的维度必须是(height,weight,channel),即色彩通道数在最后面。所以我们要通过 numpytranspose 完成原始数据类型的转换和数据维度的交换,这样才能够使用Matplotlib绘制出正确的图像。

4. 模型搭建和参数优化

(1)torch.nn.Conv2d:用于搭建卷积神经网络的卷积层,主要的输入参数有输入通道数、输出通道数、卷积核大小、卷积核移动步长和Paddingde值。其中,输入通道数的数据类型是整型,用于确定输入数据的层数;输出通道数的数据类型也是整型,用于确定输出数据的层数;卷积核大小的数据类型是整型,用于确定卷积核的大小;卷积核移动步长的数据类型是整型,用于确定卷积核每次滑动的步长;Paddingde 的数据类型是整型,值为0时表示不进行边界像素
的填充,如果值大于0,那么增加数字所对应的边界像素层数。

(2)torch.nn.MaxPool2d:用于实现卷积神经网络中的最大池化层,主要的输入参数是池化窗口大小、池化窗口移动步长和Paddingde值。同样,池化窗口大小的数据类型是整型,用于确定池化窗口的大小。池化窗口步长的数据类型也是整型,用于确定池化窗口每次移动的步长。Paddingde值和在torch.nn.Conv2d中定义的Paddingde值的用法和意义是一样的。

(3)torch.nn.Dropout:torch.nn.Dropout类用于防止卷积神经网络在训练的过程中发生过拟合,其工作原理简单来说就是在模型训练的过程中,以一定的随机概率将卷积神经网络模型的部分参数归零,以达到减少相邻两层神经连接的目的。图 6-3显示了 Dropout方法的效果。

PyTorch 笔记(20)— torchvision 的 datasets、transforms 数据预览和加载、模型搭建(torch.nn.Conv2d/MaxPool2d/Dropout)相关推荐

  1. pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题

    首先很多网上的博客,讲的都不对,自己跟着他们踩了很多坑 1.单卡训练,单卡加载 这里我为了把三个模块save到同一个文件里,我选择对所有的模型先封装成一个checkpoint字典,然后保存到同一个文件 ...

  2. 【pytorch】(六)保存和加载模型

    文章目录 保存和加载模型 保存加载模型参数 保存加载模型和参数 保存和加载模型 import torch from torch import nn from torch.utils.data impo ...

  3. PyTorch | 保存和加载模型教程

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来自 Unsplash,作者: Jenny Caywood 2019 ...

  4. python保存模型与参数_基于pytorch的保存和加载模型参数的方法

    当我们花费大量的精力训练完网络,下次预测数据时不想再(有时也不必再)训练一次时,这时候torch.save(),torch.load()就要登场了. 保存和加载模型参数有两种方式: 方式一: torc ...

  5. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

  6. vue-element-xlsx在线读取Excel数据预览

    vue-element-xlsx在线读取Excel数据预览 1.安装XLSX npm install xlsx -s 2.复制过去就可以用 <template><div>< ...

  7. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

  8. 【pytorch笔记】(五)自定义损失函数、学习率衰减、模型微调

    本文目录: 1. 自定义损失函数 2. 动态调整学习率 3. 模型微调-torchvision 3.1 使用已有模型 3.2 训练特定层 1. 自定义损失函数 虽然pytorch提供了许多常用的损失函 ...

  9. pytorch学习笔记(6):GPU和如何保存加载模型

    参考文档:https://mp.weixin.qq.com/s/kmed_E4MaDwN-oIqDh8-tg 上篇文章我们完成了一个 vgg 网络的实现,那么现在已经掌握了一些基础的网络结构的实现,距 ...

最新文章

  1. topcoder-SRM565-div2-第二题-500分--搜索/动态规划
  2. Cisco 交换机密码重置步骤
  3. arm9重启ssh服务_部署ssh使用rsa登录配置
  4. Python为你打开一扇门
  5. Bailian4029 数字反转【进制】(POJ NOI0105-29)
  6. XenDesktop 学习笔记1之DDC
  7. 运动会管理系统的需求调研会纪要
  8. DosBox装Windows98
  9. LaTeX语法环境配置:TeXLive + WinEdt
  10. 基于java+springboot+mybatis+laiyu实现学科竞赛管理系统《建议收藏》
  11. SQL Server添加Northwind数据库
  12. Acer 常见笔记本产品内存扩展对照表
  13. 年轻时欠下风流情债的十大男女明星(组图)
  14. 如何在PCB中放置禁止触摸标志
  15. google api设计指南-简介
  16. docker创建mysql容器
  17. Java P1035 [NOIP2002 普及组] 级数求和 洛谷入门题
  18. 用Java和Jquery实现了一个砸金蛋例子
  19. Android八门神器(一):OkHttp框架源码解析
  20. C++常见十六进制数组转换char数组方法

热门文章

  1. 2021-2027全球与中国经颅磁刺激仪(TMS)市场现状及未来发展趋势
  2. 联合索引最左匹配原则成因
  3. 广东java工资一般多少_广东java工资待遇,广东java工资一般多少,广东java工资底薪最低多少...
  4. 2022-2028年中国交通建设PPP模式深度分析及发展战略研究报告(全卷)
  5. C++核心编程(三)
  6. Registry仓库Harbor的部署与简介
  7. ViewGroup的Touch事件分发(源码分析)
  8. 零起点学算法01——第一个程序Hello World!
  9. 0x02 mysql 表格相关操作
  10. How to Use tomcat on Linux