在本教程中,我们将使用CIFAR10数据集。它有类别:“飞机”、“汽车”、“鸟”、“猫”、“鹿”、“狗”、“青蛙”、“马”、“船”、“卡车”。CIFAR-10中的图像大小为3x32x32,即3通道彩色图像大小为32x32像素。

我们将按以下顺序进行:
1.使用torchvision加载和规范CIFAR10培训和测试数据集
2.定义一个卷积神经网络
3.定义损失函数
4.根据培训数据对网络进行培训
5.在测试数据上测试网络

1.

import torch
import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')Out:Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np# functions to show an imagedef imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

Out:

car  deer  bird   car

2.

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()

3.

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4.

for epoch in range(2):  # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999:    # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')Out:[1,  2000] loss: 2.247
[1,  4000] loss: 1.899
[1,  6000] loss: 1.702
[1,  8000] loss: 1.574
[1, 10000] loss: 1.504
[1, 12000] loss: 1.489
[2,  2000] loss: 1.401
[2,  4000] loss: 1.391
[2,  6000] loss: 1.353
[2,  8000] loss: 1.332
[2, 10000] loss: 1.309
[2, 12000] loss: 1.291
Finished Training

保存训练的模型

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

5.

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))Out:Accuracy of plane : 62 %
Accuracy of   car : 66 %
Accuracy of  bird : 40 %
Accuracy of   cat : 42 %
Accuracy of  deer : 59 %
Accuracy of   dog : 32 %
Accuracy of  frog : 46 %
Accuracy of horse : 64 %
Accuracy of  ship : 75 %
Accuracy of truck : 63 %

就像你把一个张量转移到GPU上一样,你把神经网络转移到GPU上。
让我们首先定义我们的设备为第一个可见的cuda设备,如果我们有cuda可用:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# Assuming that we are on a CUDA machine, this should print a CUDA device:print(device)Out:cuda:0

然后这些方法将递归遍历所有模块,并将其参数和缓冲区转换为CUDA张量,也要记得把输入转移到GPU

net.to(device)
inputs, labels = data[0].to(device), data[1].to(device)

会发现GPU提速不大,这是因为网络规模小

PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类相关推荐

  1. PyTorch实战福利从入门到精通之七——卷积神经网络(LeNet)

    卷积神经网络就是含卷积层的网络.介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet [1].这个名字来源于LeNet论文的第一作者Yann LeCun.LeNet展示了通过梯度下降训练卷积神经 ...

  2. PyTorch实战福利从入门到精通之五——搭建ResNet

    Kaiming He的深度残差网络(ResNet)在深度学习的发展中起到了很重要的作用,ResNet不仅一举拿下了当年CV下多个比赛项目的冠军,更重要的是这一结构解决了训练极深网络时的梯度消失问题. ...

  3. PyTorch实战福利从入门到精通之三——autograd

    autograd 反向传播过程需要手动实现.这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出错,难以检查.t ...

  4. PyTorch实战福利从入门到精通之一——PyTorch框架安装

    使用conda安装是最不容易出错的,在pytroch的官网可以选择自己需要的操作系统.python版本.cuda版本的pytorch框架. 之后复制下面的命令就可以了 安装完这个还要安个numpy p ...

  5. PyTorch实战福利从入门到精通之六——线性回归

    一元线性回归 一元线性模型非常简单,假设我们有变量 xix_ixi​ 和目标 yiy_iyi​,每个 i 对应于一个数据点,希望建立一个模型 y^i=wxi+b\hat{y}_i = w x_i + ...

  6. PyTorch实战福利从入门到精通之八——深度卷积神经网络(AlexNet)

    在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机.虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意.一方面,神经网络计算复 ...

  7. PyTorch实战福利从入门到精通之九——数据处理

    在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像.文本.语音或其它二进制数据等.数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果.考虑到这 ...

  8. PyTorch实战福利从入门到精通之二——Tensor

    Tensor又名张量,也是Tensorflow等框架中的重要数据结构.它可以是一个数(标量),一维数组(向量),二维数组或更高维数组.Tensor支持GPU加速. 创建Tensor 几种常见创建Ten ...

  9. 从入门到精通:卷积神经网络初学者指南

    转载自:http://www.jiqizhixin.com/article/1363?utm_source=tuicool&utm_medium=referral 这是一篇向初学者讲解卷积神经 ...

最新文章

  1. 十二张图带你了解 Redis 的数据结构和对象系统
  2. keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习
  3. Python 技术篇-利用pyperclip库实现读取写入剪切板,超简单
  4. 03-kubeadm初始化Kubernetes集群
  5. 计算机视觉与深度学习 | 卷积神经网络实现异常行为识别(目标分割与提取)
  6. 鸿蒙是学生开发的系统,9岁小学生展示鸿蒙OS开发
  7. VTK:Filtering之ExtractVisibleCells
  8. CentOS安装Chrome
  9. staem被盗_如何检查照片是否被盗
  10. linux启动mqtt_linux下安装MQTT服务器 - EMQTT
  11. react构建_您应该了解的有关React的一切:开始构建所需的基础知识
  12. importanturlAndutl
  13. 小程序webview 页面被放大_Android中WebView加载的网页被放大的解决办法
  14. K3 设置为AP,用于软件路由的后级。
  15. 一则 HTTP 405 Method Not Allowed 的解决办法
  16. Word 2010 从任意页码重新开始
  17. 初学者学习哪种编程语言比较适合呢?
  18. 手机apk应用程序未安装解决办法
  19. HBuilderX快捷键大全
  20. 设计图片转换html5,在HTML5中翻转图片

热门文章

  1. 合成资产挖矿项目 ARCx 启动流动性挖矿
  2. SAP License:SAP S/4HANA Cloud [ERP 云]
  3. 智慧楼宇、消防系统、门禁管理、暖通空调、给排水、变配电、设备管理、停车管理、能源管理、故障检测、客流统计、运行控制、权限分配、物联网、Axure原型、rp原型、产品原型
  4. Python之Django之views中视图代码重复查询的优化
  5. Git使用教程之从远程库克隆项目(四)
  6. Python学习-基础篇14 Web框架本质及第一个Django实例
  7. 2017-2018-1 20155320 实验三——实时系统
  8. 三顺,因为你,我笑了。
  9. JProfiler 简要使用说明
  10. hdu 3657 最大点权独立集变形(方格取数的变形最小割,对于最小割建图很好的题)...