前言

原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz

翻译:林不清(https://www.zhihu.com/people/lu-guo-92-42-88)

目录

60分钟入门PyTorch(一)——Tensors

60分钟入门PyTorch(二)——Autograd自动求导

60分钟入门Pytorch(三)——神经网络

60分钟入门PyTorch(四)——训练一个分类器

训练一个分类器

你已经学会如何去定义一个神经网络,计算损失值和更新网络的权重。

你现在可能在思考:数据哪里来呢?

关于数据

通常,当你处理图像,文本,音频和视频数据时,你可以使用标准的Python包来加载数据到一个numpy数组中.然后把这个数组转换成torch.*Tensor

  • 对于图像,有诸如Pillow,OpenCV包等非常实用

  • 对于音频,有诸如scipy和librosa包

  • 对于文本,可以用原始Python和Cython来加载,或者使用NLTK和SpaCy 对于视觉,我们创建了一个torchvision包,包含常见数据集的数据加载,比如Imagenet,CIFAR10,MNIST等,和图像转换器,也就是torchvision.datasetstorch.utils.data.DataLoader

这提供了巨大的便利,也避免了代码的重复。

在这个教程中,我们使用CIFAR10数据集,它有如下10个类别:’airplane’,’automobile’,’bird’,’cat’,’deer’,’dog’,’frog’,’horse’,’ship’,’truck’。这个数据集中的图像大小为3*32*32,即,3通道,32*32像素。

训练一个图像分类器

我们将按照下列顺序进行:

  • 使用torchvision加载和归一化CIFAR10训练集和测试集.

  • 定义一个卷积神经网络

  • 定义损失函数

  • 在训练集上训练网络

  • 在测试集上测试网络

1. 加载和归一化CIFAR10

使用torchvision加载CIFAR10是非常容易的。

%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms

torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。

注意

如果在Windows上运行时出现BrokenPipeError,尝试将torch.utils.data.DataLoader()的num_worker设置为0。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#这个过程有点慢,会下载大约340mb图片数据。

我们展示一些有趣的训练图像。

import matplotlib.pyplot as plt
import numpy as np# functions to show an imagedef imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

2. 定义一个卷积神经网络

从之前的神经网络一节复制神经网络代码,并修改为接受3通道图像取代之前的接受单通道图像。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()

3. 定义损失函数和优化器

我们使用交叉熵作为损失函数,使用带动量的随机梯度下降。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. 训练网络

这是开始有趣的时刻,我们只需在数据迭代器上循环,把数据输入给网络,并优化。

for epoch in range(2):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')

保存一下我们的训练模型

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

点击这里查看关于保存模型的详细介绍

5. 在测试集上测试网络

我们在整个训练集上训练了两次网络,但是我们还需要检查网络是否从数据集中学习到东西。

我们通过预测神经网络输出的类别标签并根据实际情况进行检测,如果预测正确,我们把该样本添加到正确预测列表。

第一步,显示测试集中的图片一遍熟悉图片内容。

dataiter = iter(testloader)
images, labels = dataiter.next()# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

接下来,让我们重新加载我们保存的模型(注意:保存和重新加载模型在这里不是必要的,我们只是为了说明如何这样做):

net = Net()
net.load_state_dict(torch.load(PATH))

现在我们来看看神经网络认为以上图片是什么?

outputs = net(images)

输出是10个标签的概率。一个类别的概率越大,神经网络越认为他是这个类别。所以让我们得到最高概率的标签。

_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]for j in range(4)))

这结果看起来非常的好。

接下来让我们看看网络在整个测试集上的结果如何。

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

结果看起来好于偶然,偶然的正确率为10%,似乎网络学习到了一些东西。

那在什么类上预测较好,什么类预测结果不好呢?

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

接下来干什么?

我们如何在GPU上运行神经网络呢?

在GPU上训练

你是如何把一个Tensor转换GPU上,你就如何把一个神经网络移动到GPU上训练。这个操作会递归遍历有所模块,并将其参数和缓冲区转换为CUDA张量。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
#假设我们有一台CUDA的机器,这个操作将显示CUDA设备。
print(device)

接下来假设我们有一台CUDA的机器,然后这些方法将递归遍历所有模块并将其参数和缓冲区转换为CUDA张量:

net.to(device)

请记住,你也必须在每一步中把你的输入和目标值转换到GPU上:

inputs, labels = inputs.to(device), labels.to(device)

为什么我们没注意到GPU的速度提升很多?那是因为网络非常的小。

实践:

尝试增加你的网络的宽度(第一个nn.Conv2d的第2个参数, 第二个nn.Conv2d的第一个参数,他们需要是相同的数字),看看你得到了什么样的加速。

实现的目标:

  • 深入了解了PyTorch的张量库和神经网络

  • 训练了一个小网络来分类图片

在多GPU上训练

如果你希望使用所有GPU来更大的加快速度,请查看选读:[数据并行]:(https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html)

