文档基于b站视频:https://www.bilibili.com/video/BV187411T7Ye

流程

  1. model.py ——定义LeNet网络模型
  2. train.py ——加载数据集并训练,训练集计算loss,测试集计算accuracy,保存训练好的网络参数
  3. predict.py——得到训练好的网络参数后,用自己找的图像进行分类测试

文件目录结构截图:

1. model.py

先给出代码,模型是基于LeNet做简单修改,层数很浅,容易理解:

# 使用torch.nn包来构建神经网络.
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):                  # 继承于nn.Module这个父类def __init__(self):                       # 初始化网络结构super(LeNet, self).__init__()      # 多继承需用到super函数self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):             # 正向传播过程x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

需注意

  • pytorch 中 tensor(也就是输入输出层)的 通道排序为:[batch, channel, height, width]
  • pytorch中的卷积、池化、输入输出层中参数的含义与位置,可配合下图一起食用:

2. train.py

2.1 导入数据集

导入包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time

数据预处理

对输入的图像数据做预处理,即由shape (H x W x C) in the range [0, 255] → shape (C x H x W) in the range [0.0, 1.0]

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

此demo用的是CIFAR10数据集,也是一个很经典的图像分类数据集,由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集,一共包含 10 个类别的 RGB 彩色图片。

导入、加载 训练集和测试集

import torch
import torchvision
import torch.nn as nn
from LeNetTest.model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import timedevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
# 加载训练集,实际过程需要分批次(batch)训练
train_loader = torch.utils.data.DataLoader(train_set, batch_size=50,shuffle=True, num_workers=0)# 10000张测试图片
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=10000,shuffle=False, num_workers=0)
# classes = ('plane', 'car', 'bird', 'cat',
#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 获取测试集中的图像和标签,用于accuracy计算
test_data_iter = iter(test_loader)
test_image, test_label = test_data_iter.next()
#
# def imshow(img):  # 展示测试集图片和标签
#     img = img / 2 + 0.5     # unnormalize
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1, 2, 0)))
#     plt.show()
#
# # print labels
# print(' '.join('%5s' % classes[test_label[j]] for j in range(4)))
# # show images
# imshow(torchvision.utils.make_grid(test_label))

2.2 训练过程

名词 定义
epoch 对训练集的全部数据进行一次完整的训练,称为 一次 epoch
batch 由于硬件算力有限,实际训练时将训练集分成多个批次训练,每批数据的大小为 batch_size
iteration 或 step 对一个batch的数据训练的过程称为 一个 iteration 或 step

以本demo为例,训练集一共有50000个样本,batch_size=50,那么完整的训练一次样本:iteration或step=1000,epoch=1

net = LeNet()                                       # 定义训练的网络模型
loss_function = nn.CrossEntropyLoss()              # 定义损失函数为交叉熵损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器(训练参数,学习率)for epoch in range(5):  # 一个epoch即对整个训练集进行一次训练running_loss = 0.0time_start = time.perf_counter()for step, data in enumerate(train_loader, start=0):   # 遍历训练集,step从0开始计算inputs, labels = data    # 获取训练集的图像和标签optimizer.zero_grad()   # 清除历史梯度# forward + backward + optimizeoutputs = net(inputs)                  # 正向传播loss = loss_function(outputs, labels) # 计算损失loss.backward()                      # 反向传播optimizer.step()                      # 优化器更新参数# 打印耗时、损失、准确率等数据running_loss += loss.item()if step % 1000 == 999:    # print every 1000 mini-batches,每1000步打印一次with torch.no_grad(): # 在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用outputs = net(test_image)                # 测试集传入网络(test_batch_size=10000),output维度为[10000,10]predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出accuracy = (predict_y == test_label).sum().item() / test_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %  # 打印epoch,step,loss,accuracy(epoch + 1, step + 1, running_loss / 500, accuracy))print('%f s' % (time.perf_counter() - time_start))        # 打印耗时running_loss = 0.0print('Finished Training')# 保存训练得到的参数
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

2.3 使用GPU/CPU训练

使用下面语句可以在有GPU时使用GPU,无GPU时使用CPU进行训练

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

对应的,需要用to()函数来将Tensor在CPU和GPU之间相互移动,分配到指定的device中计算

net = LeNet()
net.to(device) # 将网络分配到指定的device中
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001) for epoch in range(5): running_loss = 0.0time_start = time.perf_counter()for step, data in enumerate(train_loader, start=0):inputs, labels = dataoptimizer.zero_grad()outputs = net(inputs.to(device))                  # 将inputs分配到指定的device中loss = loss_function(outputs, labels.to(device))  # 将labels分配到指定的device中loss.backward()optimizer.step()running_loss += loss.item()if step % 1000 == 999:    with torch.no_grad(): outputs = net(test_image.to(device)) # 将test_image分配到指定的device中predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == test_label.to(device)).sum().item() / test_label.size(0) # 将test_label分配到指定的device中print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 1000, accuracy))print('%f s' % (time.perf_counter() - time_start))running_loss = 0.0print('Finished Training')save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

