PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类
在本教程中,我们将使用CIFAR10数据集。它有类别:“飞机”、“汽车”、“鸟”、“猫”、“鹿”、“狗”、“青蛙”、“马”、“船”、“卡车”。CIFAR-10中的图像大小为3x32x32,即3通道彩色图像大小为32x32像素。
我们将按以下顺序进行:
1.使用torchvision加载和规范CIFAR10培训和测试数据集
2.定义一个卷积神经网络
3.定义损失函数
4.根据培训数据对网络进行培训
5.在测试数据上测试网络
1.
import torch
import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')Out:Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np# functions to show an imagedef imshow(img):img = img / 2 + 0.5 # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
Out:
car deer bird car
2.
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
3.
import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4.
for epoch in range(2): # loop over the dataset multiple timesrunning_loss = 0.0for i, data in enumerate(trainloader, 0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999: # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0print('Finished Training')Out:[1, 2000] loss: 2.247
[1, 4000] loss: 1.899
[1, 6000] loss: 1.702
[1, 8000] loss: 1.574
[1, 10000] loss: 1.504
[1, 12000] loss: 1.489
[2, 2000] loss: 1.401
[2, 4000] loss: 1.391
[2, 6000] loss: 1.353
[2, 8000] loss: 1.332
[2, 10000] loss: 1.309
[2, 12000] loss: 1.291
Finished Training
保存训练的模型
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
5.
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))Out:Accuracy of plane : 62 %
Accuracy of car : 66 %
Accuracy of bird : 40 %
Accuracy of cat : 42 %
Accuracy of deer : 59 %
Accuracy of dog : 32 %
Accuracy of frog : 46 %
Accuracy of horse : 64 %
Accuracy of ship : 75 %
Accuracy of truck : 63 %
就像你把一个张量转移到GPU上一样,你把神经网络转移到GPU上。
让我们首先定义我们的设备为第一个可见的cuda设备,如果我们有cuda可用:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# Assuming that we are on a CUDA machine, this should print a CUDA device:print(device)Out:cuda:0
然后这些方法将递归遍历所有模块,并将其参数和缓冲区转换为CUDA张量,也要记得把输入转移到GPU
net.to(device)
inputs, labels = data[0].to(device), data[1].to(device)
会发现GPU提速不大,这是因为网络规模小
PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类相关推荐
- PyTorch实战福利从入门到精通之七——卷积神经网络(LeNet)
卷积神经网络就是含卷积层的网络.介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet [1].这个名字来源于LeNet论文的第一作者Yann LeCun.LeNet展示了通过梯度下降训练卷积神经 ...
- PyTorch实战福利从入门到精通之五——搭建ResNet
Kaiming He的深度残差网络(ResNet)在深度学习的发展中起到了很重要的作用,ResNet不仅一举拿下了当年CV下多个比赛项目的冠军,更重要的是这一结构解决了训练极深网络时的梯度消失问题. ...
- PyTorch实战福利从入门到精通之三——autograd
autograd 反向传播过程需要手动实现.这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出错,难以检查.t ...
- PyTorch实战福利从入门到精通之一——PyTorch框架安装
使用conda安装是最不容易出错的,在pytroch的官网可以选择自己需要的操作系统.python版本.cuda版本的pytorch框架. 之后复制下面的命令就可以了 安装完这个还要安个numpy p ...
- PyTorch实战福利从入门到精通之六——线性回归
一元线性回归 一元线性模型非常简单,假设我们有变量 xix_ixi 和目标 yiy_iyi,每个 i 对应于一个数据点,希望建立一个模型 y^i=wxi+b\hat{y}_i = w x_i + ...
- PyTorch实战福利从入门到精通之八——深度卷积神经网络(AlexNet)
在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机.虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意.一方面,神经网络计算复 ...
- PyTorch实战福利从入门到精通之九——数据处理
在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像.文本.语音或其它二进制数据等.数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果.考虑到这 ...
- PyTorch实战福利从入门到精通之二——Tensor
Tensor又名张量,也是Tensorflow等框架中的重要数据结构.它可以是一个数(标量),一维数组(向量),二维数组或更高维数组.Tensor支持GPU加速. 创建Tensor 几种常见创建Ten ...
- 从入门到精通:卷积神经网络初学者指南
转载自:http://www.jiqizhixin.com/article/1363?utm_source=tuicool&utm_medium=referral 这是一篇向初学者讲解卷积神经 ...
最新文章
- 十二张图带你了解 Redis 的数据结构和对象系统
- keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习
- Python 技术篇-利用pyperclip库实现读取写入剪切板,超简单
- 03-kubeadm初始化Kubernetes集群
- 计算机视觉与深度学习 | 卷积神经网络实现异常行为识别(目标分割与提取)
- 鸿蒙是学生开发的系统,9岁小学生展示鸿蒙OS开发
- VTK:Filtering之ExtractVisibleCells
- CentOS安装Chrome
- staem被盗_如何检查照片是否被盗
- linux启动mqtt_linux下安装MQTT服务器 - EMQTT
- react构建_您应该了解的有关React的一切:开始构建所需的基础知识
- importanturlAndutl
- 小程序webview 页面被放大_Android中WebView加载的网页被放大的解决办法
- K3 设置为AP,用于软件路由的后级。
- 一则 HTTP 405 Method Not Allowed 的解决办法
- Word 2010 从任意页码重新开始
- 初学者学习哪种编程语言比较适合呢?
- 手机apk应用程序未安装解决办法
- HBuilderX快捷键大全
- 设计图片转换html5,在HTML5中翻转图片
热门文章
- 合成资产挖矿项目 ARCx 启动流动性挖矿
- SAP License:SAP S/4HANA Cloud [ERP 云]
- 智慧楼宇、消防系统、门禁管理、暖通空调、给排水、变配电、设备管理、停车管理、能源管理、故障检测、客流统计、运行控制、权限分配、物联网、Axure原型、rp原型、产品原型
- Python之Django之views中视图代码重复查询的优化
- Git使用教程之从远程库克隆项目(四)
- Python学习-基础篇14 Web框架本质及第一个Django实例
- 2017-2018-1 20155320 实验三——实时系统
- 三顺,因为你,我笑了。
- JProfiler 简要使用说明
- hdu 3657 最大点权独立集变形(方格取数的变形最小割,对于最小割建图很好的题)...