今天我们将使用 Pytorch 来继续实现 LeNet-5 模型,并用它来解决 CIFAR10 数据集的识别。

正文开始!

二、使用LeNet-5网络结构创建CIFAR-10识别分类器

LeNet-5 网络本是用来识别 MNIST 数据集的,下面我们来将 LeNet-5 应用到一个比较复杂的例子,识别 CIFAR-10 数据集。

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( airlane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。

CIFAR-10 的图片样例如图所示。

2.1 下载并加载数据,并做出一定的预先处理

pipline_train = transforms.Compose([#随机旋转图片transforms.RandomHorizontalFlip(),#将图片尺寸resize到32x32transforms.Resize((32,32)),#将图片转化为Tensor格式transforms.ToTensor(),#正则化(当模型出现过拟合的情况时,用来降低模型的复杂度)transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
pipline_test = transforms.Compose([#将图片尺寸resize到32x32transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
#下载数据集
train_set = datasets.CIFAR10(root="./data/CIFAR10", train=True, download=True, transform=pipline_train)
test_set = datasets.CIFAR10(root="./data/CIFAR10", train=False, download=True, transform=pipline_test)
#加载数据集
trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False)
# 类别信息也是需要我们给定的
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

2.2 搭建 LeNet-5 神经网络结构,并定义前向传播的过程

LeNet-5 网络上文已经搭建过了,由于 CIFAR10 数据集图像是 RGB 三通道的,因此 LeNet-5 网络 C1 层卷积选择的滤波器需要 3 通道,网络其它结构跟上文都是一样的。

class LeNetRGB(nn.Module):def __init__(self):super(LeNetRGB, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)   # 3表示输入是3通道self.relu = nn.ReLU()self.maxpool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.maxpool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = x.view(-1, 16*5*5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)output = F.log_softmax(x, dim=1)return output

2.3 将定义好的网络结构搭载到 GPU/CPU,并定义优化器

使用 SGD(随机梯度下降)优化,学习率为 0.001,动量为 0.9。

#创建模型,部署gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNetRGB().to(device)
#定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

2.4 定义训练过程

def train_runner(model, device, trainloader, optimizer, epoch):#训练模型, 启用 BatchNormalization 和 Dropout, 将BatchNormalization和Dropout置为Truemodel.train()total = 0correct =0.0#enumerate迭代已加载的数据集,同时获取数据和数据下标for i, data in enumerate(trainloader, 0):inputs, labels = data#把模型部署到device上inputs, labels = inputs.to(device), labels.to(device)#初始化梯度optimizer.zero_grad()#保存训练结果outputs = model(inputs)#计算损失和#多分类情况通常使用cross_entropy(交叉熵损失函数), 而对于二分类问题, 通常使用sigmodloss = F.cross_entropy(outputs, labels)#获取最大概率的预测结果#dim=1表示返回每一行的最大值对应的列下标predict = outputs.argmax(dim=1)total += labels.size(0)correct += (predict == labels).sum().item()#反向传播loss.backward()#更新参数optimizer.step()if i % 1000 == 0:#loss.item()表示当前loss的数值print("Train Epoch{} \t Loss: {:.6f}, accuracy: {:.6f}%".format(epoch, loss.item(), 100*(correct/total)))Loss.append(loss.item())Accuracy.append(correct/total)return loss.item(), correct/total

2.5 定义测试过程

def test_runner(model, device, testloader):#模型验证, 必须要写, 否则只要有输入数据, 即使不训练, 它也会改变权值#因为调用eval()将不启用 BatchNormalization 和 Dropout, BatchNormalization和Dropout置为Falsemodel.eval()#统计模型正确率, 设置初始值correct = 0.0test_loss = 0.0total = 0#torch.no_grad将不会计算梯度, 也不会进行反向传播with torch.no_grad():for data, label in testloader:data, label = data.to(device), label.to(device)output = model(data)test_loss += F.cross_entropy(output, label).item()predict = output.argmax(dim=1)#计算正确数量total += label.size(0)correct += (predict == label).sum().item()#计算损失值print("test_avarage_loss: {:.6f}, accuracy: {:.6f}%".format(test_loss/total, 100*(correct/total)))

2.6 运行

#调用
epoch = 20
Loss = []
Accuracy = []
for epoch in range(1, epoch+1):print("start_time",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))loss, acc = train_runner(model, device, trainloader, optimizer, epoch)Loss.append(loss)Accuracy.append(acc)test_runner(model, device, testloader)print("end_time: ",time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())),'\n')print('Finished Training')
plt.subplot(2,1,1)
plt.plot(Loss)
plt.title('Loss')
plt.show()
plt.subplot(2,1,2)
plt.plot(Accuracy)
plt.title('Accuracy')
plt.show()

经历 20 次 epoch 迭代训练之后:

start_time 2021-11-27 22:29:09
Train Epoch20 Loss: 0.659028, accuracy: 68.750000%
test_avarage_loss: 0.030969, accuracy: 67.760000%
end_time:  2021-11-27 22:29:44

训练集的 loss 曲线和 Accuracy 曲线变化如下:

2.7 保存模型

print(model)
torch.save(model, './models/model-cifar10.pth') #保存模型

LeNet-5 的模型会 print 出来,并将模型模型命令为 model-cifar10.pth 保存在固定目录下。

LeNetRGB((conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))(relu): ReLU()(maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

2.8 模型测试

利用刚刚训练的模型进行 CIFAR10 类型图片的测试。

from PIL import Image
import numpy as npif __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('./models/model-cifar10.pth') #加载模型model = model.to(device)model.eval()    #把模型转为test模式#读取要预测的图片# 读取要预测的图片img = Image.open("./images/test_cifar10.png").convert('RGB') # 读取图像#img.show()plt.imshow(img) # 显示图片plt.axis('off') # 不显示坐标轴plt.show()# 导入图片,图片扩展后为[1,1,32,32]trans = transforms.Compose([#将图片尺寸resize到32x32transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])img = trans(img)img = img.to(device)img = img.unsqueeze(0)  #图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 预测 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print("概率:",prob)print(predict.item())value, predicted = torch.max(output.data, 1)predict = output.argmax(dim=1)pred_class = classes[predicted.item()]print("预测类别:",pred_class)

输出:

概率:tensor([[7.6907e-01, 3.3997e-03, 4.8003e-03, 4.2978e-05, 1.2168e-02, 6.8751e-06, 3.2019e-06, 1.6024e-04, 1.2705e-01, 8.3300e-02]],
grad_fn=<SoftmaxBackward>)
5
预测类别:plane

模型预测结果正确!

以上就是 PyTorch 构建 LeNet-5 卷积神经网络并用它来识别 CIFAR10 数据集的例子。全文的代码都是可以顺利运行的,建议大家自己跑一边。

值得一提的是,针对 MNIST 数据集和 CIFAR10 数据集,最大的不同就是 MNIST 是单通道的,CIFAR10 是三通道的,因此在构建 LeNet-5 网络的时候,C1层需要做不同的设置。至于输入图片尺寸不一样,我们可以使用 transforms.Resize 方法统一缩放到 32x32 的尺寸大小。

所有完整的代码我都放在 GitHub 上,GitHub地址为:

https://github.com/RedstoneWill/ObjectDetectionLearner/tree/main/LeNet-5

也可以点击阅读原文进入~

往期精彩回顾适合初学者入门人工智能的路线及资料下载中国大学慕课《机器学习》(黄海广主讲)机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载本站qq群955171419,加入微信群请扫码:

【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!相关推荐

  1. 深度学习环境配置(pytorch版本)----超级无敌详细版(有手就行)

    公众号文章--深度学习环境配置(pytorch版本) 写在前面:如果这篇文章对大家有帮助的话,欢迎关注Franpper的公众号:Franpper的知识铺,回复"进群",即可进入讨论 ...

  2. 李沐《动手学深度学习》新增PyTorch和TensorFlow实现,还有中文版

    李沐老师的<动手学深度学习>已经有Pytorch和TensorFlow的实现了,并且有了中文版. 网址:http://d2l.ai/ 简介 李沐老师的<动手学深度学习>自一年前 ...

  3. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  4. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  5. 深度学习入门之PyTorch学习笔记:深度学习介绍

    深度学习入门之PyTorch学习笔记:深度学习介绍 绪论 1 深度学习介绍 1.1 人工智能 1.2 数据挖掘.机器学习.深度学习 1.2.1 数据挖掘 1.2.2 机器学习 1.2.3 深度学习 第 ...

  6. Python深度学习:基于PyTorch [Deep Learning with Python and PyTorch]

    作者:吴茂贵,郁明敏,杨本法,李涛,张粤磊 著 出版社:机械工业出版社 品牌:机工出版 出版时间:2019-11-01 Python深度学习:基于PyTorch [Deep Learning with ...

  7. 干货|《深度学习入门之Pytorch》资料下载

    深度学习如今已经成为了科技领域中炙手可热的技术,而很多机器学习框架也成为了研究者和业界开发者的新宠,从早期的学术框架Caffe.Theano到如今的Pytorch.TensorFlow,但是当时间线来 ...

  8. 【深度学习】基于Pytorch进行深度神经网络计算(一)

    [深度学习]基于Pytorch进行深度神经网络计算(一) 文章目录 1 层和块 2 自定义块 3 顺序块 4 在正向传播函数中执行代码 5 嵌套块 6 参数管理(不重要) 7 参数初始化(重要) 8 ...

  9. 【深度学习】基于Pytorch进行深度神经网络计算(二)

    [深度学习]基于Pytorch进行深度神经网络计算(二) 文章目录 1 延后初始化 2 Pytorch自定义层2.1 不带参数的层2.2 带参数的层 3 基于Pytorch存取文件 4 torch.n ...

  10. 【深度学习】基于Pytorch的卷积神经网络概念解析和API妙用(一)

    [深度学习]基于Pytorch的卷积神经网络API妙用(一) 文章目录 1 不变性 2 卷积的数学分析 3 通道 4 互相关运算 5 图像中目标的边缘检测 6 基于Pytorch的卷积核 7 特征映射 ...

最新文章

  1. Building COM Objects in C#
  2. 今年数据分析到底有多火?全网跪求优质资源!
  3. fort77编译器安装
  4. java并发编程-Executor框架
  5. 快速删除数据库中所有表中的数据
  6. WEB安全基础-URL跳转漏洞
  7. CentOS7服务管理(重启,停止,自动启动命令)
  8. 16F877A和24C02通信汇编语言,PIC16f877A读写24c02程序
  9. javascript 之牛人感悟,必看学习
  10. php反射机制详解,PHP反射机制
  11. 胶东机场t1离哪个停车场近,青岛胶东国际机场停车场攻略
  12. OFFICE文档转换到PDF的几种方法与转换效率和性能的简单比较
  13. 换页符'\f'的问题
  14. matlab 帕多瓦数列 通项公式_matlab通分.ppt
  15. 解决windows 中python打开文本文档乱码问题
  16. excel表格拆分多个表如何操作?
  17. 读书有益——》民间治咳偏方
  18. 维护采购订单审批的特性Characteristic和类Class
  19. JAVA安装、配置及卸载
  20. 金融借贷平台大数据风控解决方案

热门文章

  1. 【jQuery】复选框的全选、反选,推断哪些复选框被选中
  2. HTML系列(七):多媒体
  3. asp.net三层架构应用详解【收录】
  4. 看看用 live write 发布日志的效果
  5. 对南昌杀人案的一些看法
  6. 软件测试作业1 -- 关于c++项目中类相互调用的问题与解决
  7. Okhttp、Volley和Gson的简单介绍和配合使用
  8. Python之路【第七篇】:初识Socket
  9. hql刪除語句,根據參數刪除
  10. DP_字串匹配(HDU_1501)