调用PyTorch相关接口实现一个LeNet-5网络,然后通过MNIST数据集训练模型,最后对生成的模型进行预测,主要包括2大部分:训练和预测

1.训练部分:

(1).加载MNIST数据集,通过调用TorchVision模块中的接口实现,将每幅图像缩放到32*32大小,小批量数据集数量设置为32;

(2).设置网络参数的初始值,这样保证每次重新训练时初始值都是固定的,便于查找定位问题;

(3).设计LeNet-5网络,并实例化一个网络对象,重载了__init__和forward两个函数,使用到的layer包括Conv2d、AvgPool2d、Linear;激活函数使用Tanh:

(4).指定优化算法,这里采用Adam;

(5).指定损失函数,这里采用CrossEntropyLoss;

(6).训练,epochs设置为10,给出每次的训练结果;

(7).保存模型,推荐使用state_dict。

代码段如下:

def load_mnist_dataset(img_size, batch_size):'''下载并加载mnist数据集img_size: 图像大小,宽高长度相同batch_size: 小批量数据集数量'''# 对PIL图像先进行缩放操作,然后转换成tensor类型transforms_ = transforms.Compose([transforms.Resize(size=(img_size, img_size)), transforms.ToTensor()])'''下载MNIST数据集root: mnist数据集存放目录名train: 可选参数, 默认为True; 若为True,则从MNIST/processed/training.pt创建数据集;若为False,则从MNIST/processed/test.pt创建数据集transform: 可选参数, 默认为None; 接收PIL图像并作处理target_transform: 可选参数, 默认为Nonedownload: 可选参数, 默认为False; 若为True,则从网络上下载数据集到root指定的目录'''train_dataset = datasets.MNIST(root="mnist_data", train=True, transform=transforms_, target_transform=None, download=True)valid_dataset = datasets.MNIST(root="mnist_data", train=False, transform=transforms_, target_transform=None, download=False)# 加载MNIST数据集:shuffle为True,则在每次epoch时重新打乱顺序train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False)return train_loader, valid_loader, train_dataset, valid_datasetclass LeNet5(nn.Module):'''构建lenet网络'''def __init__(self, n_classes: int) -> None:super(LeNet5, self).__init__() # 调用父类Module的构造方法# n_classes: 类别数# nn.Sequential: 顺序容器,Module将按照它们在构造函数中传递的顺序添加,它允许将整个容器视为单个moduleself.feature_extractor = nn.Sequential( # 输入32*32nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0), # 卷积层,28*28*6nn.Tanh(), # 激活函数Tanh,使其值范围在(-1, 1)内nn.AvgPool2d(kernel_size=2, stride=None, padding=0), # 平均池化层,14*14*6nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0), # 10*10*16nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=None, padding=0), # 5*5*16nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1, padding=0), # 1*1*120nn.Tanh())self.classifier = nn.Sequential( # 输入1*1*120nn.Linear(in_features=120, out_features=84), # 全连接层,84nn.Tanh(),nn.Linear(in_features=84, out_features=n_classes) # 10)# LeNet5继承nn.Module,定义forward函数后,backward函数就会利用Autograd被自动实现# 只要实例化一个LeNet5对象并传入对应的参数x就可以自动调用forward函数def forward(self, x: Tensor):                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              x = self.feature_extractor(x)x = torch.flatten(input=x, start_dim=1) # 将输入按指定展平,start_dim=1则第一维度不变,后面的展平logits = self.classifier(x)probs = F.softmax(input=logits, dim=1) # 激活函数softmax: 使得每一个元素的范围都在(0,1)之间,并且所有元素的和为1return logits, probsdef validate(valid_loader, model, criterion, device):'''Function for the validation step of the training loop'''model.eval() # 将网络设置为评估模式running_loss = 0for X, y_true in valid_loader:X = X.to(device) # 将数据导入到指定的设备上(cpu或gpu)y_true = y_true.to(device)# Forward pass and record lossy_hat, _ = model(X) # 前向传播:调用Module的__call__方法, 此方法内会调用指定网络(如LeNet5)的forward方法loss = criterion(y_hat, y_true) # 计算loss,同上,通过__call__方法调用指定损失函数类(如CrossEntropyLoss)中的forward方法running_loss += loss.item() * X.size(0)epoch_loss = running_loss / len(valid_loader.dataset)return model, epoch_lossdef get_accuracy(model, data_loader, device):'''Function for computing the accuracy of the predictions over the entire data_loader'''correct_pred = 0n = 0with torch.no_grad(): # 临时将循环内的所有Tensor的requires_grad标志设置为False,不再计算Tensor的梯度(自动求导)model.eval() # 将网络设置为评估模式for X, y_true in data_loader:X = X.to(device) # 将数据导入到指定的设备上(cpu或gpu)y_true = y_true.to(device)_, y_prob = model(X) # y_prob.size(): troch.Size([32, 10]): [cols, rows]# torch.max(input):返回Tensor中所有元素的最大值# torch.max(input, dim):按维度dim返回最大值,并且返回索引# dim=0: 返回每一列中最大值的那个元素,并且返回索引# dim=1: 返回每一行中最大值的那个元素,并且返回索引_, predicted_labels = torch.max(y_prob, 1)n += y_true.size(0)correct_pred += (predicted_labels == y_true).sum()return correct_pred.float() / ndef train(train_loader, model, criterion, optimizer, device):'''Function for the training step of the training loop'''model.train() # 将网络设置为训练模式running_loss = 0for X, y_true in train_loader: # 先调用DataLoader类的__iter__函数,接着循环调用_DataLoaderIter类的__next__函数# X.size(shape: [n,c,h,w]): torch.Size([32, 1, 32, 32]); y_true.size: torch.Size([32]); n为batch_sizeoptimizer.zero_grad() # 将优化算法中的梯度重置为0,需要在计算下一个小批量数据集的梯度之前调用它,否则梯度将累积到现有的梯度中# 将Tensor数据导入到指定的设备上(cpu或gpu)X = X.to(device)y_true = y_true.to(device)y_hat, _ = model(X) # 前向传播:调用Module的__call__方法, 此方法内会调用指定网络(如LeNet5)的forward方法# y_hat.size(): torch.Size([32, 10]); _.size(): torch.Size([32, 10])loss = criterion(y_hat, y_true) # 计算loss,同上,通过__call__方法调用指定损失函数类(如CrossEntropyLoss)中的forward方法running_loss += loss.item() * X.size(0)loss.backward() # 反向传播,使用Autograd自动计算标量的当前梯度optimizer.step() # 根据梯度更新网络参数,优化器通过.grad中存储的梯度来调整每个参数epoch_loss = running_loss / len(train_loader.dataset)return model, optimizer, epoch_lossdef training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every=1):'''Function defining the entire training loopmodel: 网络对象criterion: 损失函数对象optimizer: 优化算法对象train_loader: 训练数据集对象valid_loader: 测试数据集对象epochs: 重复训练整个训练数据集的次数device: 指定在cpu上还是在gpu上运行print_every: 每训练几次打印一次训练结果'''train_losses = []valid_losses = []for epoch in range(0, epochs):model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device)train_losses.append(train_loss)# 每次训练完后通过测试数据集进行评估with torch.no_grad(): # 临时将循环内的所有Tensor的requires_grad标志设置为False,不再计算Tensor的梯度(自动求导)model, valid_loss = validate(valid_loader, model, criterion, device)valid_losses.append(valid_loss)if epoch % print_every == (print_every - 1):train_acc = get_accuracy(model, train_loader, device=device)valid_acc = get_accuracy(model, valid_loader, device=device)print(f'  {datetime.now().time().replace(microsecond=0)}:'f' Epoch: {epoch}', f' Train loss: {train_loss:.4f}', f' Valid loss: {valid_loss:.4f}'f' Train accuracy: {100 * train_acc:.2f}', f' Valid accuracy: {100 * valid_acc:.2f}')return model, optimizer, (train_losses, valid_losses)def train_and_save_model():print("#### start training ... ####")print("1. load mnist dataset")train_loader, valid_loader, _, _ = load_mnist_dataset(img_size=32, batch_size=32)print("2. fixed random init value")# 用于设置随机初始化;如果不设置每次训练时的网络初始化都是随机的,导致结果不确定;如果设置了,则每次初始化都是固定的torch.manual_seed(seed=42)#print("value:", torch.rand(1), torch.rand(1), torch.rand(1)) # 运行多次,每次输出的值都是相同的,[0, 1)print("3. instantiate lenet net object")model = LeNet5(n_classes=10).to('cpu') # 在CPU上运行print("4. specify the optimization algorithm: Adam")optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001) # 定义优化算法:Adam是一种基于梯度的优化算法print("5. specify the loss function: CrossEntropyLoss")criterion = nn.CrossEntropyLoss() # 定义损失函数:交叉熵损失print("6. repeated training")model, _, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs=10, device='cpu') # epochs为遍历训练整个数据集的次数print("7. save model")model_name = "../../../data/Lenet-5.pth"#torch.save(model, model_name) # 保存整个模型, 对应于model = torch.loadtorch.save(model.state_dict(), model_name) # 推荐:只保存模型训练好的参数,对应于model.load_state_dict(torch.load)