接下来做什么?

  • 训练神经网络玩电子游戏

  • 在ImageNet上训练最好的ResNet

  • 使用对抗生成网络来训练一个人脸生成器

  • 使用LSTM网络训练一个字符级的语言模型

  • 更多示例

  • 更多教程

  • 在论坛上讨论PyTorch

  • 在Slack上与其他用户聊天

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑
本站知识星球“黄博的机器学习圈子”(92416895)
本站qq群704220115。
加入微信群请扫码:

【深度学习】翻译:60分钟入门PyTorch(四)——训练一个分类器相关推荐

  1. PyTorch深度学习60分钟闪电战:04 训练一个分类器

    本系列是PyTorch官网Tutorial Deep Learning with PyTorch: A 60 Minute Blitz 的翻译和总结. PyTorch概览 Autograd - 自动微 ...

  2. PyTorch深度学习:60分钟入门(Translation)

    这是https://zhuanlan.zhihu.com/p/25572330的学习笔记. Tensors Tensors和numpy中的ndarrays较为相似, 因此Tensor也能够使用GPU来 ...

  3. 【深度学习】翻译:60分钟入门PyTorch(三)——神经网络

    前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute Blitz 翻译:林不清(https://www.zhihu.com/people/lu-guo-92 ...

  4. 可下载:60分钟入门PyTorch(中文翻译全集)

    来源:机器学习初学者本文约9500字,建议阅读20分钟官网教程翻译:60分钟入门PyTorch(全集) 前言 原文翻译自:Deep Learning with PyTorch: A 60 Minute ...

  5. PyTorch深度学习:60分钟闪电战

    使用PYTORCH进行深度学习:60分钟的闪电战 本教程的目标: 全面了解PyTorch的Tensor库和神经网络. 训练一个小型神经网络对图像进行分类 请确保您有 torch 和 torchvisi ...

  6. Pytorch 学习(2):神经网络及训练一个分类器(cifar10_tutorial的网络结构图)

    Pytorch 学习(2):神经网络及训练一个分类器(cifar10_tutorial的网络结构图) 本文代码来自Pytorch官网入门教程,相关内容可以从Pytorch官网学习. cifar10_t ...

  7. 60分钟入门PyTorch,官方教程手把手教你训练第一个深度学习模型(附链接)

    来源:机器之心 本文约800字,建议阅读5分钟. 本文介绍了官方教程入门PyTorch的技巧训练. 近期的一份调查报告显示:PyTorch 已经力压 TensorFlow 成为各大顶会的主流深度学习框 ...

  8. 60分钟入门PyTorch,官方教程手把手教你训练第一个深度学习模型

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自机器之心. 近期的一份调查报告显示:PyTorch 已经力压 TensorFlow 成为各大顶会的主流深度学习框架.想发论文,不学 PyTor ...

  9. Pytorch 60分钟入门之(四) TRAINING A CLASSIFIER 训练一个分类器

    目录 TRAINING A CLASSIFIER 训练一个分类器 数据呢? Training an image classifier 训练一个图像分类器 1. 载入和归一化CIFAR10 2. Def ...

最新文章

  1. api1.7oracle,API 支持
  2. grunt使用watch和livereload的Gruntfile.js的配置
  3. protobuf windows java_protobuf windows java 环境搭建
  4. ensp安装包_教你如何安装华为模拟器Ensp,另分享全套安装包
  5. 按键中断异步通知实现
  6. SpringBoot2 集成 xxl-job任务调度中心
  7. centos7 安装Golang环境
  8. 面向模式的软件体系结构
  9. 交流:Ghost版系统安装简单分析
  10. Windows下,Unicode、UTF8,GBK(GB2312)互转
  11. Kibana 操作 Elasticsearch
  12. GPS经纬度转百度地图经纬度
  13. 为什么百度查到的ip和ipconfig查到的不一样?
  14. 淘宝天猫京东拼多多苏宁抖音等平台关键词监控价格API接口(店铺商品价格监控API接口调用展示)
  15. (杭电2188)选拔志愿者
  16. 增强型Rabin签名算法
  17. 【浅谈Java项目技术开发基础】
  18. 美国政府与科技巨头讨论开源软件安全、近八万网站受开源软件漏洞影响|1月18日全球网络安全热点
  19. “元宇宙”来了 城市会消亡吗?
  20. 感应电动机的启动压降计算

热门文章

  1. 做领导应该注意的几个问题
  2. 在 Oracle 和 PHP 中使用 LOB
  3. PHP编程最快明白 by www.kuphp.com 案例实战zencart1.38a支付模块简化Fast and Easy Checkout配置...
  4. Flex不支持SOAP1.2
  5. Promise的源码实现(完美符合Promise/A+规范)
  6. CSS浮动(三)---Float
  7. 【NOIP校内模拟】塔
  8. 《一江春水向东流》——任正非
  9. Delphi实现类似Android锁屏的密码锁控件
  10. 在Windows64位环境下.net访问Oracle解决方案(转)