文章目录

  • 前言
  • CIFAR10简介
  • Backbone选择
  • 训练+测试
    • 训练环境及超参设置
    • 完整代码
  • 部分测试结果
  • 完整工程文件
  • Reference

前言

分享一下本人去年入门深度学习时,在CIFAR10数据集上做的图像分类任务,使用了多个主流的backbone网络,希望可以为同样想入门深度学习的同志们,提供一个方便上手、容易理解的参考教程。

CIFAR10简介

CIFAR-10数据集是图像分类领域经典的数据集,由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理得到,一共包含10个类别的 RGB彩色图片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck ),图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示

 
Pytorch中提供了如下命令可以直接将CIFAR10数据集下载到本地:

import torchvision
dataset = torchvision.datasets.CIFAR10(root, train=True, download=True, transform)
  • root:数据集加载到本地的路径
  • train=True:True表示加载训练集,False加载测试集
  • download=True:True表示加载数据集到root,若数据集已经存在,则不会再加载
  • transform:数据增强

这里分享一个加载CIFAR10数据集的完整代码:

# 设置数据增强
print('==> Preparing data..')
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载CIFAR10数据集
trainset = torchvision.datasets.CIFAR10(root=opt.data, train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root=opt.data, train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

Backbone选择

本文主要尝试了以下几个主流的backbone网络,并在CIFAR10上实现了图像分类任务:

  • LetNet
  • AlexNet
  • VGG
  • GoogLeNet(InceptionNet)
  • ResNet
  • DenseNet
  • ResNeXt
  • SENet
  • MobileNetv2-v3
  • ShuffleNetv2
  • EfficientNetB0
  • Darknet53
  • CSPDarknet53

这里放上测试结果最好的ResNet模块的构建代码,其他代码放到最后完整工程backbone文件夹中:

"""
pytorch实现ResNet50、ResNet101和ResNet152:
"""
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F# conv1 7 x 7 64 stride=2
def Conv1(channel_in, channel_out, stride=2):return nn.Sequential(nn.Conv2d(channel_in,channel_out,kernel_size=7,stride=stride,padding=3,bias=False),nn.BatchNorm2d(channel_out),# 会改变输入数据的值# 节省反复申请与释放内存的空间与时间# 只是将原来的地址传递,效率更好nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=stride, padding=1))# 构建ResNet18-34的网络基础模块
class BasicBlock(nn.Module):expansion = 1def __init__(self, in_planes, planes, stride=1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.shortcut = nn.Sequential()if stride != 1 or in_planes != self.expansion * planes:self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(self.expansion * planes))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = F.relu(out)return out# 构建ResNet50-101-152的网络基础模块
class Bottleneck(nn.Module):expansion = 4def __init__(self, in_planes, planes, stride=1):super(Bottleneck, self).__init__()# 构建 1x1, 3x3, 1x1的核心卷积块self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=stride, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, self.expansion *planes, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(self.expansion * planes)# 采用1x1的kernel,构建shout cut# 注意这里除了第一个bottleblock之外,都需要下采样,所以步长要设置为stride=2self.shortcut = nn.Sequential()if stride != 1 or in_planes != self.expansion * planes:self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion * planes,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(self.expansion * planes))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.relu(self.bn2(self.conv2(out)))out = self.bn3(self.conv3(out))out += self.shortcut(x)out = F.relu(out)return out# 搭建ResNet模板块
class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=10):super(ResNet, self).__init__()self.in_planes = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3,stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)# 逐层搭建ResNetself.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)self.linear = nn.Linear(512 * block.expansion, num_classes)# 参数初始化# for m in self.modules():#     if isinstance(m, nn.Conv2d):#         nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')#     elif isinstance(m, nn.BatchNorm2d):#         nn.init.constant_(m.weight, 1)#         nn.init.constant_(m.bias, 0)def _make_layer(self, block, planes, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)# layers = [ ] 是一个列表# 通过下面的for循环遍历配置列表,可以得到一个由 卷积操作、池化操作等 组成的一个列表layers# return nn.Sequential(*layers),即通过nn.Sequential函数将列表通过非关键字参数的形式传入(列表layers前有一个星号)layers = []for stride in strides:layers.append(block(self.in_planes, planes, stride))self.in_planes = planes * block.expansionreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = F.avg_pool2d(out, 4)out = out.view(out.size(0), -1)out = self.linear(out)return outdef ResNet18():return ResNet(BasicBlock, [2, 2, 2, 2])def ResNet34():return ResNet(BasicBlock, [3, 4, 6, 3])def ResNet50():return ResNet(Bottleneck, [3, 4, 6, 3])def ResNet101():return ResNet(Bottleneck, [3, 4, 23, 3])def ResNet152():return ResNet(Bottleneck, [3, 8, 36, 3])# 测试
# if __name__ == '__main__':
#     model = ResNet50()
#     print(model)
#
#     input = torch.randn(1, 3, 32, 32)
#     out = model(input)
#     print(out.shape)

训练+测试

训练环境及超参设置

