Pytorch作为新兴的深度学习框架,目前的使用率正在逐步上升。相比TensorFlow,Pytorch的上手难度更低,同时Pytorch支持对图的动态定义,并且能够方便的将网络中的tensor格式数据与numpy格式数据进行转换,使得其对某些特殊结构的网络定义起来更加方便,但是Pytorch对于分布式训练之类的支持相对较差,同时没有Tensorboard之类的工具对网络进行方便的可视化。当然,Tensorflow能够选择Keras之类的框架,来大幅简化网络架设工作。

Pytorch拥有一个不错的官方教程https://pytorch.org/tutorials/,包含了从基本运算到图像分类、语义识别、增强学习和今年大火的GAN等案例,解释的也非常清楚。这里主要依据官网的这篇教程https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html,并且对网络结构了一些改进,来练习Pytorch的使用。

这里也按照官网的步骤来,首先是通过torchvision库导入cifar10数据集:

import torch
import torchvision
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

查看图片之类的操作就看官网教程好了,这里省略掉。

第二步是定义卷积神经网络,官网使用的作为示例的网络如下:

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()

这个网络是卷积+全连接层的形式,这种结构的网络效果其实不好,因为全连接层传递效率较低,同时会干扰到卷积层提取出的局部特征,并且也没有用到BatchNorm和Dropout来防止过拟合的问题。现在流行的网络结构大多采用全卷积层的结构,下面的结构效果会好很多:

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)self.conv3 = nn.Conv2d(64, 128, 3, padding = 1)self.conv4 = nn.Conv2d(128, 128, 3, padding = 1)self.conv5 = nn.Conv2d(128, 256, 3, padding = 1)self.conv6 = nn.Conv2d(256, 256, 3, padding = 1)self.maxpool = nn.MaxPool2d(2, 2)self.avgpool = nn.AvgPool2d(2, 2)self.globalavgpool = nn.AvgPool2d(8, 8)self.bn1 = nn.BatchNorm2d(64)self.bn2 = nn.BatchNorm2d(128)self.bn3 = nn.BatchNorm2d(256)self.dropout50 = nn.Dropout(0.5)self.dropout10 = nn.Dropout(0.1)self.fc = nn.Linear(256, 10)def forward(self, x):x = self.bn1(F.relu(self.conv1(x)))x = self.bn1(F.relu(self.conv2(x)))x = self.maxpool(x)x = self.dropout10(x)x = self.bn2(F.relu(self.conv3(x)))x = self.bn2(F.relu(self.conv4(x)))x = self.avgpool(x)x = self.dropout10(x)x = self.bn3(F.relu(self.conv5(x)))x = self.bn3(F.relu(self.conv6(x)))x = self.globalavgpool(x)x = self.dropout50(x)x = x.view(x.size(0), -1)x = self.fc(x)return xnet = Net()

Pytorch也可以用nn.Sequential函数来很简单的定义序列网络,和keras的Sequential差不多,但是pytorch需要给出每一层网络的输入与输出参数,这一点就不像keras那么无脑。由于pytorch不像keras自带GlobalAveragePooling,手写一个怕自己忘记,其实这里不加效果会更好,毕竟这相当于强行压缩特征数据之后再进行分类。

这里再举个nn.Sequential的基本栗子:

channel_1 = 32
channel_2 = 16
model = nn.Sequential(nn.Conv2d(3, channel_1, 5, padding = 2),nn.ReLU(),nn.Conv2d(channel_1, channel_2, 3, padding = 1),nn.ReLU(),Flatten(),nn.Linear(channel_2 * 32 * 32, 10),
)

第三步是定义损失函数和优化器,官网这里用的是带动量项的SGD,但是个人感觉Adam对复杂函数的优化效果会比SGD好,所以这里用Adam来代替:

import torch.optim as optimcriterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

接下来,第四步就是训练网络了。可以首先使用下列语句来自动判断使用GPU还是CPU进行计算,不过一般而言,GPU和同档次的CPU计算速度可以差到50~70倍……

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

下面就可以开始训练了,这里要注意训练的数据也要.to(device):

for epoch in range(10):running_loss = 0.batch_size = 100for i, data in enumerate(torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=2), 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print('[%d, %5d] loss: %.4f' %(epoch + 1, (i+1)*batch_size, loss.item()))print('Finished Training')

当然,Pytorch不如Keras直接一个.fit来的方便,但是也不算麻烦。由于不用在一个session里边进行计算,灵活性还是比tensorflow和封装的严严实实的keras高很多。

之后,可以用下面的语句存储或读取保存好的模型:

torch.save(net, 'cifar10.pkl')
net = torch.load('cifar10.pkl')

在训练完成之后,就可以用测试集查看训练结果:

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

这个网络模型在batch_size=100的条件下训练10个epoch之后,测试集正确率大概在80%左右,对cifar10数据集而言还算可以啦。