打印信息如下:

可以看到,用GPU训练时,速度提升明显,耗时缩小。

3. predict.py

随便搜一张图片放入根目录下面,然后导入:

# 导入包
import torch
import torchvision.transforms as transforms
from PIL import Image
from LeNetTest.model import LeNet# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])im = Image.open("plane.jpg")
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0) # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width]# 实例化网络,加载训练好的模型参数
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))
# 预测
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()
print(classes[int(predict)])

输出即为预测的标签。

其实预测结果也可以用 softmax 表示,输出10个概率:

with torch.no_grad():outputs = net(im)predict = torch.softmax(outputs, dim=1)
print(predict)

深度学习:使用pytorch训练cifar10数据集(基于Lenet网络)相关推荐

  1. 【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

    文章目录 前言 CIFAR10简介 Backbone选择 训练+测试 训练环境及超参设置 完整代码 部分测试结果 完整工程文件 Reference 前言 分享一下本人去年入门深度学习时,在CIFAR1 ...

  2. MXNet学习:试用卷积-训练CIFAR-10数据集

    第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...

  3. 深度学习三(PyTorch物体检测实战)

    深度学习三(PyTorch物体检测实战) 文章目录 深度学习三(PyTorch物体检测实战) 1.网络骨架:Backbone 1.1.神经网络基本组成 1.1.1.卷积层 1.1.2.激活函数层 1. ...

  4. (深度学习)Pytorch之dropout训练

    (深度学习)Pytorch学习笔记之dropout训练 Dropout训练实现快速通道:点我直接看代码实现 Dropout训练简介 在深度学习中,dropout训练时我们常常会用到的一个方法--通过使 ...

  5. (翻译)60分钟入门深度学习工具-PyTorch

    60分钟入门深度学习工具-PyTorch 作者:Soumith Chintala 原文翻译自: https://pytorch.org/tutorials/beginner/deep_learning ...

  6. 【Pytorch进阶一】基于LeNet的CIFAR10图像分类

    [Pytorch进阶一]基于LeNet的CIFAR10图像分类 一.LeNet网络介绍 二.CIFAR10数据集介绍 三.程序架构介绍 3.1 LeNet模型(model.py) 3.2 训练(tra ...

  7. DL:深度学习框架Pytorch、 Tensorflow各种角度对比

    DL:深度学习框架Pytorch. Tensorflow各种角度对比 目录 先看两个框架实现同样功能的代码 1.Pytorch.Tensorflow代码比较 2.Tensorflow(数据即是代码,代 ...

  8. (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(2)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅰ

    开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...

  9. [深度学习] 分布式Pytorch介绍(三)

    [深度学习] 分布式模式介绍(一) [深度学习] 分布式Tensorflow介绍(二) [深度学习] 分布式Pytorch介绍(三) [深度学习] 分布式Horovod介绍(四)  一  Pytorc ...

最新文章

  1. 2022-2028年中国塑料安瓿瓶行业市场研究及前瞻分析报告
  2. fguillot json rpc_使用Hyperf框架搭建jsonrpc服务
  3. Oracle--序列和触发器的使用
  4. selenium基础框架的封装(Python版)这篇帖子在百度关键词搜索的第一位了,有图为证,开心!...
  5. 工作259:uni--页面--验证码添加
  6. 浙江义乌计算机中专学校,浙江义乌有没有中专学校?
  7. oracle 能被2整除_2021辽宁公务员考试:好用的“整除”法
  8. ios App开发的基本流程
  9. 【毕设】ASP.net校友录毕业设计(源代码+论文+开题报告+答辩PPT)
  10. unity 随机数_Unity 雨水滴到屏幕效果
  11. JQ插件jkscroll应用到页面中的效果
  12. is_array() 函数
  13. JS时间戳进行判断,判断是否超时三十分钟
  14. 三位数除以两位数竖式计算没有余数_四年级上册数学三位数除两位数练习题没有余数...
  15. Serverless Job—— 传统任务新变革
  16. 快速美化封面用word就可以
  17. ML之FE:特征工程处理中常用的数据变换(log取对数变换等)之详细攻略
  18. 股权和更高的薪资应该选哪个呢?
  19. STM32使用Jlink下载出现NO cortex-M SW device Found解决(超详细)
  20. 有一种情愫,它不属于暧昧

热门文章

  1. 实验四: IPv6路由选择协议配置
  2. 奔跑的业绩,需要配上奔跑的Excel条形图
  3. 分数阶 计算机应用,分数阶计算器
  4. 计算机开机后黑屏鼠标显示桌面图标,电脑开机后黑屏怎么解决只显示鼠标
  5. VS2019中Git源代码管理总结
  6. Jekyll 教程——合集(collections)
  7. 十年生死两茫茫,当我们已不再年轻——焦版小李飞刀
  8. [视频架构] Docker 全家桶 (上)
  9. 文件上传事件兼容性解决方案:兼容ie和谷歌
  10. 经济观察报:豆瓣的创业故事