执行结果如下所示:

2.手写数字图像识别部分:

(1).加载模型,推荐使用load_state_dict,对应于保存模型时使用的state_dict;

(2).设置网络到评估模式;

(3).准备测试图像,一共10幅,0到9各一幅,如下图所示,注意:训练图像背景色为黑色,而测试图像背景色为白色:

(4).依次对每幅图像进行识别。

代码段如下所示:

def list_files(filepath, filetype):'''遍历指定目录下的指定文件'''paths = []for root, dirs, files in os.walk(filepath):for file in files:if file.lower().endswith(filetype.lower()):paths.append(os.path.join(root, file))return pathsdef get_image_label(image_name, image_name_suffix):'''获取测试图像对应label'''index = image_name.rfind("/")if index == -1:print(f"Error: image name {image_name} is not supported")sub = image_name[index+1:]label = sub[:len(sub)-len(image_name_suffix)]return labeldef image_predict():print("#### start predicting ... ####")print("1. load model")model_name = "../../../data/Lenet-5.pth"model = LeNet5(n_classes=10).to('cpu') # 实例化一个网络对象model.load_state_dict(torch.load(model_name)) # 加载模型print("2. set net to evaluate mode")model.eval()print("3. prepare test images")image_path = "../../../data/image/handwritten_digits/"image_name_suffix = ".png"images_name = list_files(image_path, image_name_suffix)print("4. image recognition")with torch.no_grad():for image_name in images_name:#print("image name:", image_name)label = get_image_label(image_name, image_name_suffix)img = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE)img = cv2.resize(img, (32, 32))# MNIST图像背景为黑色,而测试图像的背景色为白色,识别前需要做转换img = cv2.bitwise_not(img)#print("img shape:", img.shape)# 将opencv image转换到pytorch tensortransform = transforms.ToTensor()tensor = transform(img) # tensor shape: torch.Size([1, 32, 32])tensor = tensor.unsqueeze(0) # tensor shape: torch.Size([1, 1, 32, 32])#print("tensor shape:", tensor.shape)_, y_prob = model(tensor)_, predicted_label = torch.max(y_prob, 1)print(f"  predicted label: {predicted_label.item()}, ground truth label: {label}")

