文章目录

  • 前言
  • 一、cifar-10 数据集介绍
  • 二、环境配置
  • 三、实验代码
    • 1.简单网络的代码
    • 2.VGG加深网络的代码
  • 四、运行结果
  • 五、遇到的问题
  • 总结

前言

本文的主要内容是基于 PyTorch 的 cifar-10 图像分类,文中包括 cifar-10 数据集介绍、环境配置、实验代码、运行结果以及遇到的问题这几个部分,本实验采用了基本网络和VGG加深网络模型,其中VGG加深网络模型的识别准确率是要优于基本网络模型的。


一、cifar-10 数据集介绍

cifar-10 数据集由 60000 张分辨率为 32x32 彩色图像组成,共分为 10 类,每类包含 6000 张图像,cifar-10 数据集有 50000 个训练图像和 10000 个测试图像。
数据集分为五个训练批次和一个测试批次,每个批次包含 10000 张图像,测试批次恰好包含从每个类中随机选择的 1000 张图像,训练批次以随机顺序包含其余图像,但某些训练批处理可能包含来自一个类的图像多于另一个类的图像,在它们之间,训练批次正好包含来自每个类的 5000 张图像。
下面是数据集中所包含的类以及每个类中的 10 个随机图像。

由上图可以看到,cifar-10 数据集包含飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船以及卡车这十类,这些类是完全相互排斥的,汽车和卡车之间也没有重叠,汽车包括轿车、SUV等诸如此类的东西,卡车仅包括大型卡车,但两者都不包括皮卡车。
该数据集可以在网址 https://www.cs.toronto.edu/~kriz/cifar.html 中进行下载,下载解压后包含以下几个文件。


二、环境配置

先安装 Anaconda,用来创建需要的环境,Anaconda 的安装可以参考:Anaconda 的安装及使用。
在安装好的 Anaconda 中安装 python 和 pytorch 以及代码中可能用到的包,可以参考:

在PyCharm中点击File——>Settings 打开如下界面,找到 Project 下的 Project interpreter ,再点击右边的齿轮,选择 Add。

在弹出的新界面中选择 Conda Environment,再选择Existing environment,在Interpreter这里找到你在 Anaconda 中 pytorch 环境下的 python 即可,然后点击OK。

可以看到,这里的 Project interpreter 已经发生了变化,点击 OK 即可。


上面两幅图中所包含的就是安装好python、pytorch以及本实验所用包后的信息了。


三、实验代码

本实验所用的代码有两个,一个是基于简单网络的,一个是基于VGG加深网络的。

1.简单网络的代码

