LeNet非常简单,在MNIST数据集运行速度很快,所以开辟LeNet试验系列文章,以试验各种语句、技巧的效果,分析神经网络的一些特性。

1,Pytorch版本LeNet代码

  数据路径为’minst/’,文件夹内放置minst集中的四个gz文件,代码文件放在文件夹外面。

import gzip, struct
import numpy as npimport torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader#读取数据的函数
def _read(image, label):minist_dir = 'mnist/'with gzip.open(minist_dir + label) as flbl:magic, num = struct.unpack(">II", flbl.read(8))label = np.fromstring(flbl.read(), dtype=np.int8)with gzip.open(minist_dir + image, 'rb') as fimg:magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)return image, label #读取数据
def get_data():train_img, train_label = _read('train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz')test_img, test_label = _read('t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz')return [train_img, train_label, test_img, test_label]#定义lenet5
class LeNet5(nn.Module):def __init__(self):'''构造函数,定义网络的结构'''super().__init__()#定义卷积层self.conv1 = nn.Conv2d(1, 6, 5, padding=2)#第二个卷积层,6个输入,16个输出,5*5的卷积filter self.conv2 = nn.Conv2d(6, 16, 5)#最后是三个全连接层self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):'''前向传播函数'''#先卷积,然后调用relu激活函数,再最大值池化操作x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))#第二次卷积+池化操作x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))#重新塑形,将多维数据重新塑造为二维数据,256*400x = x.view(-1, self.num_flat_features(x))#print('size', x.size())#第一个全连接x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):#x.size()返回值为(256, 16, 5, 5),size的值为(16, 5, 5),256是batch_sizesize = x.size()[1:]  num_features = 1for s in size:num_features *= sreturn num_features #训练函数
def train(epoch):#调用前向传播model.train()     train_loss = 0for batch_idx, (data, target) in enumerate(train_loader):if use_gpu:data, target = data.cuda(), target.cuda()optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()train_loss += loss.item()train_loss /= len(train_loader.dataset)print('Train Epoch: {} \tTrain Loss: {:.6f}'.format(epoch, train_loss))#定义测试函数
def test():model.eval()  #测试模式,主要是保证dropout和BN和训练过程一致。test_loss = 0correct = 0for data, target in test_loader:if use_gpu:data, target = data.cuda(), target.cuda()output = model(data)#计算总的损失test_loss += criterion(output, target).item()pred = output.data.max(1, keepdim=True)[1]   #获得得分最高的类别correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))class DealDataset(Dataset):"""数据封装成dataset类型"""def __init__(self,mode='train'):X, y, Xt, yt = get_data()if mode=='train':self.x_data = Xself.y_data = yelif mode=='test':self.x_data = Xtself.y_data = ytself.x_data = torch.from_numpy(self.x_data.reshape(-1, 1, 28, 28)).float()self.y_data = torch.from_numpy(self.y_data).long()self.len = self.x_data.shape[0]def __getitem__(self, index):data = self.x_data[index]target = self.y_data[index]return data, targetdef __len__(self):return self.len#封装数据集
train_dataset = DealDataset(mode='train')
test_dataset = DealDataset(mode='test')
#定义数据加载器
kwargs = {"num_workers": 0, "pin_memory": True}
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=256, **kwargs)
test_loader = DataLoader(dataset=test_dataset, shuffle=True, batch_size=256, **kwargs)#实例化网络
model = LeNet5()
#是否使用GPU
use_gpu = torch.cuda.is_available()
if use_gpu:model = model.cuda()print('USE GPU')
else:print('USE CPU')
#使用交叉熵损失函数
criterion = nn.CrossEntropyLoss(size_average=False)
#优化器
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.99))#执行训练和测试
for epoch in range(1, 101):train(epoch)test()

2,运行

运行结果:

下面几个画图的代码上面没有。
train_loss:

test_loss:

accuarcy:

LeNet试验(一) 搭建pytorch版模型及运行相关推荐

  1. windows10使用cuda11搭建pytorch深度学习框架——运行Dlinknet提取道路(三)——模型精度评估代码完善

    重新调试好代码,使用Dinknet34模型对数据集进行训练 数据集大小为1480张图片 运行时间为2022年1月12日16:00 记录下该模型训练时间 但如何评估模型的精度也是一个问题,因此作如下总结 ...

  2. windows10使用cuda11搭建pytorch深度学习框架——运行Dlinknet提取道路(二)——代码运行问题解决

    运行程序 去github上下载Dlinknet的代码 https://github.com/zlckanata/DeepGlobe-Road-Extraction-Challenge 把数据集放进da ...

  3. <计算机视觉四> pytorch版yolov3网络搭建

    鼠标点击下载     项目源代码免费下载地址 <计算机视觉一> 使用标定工具标定自己的目标检测 <计算机视觉二> labelme标定的数据转换成yolo训练格式 <计算机 ...

  4. Pytorch搭建自己的模型

    前言 PyTorch.TensorFlow都是主流的深度学习框架,今天主要讲解一下如何快速使用pytorch搭建自己的模型.至于为什么选择讲解pytorch,这里我就简单说明一下自己的使用感受(相对T ...

  5. 最强NLP模型BERT喜迎PyTorch版!谷歌官方推荐,也会支持中文

    郭一璞 夏乙 发自 凹非寺  量子位 报道 | 公众号 QbitAI 谷歌的最强NLP模型BERT发布以来,一直非常受关注,上周开源的官方TensorFlow实现在GitHub上已经收获了近6000星 ...

  6. Pytorch一行代码便可以搭建整个transformer模型

    transformer模型是在NLP领域发表的论文attention is all you need中提出的一种语言处理模型,其transformer模型由于加速了模型推理时间与训练精度,越来越受到了 ...

  7. 搭建GPU版PyTorch Docker镜像

    提要: 记录手动搭建GPU版PyTorch Docker镜像的过程.本地主机已经装好了显卡驱动,CUDA, cuDNN, 因此不再累述.本篇博客覆盖以下内容: Docker常用命令 搭建GPU版PyT ...

  8. PyTorch中模型的可复现性

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:AI算法与图像处理 在深度学习模型的训练过程中,难免引入 ...

  9. 364 页 PyTorch 版《动手学深度学习》分享(全中文,支持 Jupyter 运行)

    1 前言 最近有朋友留言要求分享一下李沐老师的<动手学深度学习>,小汤本着一直坚持的"好资源大家一起分享,共同学习,共同进步"的初衷,于是便去找了资料,而且还是中文版的 ...

最新文章

  1. XMLDOM对象方法:对象属性
  2. 【树莓派】树莓派移动网络连接(配置4G网卡)
  3. C++ Primer 5th笔记(10)chapter10 泛型算法 : read
  4. kafka和zookeeper一键启停脚本(以及kafka关不掉问题解决)
  5. RabbitMQ 镜像集群配置_05
  6. SQLServer 优化SQL语句:in 和not in的替代方案
  7. Python爬虫 Day 3
  8. (器) 构建自由通行的IOS开发者地图
  9. java两人猜数字游戏,三人背后猜数字游戏
  10. html网页设计优秀作品和代码,优秀的网页设计作品(一)
  11. C语言笔试经典编程题目(汇总帖)
  12. 计算机电源大小,常见电脑主板和电源尺寸
  13. SI24R1可以替代NRF24L01P软件硬件DIY兼容成功
  14. C#程序设计与应用课程教学总结:自评与改进
  15. 智能机器人走迷宫c语言游戏,(动态规划)机器人走迷宫问题(示例代码)
  16. Oracle的OFA架构
  17. vue2+element ui 导入和导出后端传过来的文件
  18. 程序员婚恋那点事儿(四):程序员与程序媛的婚礼
  19. Python基本语法一
  20. pytorch断点续传

热门文章

  1. 插入排序法算长度为10的数组
  2. Jmeter实现压力测试(多并发测试)
  3. STM32下载库资料
  4. Hessian 初探
  5. xcode:关于Other Linker Flags
  6. Endnote技巧:解决中英参考文献混排问题,附国标文件
  7. 如何将gitub的项目在eclipse中运行
  8. php中按引用传递参数,如何通过PHP中的引用传递可变参数的参数?
  9. ul去掉li前面的点_一年级语文上册期末考点:拼音重、难点总结,全面,建议收藏...
  10. 图像处理--角点检测与匹配