执行结果如下图所示:

GitHub: https://github.com/fengbingchun/PyTorch_Test

通过PyTorch构建的LeNet-5网络对手写数字进行训练和识别相关推荐

  1. PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN

    PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN 目录 训练过程 代码设计 训练过程 代码设计 #PyTorch:利用PyTorch实现 ...

  2. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

    文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...

  3. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

    文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...

  4. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  5. pytorch 预测手写体数字_深度学习之PyTorch实战(3)——实战手写数字识别

    如果需要小编其他论文翻译,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 上一节,我们已经 ...

  6. python识别手写数字字体_基于tensorflow框架对手写字体MNIST数据集的识别

    本文我们利用python语言,通过tensorflow框架对手写字体MNIST数据库进行识别. 学习每一门语言都有一个"Hello World"程序,而对数字手写体数据库MNIST ...

  7. 深度学习实战——利用卷积神经网络对手写数字二值图像分类(附代码)

    系列文章目录 深度学习实战--利用卷积神经网络对手写数字二值图像分类(附代码) 目录 系列文章目录 前言 一.案例需求 二.MATLAB算法实现 三.MATLAB源代码 参考文献 前言 本案例利用MA ...

  8. matlab对手写数字聚类的方法_scikitlearn — 聚类

    可以使用模块sklearn.cluster对未标记的数据进行聚类.每个聚类算法都有两种变体:一个是类(class)实现的 fit方法来学习训练数据上的聚类:另一个是函数(function)实现,给定训 ...

  9. 机器学习速成课程 | 练习 | Google Development——编程练习:使用神经网络对手写数字进行分类

    使用神经网络对手写数字进行分类 学习目标: 训练线性模型和神经网络,以对传统 MNIST 数据集中的手写数字进行分类 比较线性分类模型和神经网络分类模型的效果 可视化神经网络隐藏层的权重 我们的目标是 ...

最新文章

  1. oracle12c dml语句缓存,Oracle --DML、DDL、DCL
  2. IIS调用批处理权限的处理
  3. python实例(一)
  4. 了解CMS(Concurrent Mark-Sweep)垃圾回收器
  5. python导入标准库对象的语句_Python项目中如何优雅的import
  6. SpringBoot_数据访问-JDBC自动配置原理
  7. Android开发之LayoutInflater.from(context).inflate()方法参数介绍解决RecyclerView加载布局不全的问题
  8. 最大连续子矩阵和算法
  9. Scala连接mongodb数据库
  10. jmeter跨线程组传多个值_jmeter多用户登录跨线程组操作传值
  11. 一级计算机框线设置为窄线,计算机等级一级MS Office考题:第二套字处理题
  12. 关于电路的书的读后感_通知 | 2021.1.1日起,专利和集成电路布图设计收费启用电子票据...
  13. 永久删除掉qq安全防护进程q盾
  14. 反诈题库---合计100道(解析版最新)
  15. 专业计算机怎么关机,Win10如何使用快捷键来关机?_win10专业版技巧
  16. Spring Boot电商项目54:订单模块三:【前台:订单详情】接口;
  17. webflux之reactor-Subscriber
  18. c#和java部署pytorch同事识别两个图片_Pytorch转NCNN的流程记录
  19. MATALB-结构体
  20. matlab 画梯形,转向梯形优化设计matlab程序

热门文章

  1. H5 WebSQL每日成语
  2. Stock 股票因子
  3. linux shell 判断字符串包含
  4. mysql useradd_一天一个linux基础命令之添加用户useradd
  5. 面试必备,MySQL InnoDB MVCC机制
  6. idea多个项目合并一个窗口
  7. WebRequest 请求
  8. A Survey of Simultaneous Localization and Mapping 论文精读笔记
  9. 你就是孩子最好的玩具——情感引导式教育
  10. MySQL看这一篇就够了