目录

1. 前言

2. 程序

2.1 net

2.2 train

2.3 main

3. 总结


1. 前言

手写字体识别模型LeNet5诞生于1994年,是最早的卷积神经网络之一。LeNet5通过巧妙的设计,利用卷积、参数共享、池化等操作提取特征,避免了大量的计算成本,最后再使用全连接神经网络进行分类识别,这个网络也是最近大量神经网络架构的起点。

代码分析见视频: 从0开始撸代码--手把手教你搭建LeNet-5网络模型

环境配置见:基于Anaconda安装pytorch和paddle深度学习环境(win11)

卷积或池化输出图像尺寸的计算公式如下:

O=输出图像的尺寸;I=输入图像的尺寸;K=池化或卷积层的核尺寸;S=移动步长;P =填充数

2. 程序

程序分为三部分net、train、main(test)

2.1 net

定义网络结构,初始化输入输出参数

import torch
from torch import nn# 定义一个网络模型类
class MyLeNet5(nn.Module):# 初始化网络def __init__(self):super(MyLeNet5, self).__init__()# 输入大小为32*32,输出大小为28*28,输入通道为1,输出为6,卷积核为5self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)# 使用sigmoid激活函数self.Sigmoid = nn.Sigmoid()# 使用平均池化self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)self.flatten = nn.Flatten()self.f6 = nn.Linear(120, 84)self.output = nn.Linear(84, 10)def forward(self, x):# x输入为32*32*1, 输出为28*28*6x = self.Sigmoid(self.c1(x))# x输入为28*28*6, 输出为14*14*6x = self.s2(x)# x输入为14*14*6, 输出为10*10*16x = self.Sigmoid(self.c3(x))# x输入为10*10*16, 输出为5*5*16x = self.s4(x)# x输入为5*5*16, 输出为1*1*120x = self.c5(x)x = self.flatten(x)# x输入为120, 输出为84x = self.f6(x)# x输入为84, 输出为10x = self.output(x)return xif __name__ == "__main__":x = torch.rand([1, 1, 28, 28])model = MyLeNet5()y = model(x)

可以右键run 一下,检验程序是否正确。

2.2 train

数据转化为tensor格式,加载训练数据集,如果显卡可用,则用显卡进行训练,如果GPU可用则将模型转到GPU,定义损失函数,定义训练函数,定义验证函数,开始训练,这里为了节约时间只训练了20次。

import torch
from torch import nn
from net import MyLeNet5
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os# 将数据转化为tensor格式
data_transform = transforms.Compose([transforms.ToTensor()
])# 加载训练数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
# 加载训练数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 如果显卡可用,则用显卡进行训练
device = "cuda" if torch.cuda.is_available() else 'cpu'# 调用net里面定义的模型,如果GPU可用则将模型转到GPU
model = MyLeNet5().to(device)# 定义损失函数(交叉熵损失)
loss_fn = nn.CrossEntropyLoss()# 定义优化器,SGD,
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)# 学习率每隔10epoch变为原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):loss, current, n = 0.0, 0.0, 0# enumerate返回为数据和标签还有批次for batch, (X, y) in enumerate(dataloader):# 前向传播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)# torch.max返回每行最大的概率和最大概率的索引,由于批次是16,所以返回16个概率和索引_, pred = torch.max(output, axis=1)# 计算每批次的准确率, output.shape[0]为该批次的多少cur_acc = torch.sum(y == pred) / output.shape[0]# print(cur_acc)# 反向传播optimizer.zero_grad()cur_loss.backward()optimizer.step()# 取出loss值和精度值loss += cur_loss.item()current += cur_acc.item()n = n + 1print('train_loss' + str(loss / n))print('train_acc' + str(current / n))# 定义验证函数
def val(dataloader, model, loss_fn):# 将模型转为验证模式model.eval()loss, current, n = 0.0, 0.0, 0# 非训练,推理期用到(测试时模型参数不用更新, 所以no_grad)# print(torch.no_grad)with torch.no_grad():for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)_, pred = torch.max(output, axis=1)cur_acc = torch.sum(y == pred) / output.shape[0]loss += cur_loss.item()current += cur_acc.item()n = n + 1print('val_loss' + str(loss / n))print('val_acc' + str(current / n))return current / n# 开始训练
epoch = 20
min_acc = 0
for t in range(epoch):lr_scheduler.step()print(f"epoch{t + 1}\n-------------------")train(train_dataloader, model, loss_fn, optimizer)a = val(test_dataloader, model, loss_fn)# 保存最好的模型权重文件if a > min_acc:folder = 'sava_model'if not os.path.exists(folder):os.mkdir('sava_model')min_acc = aprint('save best model', )torch.save(model.state_dict(), "sava_model/best_model.pth")# 保存最后的权重文件if t == epoch - 1:torch.save(model.state_dict(), "sava_model/last_model.pth")
print('Done!')

可以run(F5) 一下,得到最佳模型,精度达到了96.16%,还不错呦

2.3 main

数据转化为tensor格式,加载训练数据集,如果显卡可用,则用显卡进行训练,如果GPU可用则将模型转到GPU,加载 train.py 里训练好的模型,把tensor转成Image, 方便可视化,获取预测结果,进入验证阶段,测试前五张图片。

