%matplotlib inline

训练分类器

就是这个。您已经了解了如何定义神经网络,计算损耗并更新网络权重。

现在你可能在想

数据怎么样?

通常,当您必须处理图像,文本,音频或视频数据时,您可以使用标准的python包将数据加载到numpy数组中。然后你可以将这个数组转换成一个torch.*Tensor

  • 对于图像,Pillow,OpenCV等软件包很有用
  • 对于音频,包括scipy和librosa
  • 对于文本,无论是原始Python还是基于Cython的加载,还是NLTK和SpaCy都很有用
    特别是对于视觉,我们创建了一个名为的包 torchvision,其中包含用于常见数据集的数据加载器,如Imagenet,CIFAR10,MNIST等,以及用于图像的数据转换器,即 torchvision.datasetstorch.utils.data.DataLoader

这提供了极大的便利并避免编写样板代码。

在本教程中,我们将使用CIFAR10数据集。它有类:‘飞机’,‘汽车’,‘鸟’,‘猫’,‘鹿’,‘狗’,‘青蛙’,‘马’,‘船’,‘卡车’。CIFAR-10中的图像尺寸为3x32x32,即尺寸为32x32像素的3通道彩色图像。

cifar10

训练图像分类器

我们将按顺序执行以下步骤:

  • 使用加载和标准化CIFAR10训练和测试数据集 torchvision
  • 定义卷积神经网络
  • 定义损失函数
  • 在训练数据上训练网络
  • 在测试数据上测试网络

加载和标准化CIFAR10

使用torchvision,加载CIFAR10非常容易。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision数据集的输出是范围[0,1]的PILImage图像。我们将它们转换为归一化范围的张量[-1,1]

Note

If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0.

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,##数据集的数据加载器,train=True时使用训练集,False时为使用测试集download=True, transform=transform)##获取分类任务所需要的数据集,transform给输入图像施加变换
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz
Files already downloaded and verified

让我们展示一些训练图像,以获得乐趣。

import matplotlib.pyplot as plt
import numpy as np# functions to show an imagedef imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

[外链图片转存失败(img-tdbXaAP1-1568109363235)(output_7_0.png)]

 ship plane  ship  deer

定义卷积神经网络

从神经网络部分复制神经网络并修改它以获取3通道图像(而不是定义的1通道图像)。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()

定义Loss函数和优化器

让我们使用分类交叉熵损失和SGD动量。

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

训练网络

事情开始变得有趣了。我们只需循环遍历数据迭代器,并将输入提供给网络并进行优化。

for epoch in range(2):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')
[1,  2000] loss: 2.168
[1,  4000] loss: 1.833
[1,  6000] loss: 1.693
[1,  8000] loss: 1.589
[1, 10000] loss: 1.527
[1, 12000] loss: 1.488
[2,  2000] loss: 1.418
[2,  4000] loss: 1.371
[2,  6000] loss: 1.355
[2,  8000] loss: 1.331
[2, 10000] loss: 1.300
[2, 12000] loss: 1.292
Finished Training

在测试数据上测试网络

我们已经在训练数据集上训练了2次。但我们需要检查网络是否已经学到了什么。

我们将通过预测神经网络输出的类标签来检查这一点,并根据地面实况进行检查。如果预测正确,我们将样本添加到正确预测列表中。

好的,第一步。让我们从测试集中显示一个图像以熟悉。

dataiter = iter(testloader)
images, labels = dataiter.next()# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

[外链图片转存失败(img-NPfLQZwH-1568109363236)(output_15_0.png)]

GroundTruth:    cat  ship  ship plane

好的,现在让我们看看神经网络认为上面这些例子是什么:

outputs = net(images)

输出是10类的精度。一个类的精度越高,网络认为图像是特定类的越多。那么,让我们得到最高精度的指数:

_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]for j in range(4)))
Predicted:    cat  ship  ship  ship

结果似乎很好。

让我们看看网络如何在整个数据集上执行。

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
Accuracy of the network on the 10000 test images: 52 %

这看起来好于偶然,这是10%的准确性(从10个班级中随机挑选一个班级)。似乎网络学到了一些东西。

嗯,什么是表现良好的类,以及表现不佳的类:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 34 %
Accuracy of   car : 43 %
Accuracy of  bird : 38 %
Accuracy of   cat : 39 %
Accuracy of  deer : 57 %
Accuracy of   dog : 57 %
Accuracy of  frog : 58 %
Accuracy of horse : 55 %
Accuracy of  ship : 73 %
Accuracy of truck : 67 %