#  声明:本代码并非自己编写,由他人提供
import torch
import torchvision
import torchvision.transforms as transforms
import sslfrom torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import timessl._create_default_https_context = ssl._create_unverified_contexttransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomGrayscale(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self,x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)def imshow(img):img = img / 2 + 0.5npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()if __name__ == '__main__':for epoch in range(20):timestart = time.time()running_loss = 0.0for i,data in enumerate(trainloader, 0):inputs, labels = datainputs, labels = Variable(inputs), Variable(labels)optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 500 == 499:print('[%d ,%5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 500))running_loss = 0.0print('epoch %d cost %3f sec' % (epoch + 1, time.time()-timestart))print('Finished Training')dataiter = iter(testloader)images, labels = dataiter.__next__()imshow(torchvision.utils.make_grid(images))print('GroundTruth:', ' '.join('%5s' % classes[labels[j]] for j in range(4)))outputs = net(Variable(images))_, predicted = torch.max(outputs.data,1)print('Predicted:', ' '.join('%5s' % classes[labels[j]] for j in range(4)))correct = 0total = 0for data in testloader:images, labels = dataoutputs = net(Variable(images))_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum()print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total))class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))for data in testloader:images, labels = dataoutputs = net(Variable(images))_, predicted = torch.max(outputs.data, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i]class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

2.VGG加深网络的代码

#  声明:本代码并非自己编写,由他人提供
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import os
import sslssl._create_default_https_context = ssl._create_unverified_contexttransform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomGrayscale(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])transform1 = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./cifar10_vgg', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./cifar10_vgg', train=False, download=True, transform=transform1)
testloader = torch.utils.data.DataLoader(testset, batch_size=50, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 64, 3, padding=1)self.conv2 = nn.Conv2d(64, 64, 3, padding=1)self.pool1 = nn.MaxPool2d(2, 2)self.bn1 = nn.BatchNorm2d(64)self.relu1 = nn.ReLU()self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.conv4 = nn.Conv2d(128, 128, 3, padding=1)self.pool2 = nn.MaxPool2d(2, 2, padding=1)self.bn2 = nn.BatchNorm2d(128)self.relu2 = nn.ReLU()self.conv5 = nn.Conv2d(128, 128, 3, padding=1)self.conv6 = nn.Conv2d(128, 128, 3, padding=1)self.conv7 = nn.Conv2d(128, 128, 1, padding=1)self.pool3 = nn.MaxPool2d(2, 2, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.conv8 = nn.Conv2d(128, 256, 3, padding=1)self.conv9 = nn.Conv2d(256, 256, 3, padding=1)self.conv10 = nn.Conv2d(256, 256, 1, padding=1)self.pool4 = nn.MaxPool2d(2, 2, padding=1)self.bn4 = nn.BatchNorm2d(256)self.relu4 = nn.ReLU()self.conv11 = nn.Conv2d(256, 512, 3, padding=1)self.conv12 = nn.Conv2d(512, 512, 3, padding=1)self.conv13 = nn.Conv2d(512, 512, 1, padding=1)self.pool5 = nn.MaxPool2d(2, 2, padding=1)self.bn5 = nn.BatchNorm2d(512)self.relu5 = nn.ReLU()self.fc14 = nn.Linear(512 * 4 * 4, 1024)self.drop1 = nn.Dropout2d()self.fc15 = nn.Linear(1024, 1024)self.drop2 = nn.Dropout2d()self.fc16 = nn.Linear(1024, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.pool1(x)x = self.bn1(x)x = self.relu1(x)x = self.conv3(x)x = self.conv4(x)x = self.pool2(x)x = self.bn2(x)x = self.relu2(x)x = self.conv5(x)x = self.conv6(x)x = self.conv7(x)x = self.pool3(x)x = self.bn3(x)x = self.relu3(x)x = self.conv8(x)x = self.conv9(x)x = self.conv10(x)x = self.pool4(x)x = self.bn4(x)x = self.relu4(x)x = self.conv11(x)x = self.conv12(x)x = self.conv13(x)x = self.pool5(x)x = self.bn5(x)x = self.relu5(x)# print(" x shape ",x.size())x = x.view(-1, 512 * 4 * 4)x = F.relu(self.fc14(x))x = self.drop1(x)x = F.relu(self.fc15(x))x = self.drop2(x)x = self.fc16(x)return xdef train_sgd(self, device):optimizer = optim.SGD(self.parameters(), lr=0.01)path = 'weights.tar'initepoch = 0if os.path.exists(path) is not True:loss = nn.CrossEntropyLoss()else:checkpoint = torch.load(path)self.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])initepoch = checkpoint['epoch']loss = checkpoint['loss']for epoch in range(initepoch, 20):  # loop over the dataset multiple timestimestart = time.time()running_loss = 0.0total = 0correct = 0for i, data in enumerate(trainloader, 0):# get the inputsinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = self(inputs)l = loss(outputs, labels)l.backward()optimizer.step()running_loss += l.item()if i % 500 == 499:print('[%d, %5d] loss: %.4f' %(epoch, i, running_loss / 500))running_loss = 0.0_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the %d tran images: %.3f %%' % (total,100.0 * correct / total))total = 0correct = 0torch.save({'epoch': epoch,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss}, path)print('epoch %d cost %3f sec' % (epoch, time.time() - timestart))print('Finished Training')def test(self, device):correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = self(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %.3f %%' % (100.0 * correct / total))def classify(self, device):class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = self(images)_, predicted = torch.max(outputs.data, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i]class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")net = Net()net = net.to(device)net.train_sgd(device)net.test(device)net.classify(device)

四、运行结果

基于简单网络的代码运行过程如下。
代码运行后开始在cifar10的官网下载数据集 cifar-10-python.tar.gz 的压缩包。

下载成功后接着运行了20个epoch。

20个epoch运行完成后弹出了该图,可以看到画面是比较模糊的。

关闭该图后接着输出了各类识别的准确率。

基于VGG加深网络的代码运行过程如下,整个过程相当耗时。

最终输出各类识别的准确率。

绘制图对比一下,基于VGG加深网络的整体识别效果要比简单网络好很多。


五、遇到的问题

original error was: dll load failed: 找不到指定的模块。

这个问题在网上有好多的解决办法,我自己做了好多尝试,最后不知道是具体的哪一步起了作用,就可以运行程序了,总之将我尝试的方法都贴在下面吧,希望能够帮到你!
1、在Anaconda下安装python3.6版本(之前装了3.7和3.8都不太好使,有可能也不是版本的问题)。
2、先安装 matplotlib,再安装 pytorch(本实验用到了 matplotlib,我先安装的这一个)。
3、尝试过卸载 numpy 再重新安装(好多人通过这个方法解决了)。
4、卸载了电脑之前已安装的 python ,删除了其对应的环境变量(可能会与Anaconda下的python互相影响)。
5、配置 Anaconda 下的 python 环境变量。

上面的环境变量按照自己的安装路径配置。
6、在 PyCharm 下的Settings中把所有可以改变 Project Interpreter 的地方(下图左侧框住的这四个)都改为Anaconda 下的 python路径并保存。

7、看看自己存放 python 模块的文件夹下是否有之前版本 python 的文件,我这里就有一个名为_pycache_的文件夹,删除它。


总结

以上就是cifar-10图像分类的所有内容了,我在搭建环境上花费的时间比运行程序本身的时间都要长,所以在这个过程中遇到问题时要耐心一点,相信你也可以解决问题,让代码成功的跑起来!
参考网址:
Alex Krizhevsky的主页
https://www.kaggle.com/c/cifar-10

基于 PyTorch 的 cifar-10 图像分类相关推荐

  1. 【3D图像分类】基于Pytorch的3D立体图像分类2--数据增强篇

    增强篇主要是对基础篇的一个补充,补充的内要主要是包括以下两个大的方面 数据方面 网络模型方面 其中,数据方面主要是增加训练过程中的数据增强方式:网络模型方面引入残差结构的resent.mobile n ...

  2. 【项目实战课】基于Pytorch的InceptionNet花卉图像分类实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的InceptionNet花卉图像分类实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行 ...

  3. 【图像分类】 基于Pytorch的多类别图像分类实战

    欢迎大家来到图像分类专栏,本篇基于Pytorch完成一个多类别图像分类实战. 作者 | 郭冰洋 编辑 | 言有三 1 简介 实现一个完整的图像分类任务,大致需要分为五个步骤: 1.选择开源框架 目前常 ...

  4. 基于SVM的思想做CIFAR 10图像分类

    #SVM 回顾一下之前的SVM,找到一个间隔最大的函数,使得正负样本离该函数是最远的,是否最远不是看哪个点离函数最远,而是找到一个离函数最近的点看他是不是和该分割函数离的最近的. 使用large ma ...

  5. 【实战篇】基于Pytorch的3D立体图像分类--基础篇

    在一般的图像数据的采集场景中,得到的多是二维图像,所以大多数深度学习网络的雏形都是基于二维图像展开的工作. 但是,在某些场景下,比如医学影像CT数据,监控场景连续拍摄的视频和自动驾驶使用到的激光点云等 ...

  6. Pytorch基础知识(15)基于PyTorch的多标签图像分类

    早在 2012 年,神经网络就首次赢得了 ImageNet 大规模视觉识别挑战.Alex Krizhevsky,Ilya Sutskever 和 Geoffrey Hinton 彻底改变了图像分类领域 ...

  7. 基于PyTorch的卷积神经网络图像分类——猫狗大战(二):使用Pytorch定义网络模型

    文章目录 1. 需要用到的库 2. 模型定义 3. 测试 基于上一篇文章 https://blog.csdn.net/linghu8812/article/details/100044971,这次介绍 ...

  8. 【图像分类】 基于Pytorch的细粒度图像分类实战

    欢迎大家来到<图像分类>专栏,今天讲述基于pytorch的细粒度图像分类实战! 作者&编辑 | 郭冰洋 1 简介 针对传统的多类别图像分类任务,经典的CNN网络已经取得了非常优异的 ...

  9. 【项目实战课】基于Pytorch的DANet自然图像降噪实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的DANet自然图像降噪实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战讲 ...

  10. 【项目实战课】基于Pytorch的EnlightenGAN自然图像增强实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的EnlightenGAN自然图像增强实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行 ...

最新文章

  1. 人群行为分析--Understanding Pedestrian Behaviors from Stationary Crowd Groups
  2. android singleInstance返回问题
  3. django的表单系统
  4. sql server和mysql的区别是什么
  5. 【记忆化搜索】【dfs】【递归】Chocolate
  6. simulink 分析达芬方程
  7. 自定义标签处理器类的生命周期
  8. 禁止微信调整页面字体大小
  9. Egret入门学习日记 --- 第十五篇(书中 6.1~6.9节 内容)
  10. 决策树之CART算法
  11. XL4001 典型应用电路
  12. element-plus ui表格表头筛选功能
  13. ByVal和ByRef的区别
  14. 一条校招/社招潜规则~
  15. 数学建模笔记 day-03
  16. java 面试知识点总结
  17. 液晶屏工艺中的封口抹平和端口丝印
  18. 2018年Java大企业面试问题
  19. python精通大学_小白21天精通Python是如何做到的?
  20. 2021年全国职业院校技能大赛(中职组)网络安全竞赛试题(4)(总分100分)

热门文章

  1. pytest 测试开发 —— 上手 pytest 实现自动化测试
  2. MOS管中的米勒效应
  3. 天气预报查询 API + AI 等于王炸(一大波你未曾设想的天气预报查询 API 应用场景更新了)
  4. 最有效的方法来增加在Map中的值
  5. 影像组学病理切片案例教学课程(免费赠送)另有现成的代码,模型供大家上手使用
  6. 鸿蒙系统当贝市场,新款华为智慧屏如何安装第三方软件?当贝市场分享最详细的安装教程详解...
  7. Linux下fdisk格式化TF卡,创建分区
  8. 欧盟电源适配器外部电源2019/1782/EU ERP欧洲能效认证
  9. 通俗易懂的基金理财(小白)
  10. 一个抖音视频下载代码