源码放在github上,欢迎取用~地址:https://github.com/PolarisShi/cifar10

基于Pytorch的cifar10分类网络模型相关推荐

  1. Pytorch搭建常见分类网络模型------VGG、Googlenet、ResNet50 、MobileNetV2(4)

    接上一节内容:Pytorch搭建常见分类网络模型------VGG.Googlenet.ResNet50 .MobileNetV2(3)_一只小小的土拨鼠的博客-CSDN博客 mobilenet系列: ...

  2. 基于 PyTorch 的 cifar-10 图像分类

    文章目录 前言 一.cifar-10 数据集介绍 二.环境配置 三.实验代码 1.简单网络的代码 2.VGG加深网络的代码 四.运行结果 五.遇到的问题 总结 前言 本文的主要内容是基于 PyTorc ...

  3. Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练

    pytorch进行CIFAR-10分类(4)训练 我的系列博文: Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理 Pytorch打怪路(一)pyt ...

  4. 基于ResNet50的CIFAR10分类

    Q3 CIFAR10 图像分类 CIFAR10 本次运用了 ResNet50进行了图像分类处理(基于Pytorch) 一.数据集 1. 数据集说明 CIFAR-10数据集共有60000张彩色图像,这些 ...

  5. [PyTorch] 基于Python和PyTorch的cifar-10分类

    cifar-10数据集介绍 CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像.有50000个训练图像和10000个测试图像. 数据集分为5个训练批次和1个测 ...

  6. [Github项目]基于PyTorch的深度学习网络模型实现

    2019 年第 48 篇文章,总第 72 篇文章 本文大约 1500 字,阅读大约需要 4 分钟 今天主要分享两份 Github 项目,都是采用 PyTorch 来实现深度学习网络模型,主要是一些常用 ...

  7. 基于pytorch的人工智能分类垃圾桶

    Hello,大家好,作者终于考完研了,现在开始更新自己以前的科研项目来供大家一起学习参考,开源共享,作者github网址:https://github.com/czzq1999,欢迎加油一起学习,一起 ...

  8. 基于pytorch搭建神经网络的花朵种类识别(深度学习)

    基于pytorch搭建神经网络的花朵种类识别(深度学习) 文章目录 基于pytorch搭建神经网络的花朵种类识别(深度学习) 一.知识点 1.特征提取.神经元逐层判断 2.中间层(隐藏层) 3.学习权 ...

  9. Deep Learning:基于pytorch搭建神经网络的花朵种类识别项目(内涵完整文件和代码)—超详细完整实战教程

    基于pytorch的深度学习花朵种类识别项目完整教程(内涵完整文件和代码) 相关链接:: 超详细--CNN卷积神经网络教程(零基础到实战) 大白话pytorch基本知识点及语法+项目实战 文章目录 基 ...

最新文章

  1. Http协议简单介绍
  2. 【大话存储】学习笔记(7章), OSI模型
  3. 简单两步干掉WordPress里面的fonts.googleapis
  4. Java反射自定义注解底层设计原理
  5. Java 获取命令行输入数据(命令行输入,Scanner类)
  6. Vue 自定义按键修饰符对应表
  7. 在webstorm中怎么配置本地服务器
  8. 怎么理解汉罗塔问题_小白理解的汉诺塔中的递归问题
  9. 第一百八十四节,jQuery-UI,验证注册表单
  10. TensorFlow MNIST(手写识别 softmax)实例运行
  11. LSTM 文本分类模型的实现
  12. 绵阳市:充分利用区块链等技术 为农民工证照办理提供线上便捷服务
  13. windows 下vscode coderunner+bash 编程
  14. c语言中变量要加引号吗,CMake中引号用法总结
  15. 如何获取MySQL中表的最后更新时间
  16. JS基础之数组--概述、创建数组的几种方式、数组的特点、数组的常用方法、数组的解构赋值、数组高级API
  17. Pandas高级教程之:自定义选项
  18. 性能测试之nmon对linux服务器的监控 侵删
  19. 【剑指offer-54】20190907/03 字符流中第一个不重复的字符
  20. Javascript Flash harness

热门文章

  1. 爬取ajax加载的豆瓣电影
  2. 从RCNN,Fast-RCNN到Fater-RCNN的演化过程
  3. 论文阅读笔记《Adaptive Image-Based Visual Servoing Using Reinforcement Learning With Fuzzy State Coding》
  4. jupyter 内核似乎挂掉了 它很快将自动重启---解决方案
  5. php ECShop form,在ecshop中添加页面并且实现后台管理
  6. ac Let‘s Play Curling
  7. 推出 BlazePose:实现设备端实时人体姿态追踪
  8. 北斗定位,定位追踪,防盗追踪系统设计方案
  9. Column ‘‘ in field list is a ambiguous
  10. Log4j2 重大漏洞与解决方案