利用pytorch实现多分类器相关推荐

  1. 使用pytorch构建图片分类器

    分类器任务和数据介绍 构造一个将不同图像进行分类的神经网络分类器, 对输入的图片进行判别并完成分类. 本案例采用CIFAR10数据集作为原始图片数据. CIFAR10数据集介绍: 数据集中每张图片的尺 ...

  2. 利用Pytorch实现GoogLeNet网络

    目  录 1 GoogLeNet网络 1.1 网络结构及参数 1.2 Inception结构 1.3 带降维功能的Inception结构 1.4 辅助分类器 2 利用Pytorch实现GoogLeNe ...

  3. 使用Pytorch构建一个分类器(CIFAR10模型)

    分类器任务和数据介绍 ·构建一个将不同图像进行分类的神经网络分类器,对输入的的图片进行判别并完成分类. ·本案例采用CIFAR10数据集作为原始图片数据 ·CIFAR10数据集介绍:数据集中每张图片的 ...

  4. CUDA:利用Pytorch查看自己电脑上CUDA版本及其相关信息

    CUDA:利用Pytorch查看自己电脑上CUDA版本及其相关信息 目录 利用Pytorch查看自己电脑上CUDA的版本信息

  5. PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN

    PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN 目录 训练过程 代码设计 训练过程 代码设计 #PyTorch:利用PyTorch实现 ...

  6. PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

    PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析 目录 输出结果 核心代码 输出结果 核心代码 #PyTorch:采用skle ...

  7. 深度学习练手项目(二)-----利用PyTorch进行线性回归

    前言 深度学习并没有想象的那么难,甚至比有些传统的机器学习更简单.所用到的数学知识也不需要特别的高深.这篇文章将利用PyTorch来实现线性回归这个经典的模型. 一.线性回归理论 线性回归时利用数理统 ...

  8. 利用Pytorch的C++前端(libtorch)读取预训练权重并进行预测

    本篇使用的平台为Ubuntu,Windows平台的请看Pytorch的C++端(libtorch)在Windows中的使用 前言 距离发布Pytorch-1.0-Preview版的发布已经有两个多月, ...

  9. matlab ann-bp分类器,利用matlab真的BP-ANN分类器设计.doc

    利用matlab真的BP-ANN分类器设计,ann分类器,bp神经网络分类器,bp分类器,贝叶斯分类器matlab,svm分类器matlab程序,matlab分类器,matlab分类器工具箱,soft ...

最新文章

  1. 天软考c语言,软考中C语言试题问答精选
  2. 华为数据中心服务器数量,IDC 与华为联合发布《全闪存数据中心白皮书》,目前已有多个应用...
  3. 录屏 模拟器_Scrcpy-在电脑无缝操作手机 (投屏/录屏/免Root)
  4. ICCV 2021 | 通过显式寻找物体的extremity区域加快DETR的收敛
  5. BZOJ 1047: [HAOI2007]理想的正方形 单调队列瞎搞
  6. Exynos4412 中断驱动开发(二)—— 中断处理流程分析
  7. 阿里云容器服务新增支持Kubernetes编排系统,性能重大提升 1
  8. SpringBoot—JPA: javax.persistence.TransactionRequiredException
  9. Unreal Engine 4 —— Ghost Mesh Plugin的开发日志
  10. python的缩进规则是什么意思_Python编程思想(2):Python主要特性、命名规则与代码缩进...
  11. pgadmin 转成oracle,pgAdmin快速备份还原数据库
  12. python pandas 教程下载_如何用Python处理Excel?Pandas视频教程官方文档来啦~
  13. win10sas安装教程_Android Studio详细安装教程
  14. 【C#】NPOI导出Excel格式设置
  15. 2022-4-12作业
  16. 一文搞懂supervisor进程管理
  17. python之路金角大王_Python 之路03 - Python基础3
  18. 解决studio 3T时间到期方法
  19. 笔记《基于无人驾驶方程式赛车的传感器融合目标检测算法研究及实现》
  20. 统计学——数据的分类

热门文章

  1. 关于ImportError: DLL load failed: 找不到指定的模块
  2. typec扩展坞hdmi没反应_解决电脑接口不足难题,实测绿联九合一type-c扩展坞到底有多强...
  3. ES压测工具(四):esrally实例操作
  4. 推荐算法最前沿|CIKM2020推荐系统论文一览
  5. 旋转电机设计_尤哈·皮罗内 PDF完整版下载 网盘分享
  6. 昨晚罗老师的手机发布了,传闻发布会的门票就卖了
  7. 酸了!她在MSRA实习短短半年时间内便完成了两篇入选顶级学术会议 AAAI、ACL 的研究...
  8. i5 i7 Oracle,Intel Core i5/i7哪款最适合你?Intel Core i5/i7处理器简略对比评测
  9. win10环境 pip 安装theano(gpu) python3.6
  10. TextView和EditText