利用pytorch搭建LeNet-5网络模型(win11)
目录
1. 前言
2. 程序
2.1 net
2.2 train
2.3 main
3. 总结
1. 前言
手写字体识别模型LeNet5诞生于1994年,是最早的卷积神经网络之一。LeNet5通过巧妙的设计,利用卷积、参数共享、池化等操作提取特征,避免了大量的计算成本,最后再使用全连接神经网络进行分类识别,这个网络也是最近大量神经网络架构的起点。
代码分析见视频: 从0开始撸代码--手把手教你搭建LeNet-5网络模型
环境配置见:基于Anaconda安装pytorch和paddle深度学习环境(win11)
卷积或池化输出图像尺寸的计算公式如下:
O=输出图像的尺寸;I=输入图像的尺寸;K=池化或卷积层的核尺寸;S=移动步长;P =填充数
2. 程序
程序分为三部分net、train、main(test)
2.1 net
定义网络结构,初始化输入输出参数
import torch
from torch import nn# 定义一个网络模型类
class MyLeNet5(nn.Module):# 初始化网络def __init__(self):super(MyLeNet5, self).__init__()# 输入大小为32*32,输出大小为28*28,输入通道为1,输出为6,卷积核为5self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)# 使用sigmoid激活函数self.Sigmoid = nn.Sigmoid()# 使用平均池化self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)self.flatten = nn.Flatten()self.f6 = nn.Linear(120, 84)self.output = nn.Linear(84, 10)def forward(self, x):# x输入为32*32*1, 输出为28*28*6x = self.Sigmoid(self.c1(x))# x输入为28*28*6, 输出为14*14*6x = self.s2(x)# x输入为14*14*6, 输出为10*10*16x = self.Sigmoid(self.c3(x))# x输入为10*10*16, 输出为5*5*16x = self.s4(x)# x输入为5*5*16, 输出为1*1*120x = self.c5(x)x = self.flatten(x)# x输入为120, 输出为84x = self.f6(x)# x输入为84, 输出为10x = self.output(x)return xif __name__ == "__main__":x = torch.rand([1, 1, 28, 28])model = MyLeNet5()y = model(x)
可以右键run 一下,检验程序是否正确。
2.2 train
数据转化为tensor格式,加载训练数据集,如果显卡可用,则用显卡进行训练,如果GPU可用则将模型转到GPU,定义损失函数,定义训练函数,定义验证函数,开始训练,这里为了节约时间只训练了20次。
import torch
from torch import nn
from net import MyLeNet5
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import os# 将数据转化为tensor格式
data_transform = transforms.Compose([transforms.ToTensor()
])# 加载训练数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
# 加载训练数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 如果显卡可用,则用显卡进行训练
device = "cuda" if torch.cuda.is_available() else 'cpu'# 调用net里面定义的模型,如果GPU可用则将模型转到GPU
model = MyLeNet5().to(device)# 定义损失函数(交叉熵损失)
loss_fn = nn.CrossEntropyLoss()# 定义优化器,SGD,
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)# 学习率每隔10epoch变为原来的0.1
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):loss, current, n = 0.0, 0.0, 0# enumerate返回为数据和标签还有批次for batch, (X, y) in enumerate(dataloader):# 前向传播X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)# torch.max返回每行最大的概率和最大概率的索引,由于批次是16,所以返回16个概率和索引_, pred = torch.max(output, axis=1)# 计算每批次的准确率, output.shape[0]为该批次的多少cur_acc = torch.sum(y == pred) / output.shape[0]# print(cur_acc)# 反向传播optimizer.zero_grad()cur_loss.backward()optimizer.step()# 取出loss值和精度值loss += cur_loss.item()current += cur_acc.item()n = n + 1print('train_loss' + str(loss / n))print('train_acc' + str(current / n))# 定义验证函数
def val(dataloader, model, loss_fn):# 将模型转为验证模式model.eval()loss, current, n = 0.0, 0.0, 0# 非训练,推理期用到(测试时模型参数不用更新, 所以no_grad)# print(torch.no_grad)with torch.no_grad():for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)output = model(X)cur_loss = loss_fn(output, y)_, pred = torch.max(output, axis=1)cur_acc = torch.sum(y == pred) / output.shape[0]loss += cur_loss.item()current += cur_acc.item()n = n + 1print('val_loss' + str(loss / n))print('val_acc' + str(current / n))return current / n# 开始训练
epoch = 20
min_acc = 0
for t in range(epoch):lr_scheduler.step()print(f"epoch{t + 1}\n-------------------")train(train_dataloader, model, loss_fn, optimizer)a = val(test_dataloader, model, loss_fn)# 保存最好的模型权重文件if a > min_acc:folder = 'sava_model'if not os.path.exists(folder):os.mkdir('sava_model')min_acc = aprint('save best model', )torch.save(model.state_dict(), "sava_model/best_model.pth")# 保存最后的权重文件if t == epoch - 1:torch.save(model.state_dict(), "sava_model/last_model.pth")
print('Done!')
可以run(F5) 一下,得到最佳模型,精度达到了96.16%,还不错呦
2.3 main
数据转化为tensor格式,加载训练数据集,如果显卡可用,则用显卡进行训练,如果GPU可用则将模型转到GPU,加载 train.py 里训练好的模型,把tensor转成Image, 方便可视化,获取预测结果,进入验证阶段,测试前五张图片。
import torch
from net import MyLeNet5
from torch.autograd import Variable
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage# 将数据转化为tensor格式
data_transform = transforms.Compose([transforms.ToTensor()
])# 加载训练数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True)
# 加载训练数据集
test_dataset = datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)
# 给训练集创建一个数据加载器, shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)# 如果显卡可用,则用显卡进行训练
device = "cuda" if torch.cuda.is_available() else 'cpu'# 调用net里面定义的模型,如果GPU可用则将模型转到GPU
model = MyLeNet5().to(device)# 加载 train.py 里训练好的模型
model.load_state_dict(torch.load("sava_model/best_model.pth"))# 获取预测结果
classes = ["0","1","2","3","4","5","6","7","8","9",
]# 把tensor转成Image, 方便可视化
show = ToPILImage()# 进入验证阶段
model.eval()
# 对test_dataset里10000张手写数字图片进行推理
# for i in range(len(test_dataloader)):
for i in range(5):x, y = test_dataset[i][0], test_dataset[i][1]# tensor格式数据可视化show(x).show()# 扩展张量维度为4维x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False).to(device)with torch.no_grad():pred = model(x)# 得到预测类别中最高的那一类,再把最高的这一类对应classes中的哪一类标签predicted, actual = classes[torch.argmax(pred[0])], classes[y]# 最终输出的预测值与真实值print(f'predicted: "{predicted}", actual:"{actual}"')
效果如下 ,全部都是正确的滴!
3. 总结
本文介绍利用pytorch搭建LeNet-5网络模型(win11),接下来我会记录我的pytorch和paddle深度学习记录,很高兴能和大家分享!
利用pytorch搭建LeNet-5网络模型(win11)相关推荐
- python与机器学习(七)上——PyTorch搭建LeNet模型进行MNIST分类
任务要求:利用PyTorch框架搭建一个LeNet模型,并针对MNIST数据集进行训练和测试. 数据集:MNIST 导入: import torch from torch import nn, opt ...
- 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络
Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...
- Pytorch搭建常见分类网络模型------VGG、Googlenet、ResNet50 、MobileNetV2(4)
接上一节内容:Pytorch搭建常见分类网络模型------VGG.Googlenet.ResNet50 .MobileNetV2(3)_一只小小的土拨鼠的博客-CSDN博客 mobilenet系列: ...
- (含源码下载)利用Pytorch搭建GPU视觉处理接口
利用Pytorch搭建视觉处理接口 NVIDIA 视觉编程接口 (VPI: Vision Programming Interface) 是 NVIDIA 的计算机视觉和图像处理软件库,使您能够实现在 ...
- 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)
目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...
- 实战:利用pytorch搭建VGG-16实现从数据获取到模型训练的猫狗分类网络
起 在学习了卷积神经网络的理论基础和阅读了VGG的论文之后,对卷积有了大致的了解,但这都只是停留在理论上,动手实践更为重要,于是便开始了0基础学习pytorch.图像处理,搭建模型. pytorch学 ...
- 基于PyTorch的生成对抗网络入门(3)——利用PyTorch搭建生成对抗网络(GAN)生成彩色图像超详解
目录 一.案例描述 二.代码详解 2.1 获取数据 2.2 数据集类 2.3 构建判别器 2.3.1 构造函数 2.3.2 测试判别器 2.4 构建生成器 2.4.1 构造函数 2.4.2 测试生成器 ...
- pytorch 搭建 VGG 网络
目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...
- 使用pytorch搭建AlexNet网络模型
使用pytorch搭建AlexNet网络模型 AlexNet详解 AlexNet是2012年ISLVRC 2012(ImageNet Large Scale Visual Recognition Ch ...
最新文章
- deepin-wine-qq无法加载图片解决方案
- 分布式存储系统设计的几个问题和考虑点
- c语言pID程序怎么设计,51单片机PID的算法实现程序C语言
- boost::mp11::mp_identity_t相关用法的测试程序
- SAP UI5 this.oModel.createBindingContext will trigger odata request
- html尾部代码_3分钟短文:Laravel Form,让你不再写 HTML 的好“库”
- Android开发BroadcastReceiver广播的使用
- 六下计算机教学总结,六年级信息技术教师教学工作总结
- 最新基于高德地图的android进阶开发(3)GPS地图定位
- 微言Netty:分布式服务框架
- 【已解决】Python将网页内容保存为PDF (url转pdf)
- PHP内容管理系统详细制作步骤
- html设置背景图片透明度代码,css设置图片背景透明度
- ros使用自动驾驶数据集KITTI【1】介绍与可视化
- 麻雀租房App 作品展示
- 经典递归算法之Fibonacci序列
- GIF转MP4 - 在线将GIF动态图转为MP4视频文件
- 金蝶eas系统服务器地址,金蝶eas服务器地址
- linux下qt不能加载控件,找不到或加载Qt平台插件“xcb”
- MCU学习笔记_STA及PT工具
热门文章
- 贪心算法之贪心的c小加问题
- Collectors.reducing总结Collectors.mapping+Collectors.reducing+TreeSet等等
- 数据报(datagram)网络与虚电路(virtual-circuit)网络是典型两类分组交换网络。
- linux密码sha512,如何在Linux上检查SHA1,SHA256和SHA512哈希 | MOS86
- Python IPy模块的使用
- python元组和列表
- 在淘宝开放平台创建应该步骤
- 转自南开bbs——看完了就不会有人再感慨奶粉事件了...
- 【0x7FFFFFFF】【0x3f3f3f3f】
- 新学 PHP 日记。(分页查询)