本文的训练环境和超参数设置如下

  • 1块1080 Ti GPU
  • epoch为100
  • batch-size为128
  • 优化器:SGD
  • 学习率:余弦退火有序调整学习率

主要步骤如下

  • 加载数据集

    • 将数据集加载到本地
    • 按batch-size加载到dataLoader
  • 设置相关参数
    • 指定GPU
    • 训练相关参数
    • 断点续训
    • 模型保存参数
  • 设置优化器
  • 设置学习率
  • 循环每个epoch
    • 开启训练
    • 开启测试
    • 学习率调整
  • 数据可视化
  • 打印结果

完整代码

'''Train CIFAR10 with PyTorch.'''
import torchvision.transforms as transforms
import time
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import argparse# 导入模型
from backbones.ResNet import ResNet18# 指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'# 用于计算GPU运行时间
def time_sync():# pytorch-accurate timeif torch.cuda.is_available():torch.cuda.synchronize()return time.time()# Training
def train(epoch):model.train()train_loss = 0correct = 0total = 0train_acc = 0# 开始迭代每个batch中的数据for batch_idx, (inputs, targets) in enumerate(trainloader):# inputs:[b,3,32,32], targets:[b]# train_outputs:[b,10]inputs, targets = inputs.to(device), targets.to(device)# print(inputs.shape)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()# 计算损失train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 计算准确率train_acc = correct / total# 每训练100个batch打印一次训练集的loss和准确率if (batch_idx + 1) % 100 == 0:print('[INFO] Epoch-{}-Batch-{}: Train: Loss-{:.4f}, Accuracy-{:.4f}'.format(epoch + 1,batch_idx + 1,loss.item(),train_acc))# 计算每个epoch内训练集的acctotal_train_acc.append(train_acc)# Testing
def test(epoch, ckpt):global best_accmodel.eval()test_loss = 0correct = 0total = 0test_acc = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(testloader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()test_acc = correct / totalprint('[INFO] Epoch-{}-Test Accurancy: {:.3f}'.format(epoch + 1, test_acc), '\n')total_test_acc.append(test_acc)# 保存权重文件acc = 100. * correct / totalif acc > best_acc:print('Saving..')state = {'net': model.state_dict(),'acc': acc,'epoch': epoch,}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint')torch.save(state, ckpt)best_acc = accif __name__ == '__main__':# 设置超参parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch_size', type=int, default=128)parser.add_argument('--data', type=str, default='cifar10')parser.add_argument('--T_max', type=int, default=100)parser.add_argument('--lr', default=0.1, type=float, help='learning rate')parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')parser.add_argument('--checkpoint', type=str, default='checkpoint/ResNet18-CIFAR10.pth')opt = parser.parse_args()# 设置相关参数device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'best_acc = 0  # best test accuracystart_epoch = 0  # start from epoch 0 or last checkpoint epochclasses = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')# 设置数据增强print('==> Preparing data..')transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])# 加载CIFAR10数据集trainset = torchvision.datasets.CIFAR10(root=opt.data, train=True, download=True, transform=transform_train)trainloader = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root=opt.data, train=False, download=True, transform=transform_test)testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)# print(trainloader.dataset.shape)# 加载模型print('==> Building model..')model = ResNet18().to(device)# DP训练if device == 'cuda':model = torch.nn.DataParallel(model)cudnn.benchmark = True# 加载之前训练的参数if opt.resume:# Load checkpoint.print('==> Resuming from checkpoint..')assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'checkpoint = torch.load(opt.checkpoint)model.load_state_dict(checkpoint['net'])best_acc = checkpoint['acc']start_epoch = checkpoint['epoch']# 设置损失函数与优化器criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=opt.lr,momentum=0.9, weight_decay=5e-4)# 余弦退火有序调整学习率scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.T_max)# ReduceLROnPlateau(自适应调整学习率)# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)# 记录training和testing的acctotal_test_acc = []total_train_acc = []# 记录训练时间tic = time_sync()# 开始训练for epoch in range(opt.epochs):train(epoch)test(epoch, opt.checkpoint)# 动态调整学习率scheduler.step()# ReduceLROnPlateau(自适应调整学习率)# scheduler.step(loss_val)# 数据可视化plt.figure()plt.plot(range(opt.epochs), total_train_acc, label='Train Accurancy')plt.plot(range(opt.epochs), total_test_acc, label='Test Accurancy')plt.xlabel('Epoch')plt.ylabel('Accurancy')plt.title('ResNet18-CIFAR10-Accurancy')plt.legend()plt.savefig('output/ResNet18-CIFAR10-Accurancy.jpg')  # 自动保存plot出来的图片plt.show()# 输出best_accprint(f'Best Acc: {best_acc * 100}%')toc = time_sync()# 计算本次运行时间t = (toc - tic) / 3600print(f'Training Done. ({t:.3f}s)')

部分测试结果

Backbone Best Acc
MobileNetv2 93.37%
VGG16 93.80%
DenseNet121 94.55%
GoogLeNet 95.02%
ResNeXt29_32×4d 95.18%
ResNet50 95.20%
SENet18 95.22%
ResNet18 95.23%

完整工程文件

Pytorch实现CIFAR10图像分类任务测试集准确率达95%

Reference

CIFAR-10 数据集

深度学习入门基础教程(二) CNN做CIFAR10数据集图像分类 pytorch版代码

Pytorch CIFAR10 图像分类篇 汇总

pytorch-cifar:使用PyTorch在CIFAR10上为95.47%

【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%相关推荐

  1. Pytorch实战2:ResNet-18实现Cifar-10图像分类(测试集分类准确率95.170%)

    版权说明:此文章为本人原创内容,转载请注明出处,谢谢合作! Pytorch实战2:ResNet-18实现Cifar-10图像分类 实验环境: Pytorch 0.4.0 torchvision 0.2 ...

  2. 用深度学习keras的cnn做图像识别分类,准确率达97%

    Keras是一个简约,高度模块化的神经网络库. 可以很容易和快速实现原型(通过总模块化,极简主义,和可扩展性) 同时支持卷积网络(vision)和复发性的网络(序列数据).以及两者的组合. 无缝地运行 ...

  3. pytorch 训练过程acc_深度学习Pytorch实现分类模型

    今天将介绍深度学习中的分类模型,以下主要介绍Softmax的基本概念.神经网络模型.交叉熵损失函数.准确率以及Pytorch实现图像分类.01Softmax基本概念 在分类问题中,通常标签都为类别,可 ...

  4. 深度学习经典网络解析图像分类篇(二):AlexNet

    深度学习经典网络解析图像分类篇(二):AlexNet 1.背景介绍 2.ImageNet 3.AlexNet 3.1AlexNet简介 3.2AlexNet网络架构 3.2.1第一层(CONV1) 3 ...

  5. 【动手学深度学习PyTorch版】27 数据增强

    上一篇请移步[动手学深度学习PyTorch版]23 深度学习硬件CPU 和 GPU_水w的博客-CSDN博客 目录 一.数据增强 1.1 数据增强(主要是关于图像增强) ◼ CES上的真实的故事 ◼ ...

  6. 深度学习PyTorch笔记(12):线性神经网络——softmax回归

    深度学习PyTorch笔记(12):线性神经网络--softmax回归 6 线性神经网络--softmax回归 6.1 softmax回归 6.1.1 概念 6.1.2 softmax运算 6.2 图 ...

  7. 伯禹公益AI《动手学深度学习PyTorch版》Task 03 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 03 学习笔记 Task 03:过拟合.欠拟合及其解决方案:梯度消失.梯度爆炸:循环神经网络进阶 微信昵称:WarmIce 过拟合. ...

  8. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  9. 深度学习+pytorch实战Kaggle比赛(一)——房价预测

    参考书籍<动手学深度学习(pytorch版),参考网址为: https://zh-v2.d2l.ai/chapter_multilayer-perceptrons/kaggle-house-pr ...

最新文章

  1. GitLab 配置邮箱
  2. Gartner评出2017年最值得关注的11个顶级信息安全技术
  3. centos samba 配置
  4. linux打开pythonshall,linux系统shell脚本后台运行python程序
  5. java中的final关键字(2013-10-11-163 写的日志迁移
  6. viewflipper_Android ViewFlipper示例教程
  7. tomcat优化实例
  8. API章节--第四节包装类总结
  9. fd抓包数据类型_终端抓包神器 | tcpdump参数解析及使用
  10. java 罗马数字_罗马数字 | 学步园
  11. A40i 平台应用笔记-华为-ME909S-4G 模块的移植应用
  12. 京东云 linux无法远程,怎样远程登录京东云云主机.pdf
  13. 软件测试周刊(第23期):你理想中的工作是什么?
  14. 非极大值抑制(NMS)的几种实现优化
  15. HTTP传输协议详解(传输过程及数据格式详细)
  16. ABAQUS怎样导出部分节点的编号
  17. html密码框ml表单文本框,表单组件 PasswordInput 密码输入框 - 闪电教程JSRUN
  18. C3P0的三种使用方法
  19. IBM Rational Rhapsody 8.0和Rhapsody Design Manager 4.0中的新增功能
  20. java生成pdf分页_java itext导出PDF 分页 github

热门文章

  1. 用卡尔曼滤波器跟踪导弹
  2. [转] 肾有多好人就有多年轻
  3. XXX处有未经处理的异常: 0xC0000374: 堆已损坏,处有未经处理的异常: 0xC0000005: 读取位置 0x4F774B16 时发生访问冲突。
  4. ksps什么单位_[转载]采样频率Hz 采样率KSPS或MSPS,两种单位的换算关系
  5. 人工智能刷题(个人向)
  6. 论文阅读 Jointly Optimize Data Augmentation and Network Training
  7. java版汉字转换拼音(大小写)
  8. 大唐:我家阁楼通公主府(三)
  9. arduino驱动LD3320语音识别模块
  10. 卡券php小程序,微信小程序领取卡券 - osc_88a08cel的个人空间 - OSCHINA - 中文开源技术交流社区...