import torch
from net import MyLeNet5
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage# 将数据转化为tensor格式
data_transform = transforms.Compose([transforms.ToTensor()
])# 加载训练数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
# 加载训练数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)#  如果显卡可用,则用显卡进行训练
device = "cuda" if torch.cuda.is_available() else 'cpu'# 调用net里面定义的模型,如果GPU可用则将模型转到GPU
model = MyLeNet5().to(device)# 加载 train.py 里训练好的模型
model.load_state_dict(torch.load("sava_model/best_model.pth"))# 获取预测结果
classes = ["0","1","2","3","4","5","6","7","8","9",
]# 把tensor转成Image, 方便可视化
show = ToPILImage()# 进入验证阶段
model.eval()
# 对test_dataset里10000张手写数字图片进行推理
# for i in range(len(test_dataloader)):
for i in range(5):x, y = test_dataset[i][0], test_dataset[i][1]# tensor格式数据可视化show(x).show()# 扩展张量维度为4维x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False).to(device)with torch.no_grad():pred = model(x)# 得到预测类别中最高的那一类,再把最高的这一类对应classes中的哪一类标签predicted, actual = classes[torch.argmax(pred[0])], classes[y]# 最终输出的预测值与真实值print(f'predicted: "{predicted}", actual:"{actual}"')

效果如下 ,全部都是正确的滴!

3. 总结

本文介绍利用pytorch搭建LeNet-5网络模型(win11),接下来我会记录我的pytorch和paddle深度学习记录,很高兴能和大家分享!

利用pytorch搭建LeNet-5网络模型(win11)相关推荐

  1. python与机器学习(七)上——PyTorch搭建LeNet模型进行MNIST分类

    任务要求:利用PyTorch框架搭建一个LeNet模型,并针对MNIST数据集进行训练和测试. 数据集:MNIST 导入: import torch from torch import nn, opt ...

  2. 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络

    Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...

  3. Pytorch搭建常见分类网络模型------VGG、Googlenet、ResNet50 、MobileNetV2(4)

    接上一节内容:Pytorch搭建常见分类网络模型------VGG.Googlenet.ResNet50 .MobileNetV2(3)_一只小小的土拨鼠的博客-CSDN博客 mobilenet系列: ...

  4. (含源码下载)利用Pytorch搭建GPU视觉处理接口

    利用Pytorch搭建视觉处理接口 NVIDIA 视觉编程接口 (VPI: Vision Programming Interface) 是 NVIDIA 的计算机视觉和图像处理软件库,使您能够实现在 ...

  5. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  6. 实战:利用pytorch搭建VGG-16实现从数据获取到模型训练的猫狗分类网络

    起 在学习了卷积神经网络的理论基础和阅读了VGG的论文之后,对卷积有了大致的了解,但这都只是停留在理论上,动手实践更为重要,于是便开始了0基础学习pytorch.图像处理,搭建模型. pytorch学 ...

  7. 基于PyTorch的生成对抗网络入门(3)——利用PyTorch搭建生成对抗网络(GAN)生成彩色图像超详解

    目录 一.案例描述 二.代码详解 2.1 获取数据 2.2 数据集类 2.3 构建判别器 2.3.1 构造函数 2.3.2 测试判别器 2.4 构建生成器 2.4.1 构造函数 2.4.2 测试生成器 ...

  8. pytorch 搭建 VGG 网络

    目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...

  9. 使用pytorch搭建AlexNet网络模型

    使用pytorch搭建AlexNet网络模型 AlexNet详解 AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Ch ...

最新文章

  1. deepin-wine-qq无法加载图片解决方案
  2. 分布式存储系统设计的几个问题和考虑点
  3. c语言pID程序怎么设计,51单片机PID的算法实现程序C语言
  4. boost::mp11::mp_identity_t相关用法的测试程序
  5. SAP UI5 this.oModel.createBindingContext will trigger odata request
  6. html尾部代码_3分钟短文:Laravel Form,让你不再写 HTML 的好“库”
  7. Android开发BroadcastReceiver广播的使用
  8. 六下计算机教学总结,六年级信息技术教师教学工作总结
  9. 最新基于高德地图的android进阶开发(3)GPS地图定位
  10. 微言Netty:分布式服务框架
  11. 【已解决】Python将网页内容保存为PDF (url转pdf)
  12. PHP内容管理系统详细制作步骤
  13. html设置背景图片透明度代码,css设置图片背景透明度
  14. ros使用自动驾驶数据集KITTI【1】介绍与可视化
  15. 麻雀租房App 作品展示
  16. 经典递归算法之Fibonacci序列
  17. GIF转MP4 - 在线将GIF动态图转为MP4视频文件
  18. 金蝶eas系统服务器地址,金蝶eas服务器地址
  19. linux下qt不能加载控件,找不到或加载Qt平台插件“xcb”
  20. MCU学习笔记_STA及PT工具

热门文章

  1. 贪心算法之贪心的c小加问题
  2. Collectors.reducing总结Collectors.mapping+Collectors.reducing+TreeSet等等
  3. 数据报(datagram)网络与虚电路(virtual-circuit)网络是典型两类分组交换网络。
  4. linux密码sha512,如何在Linux上检查SHA1,SHA256和SHA512哈希 | MOS86
  5. Python IPy模块的使用
  6. python元组和列表
  7. 在淘宝开放平台创建应该步骤
  8. 转自南开bbs——看完了就不会有人再感慨奶粉事件了...
  9. 【0x7FFFFFFF】【0x3f3f3f3f】
  10. 新学 PHP 日记。(分页查询)