LeNet诞生于1994年,是最早的卷积神经网络之一,这个网络虽然很小,但是它包含了深度学习的基本模块:卷积层,池化层,全连接层。是其他深度学习模型的基础,这项由Yann LeCun完成的开拓性成果被命名为LeNet。

之前做过HOG特征+SVM手写数字识别
现在再用卷积神经网络来重新做一遍手写数字识别,选用的是LeNet-5结构,pytorch平台,直接贴代码。

一.数据处理

from torch.utils.data import Dataset
import struct
import torch
import numpy as np
class MNISTDataSet(Dataset):def __init__(self, img_path, label_path):with open(label_path, 'rb') as lbpath:magic, n = struct.unpack('>II', lbpath.read(8))self.labels = np.fromfile(lbpath, dtype=np.uint8)with open(img_path, 'rb') as imgpath:magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))# print("magic, num, rows, cols", magic, num, rows, cols)self.images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(self.labels), 784)def __getitem__(self, index):img = self.images[index].reshape(1,28, 28)# 扩展1维,表示1通道img = torch.from_numpy(img).to(torch.float32)label = self.labels[index].astype(np.long)return img, labeldef __len__(self):return len(self.images)

二.网络结构

import torch.nn as nnclass LeNet5(nn.Module):def __init__(self, num_class=3):super().__init__()self.num_class = num_classself.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.relu = nn.ReLU()       # 用relu代替sigmoidself.sf = nn.Softmax(dim=1)self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2, stride=1)self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0, stride=1)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, self.num_class)def forward(self, x):       # batch_size x 1,28,28x = self.conv1(x)       # batch_size x 6,28,28x = self.maxpool(x)     # batch_size x 6,14,14x = self.relu(x)x = self.conv2(x)       # batch_size x 16,10,10x = self.maxpool(x)     # batch_size x 16,5,5x = self.relu(x)x = x.view(x.size(0), -1)   # batch_size x (400)x = self.fc1(x)         # batch_size x (120)x = self.relu(x)x = self.fc2(x)         # batch_size x (84)x = self.relu(x)x = self.fc3(x)         # batch_size x (10)return self.sf(x)

三.训练以及测试

from LeNet import LeNet5
from DataSet import MNISTDataSet
from torch.utils.data import Dataset
import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 100def train():net = LeNet5(num_class=10)net.to(device)optimizer = torch.optim.Adam(net.parameters(), lr=1e-5)     # Adam优化算法是一种对随机梯度下降法的扩展loss = torch.nn.CrossEntropyLoss()  # 交叉熵损失函数train_data_set = MNISTDataSet('train-images.idx3-ubyte','train-labels.idx1-ubyte')train_data_set = torch.utils.data.DataLoader(train_data_set, batch_size = BATCH_SIZE, shuffle=True, num_workers=2)print("训练开始")for epoch_id in range(50):for batch_id, (img, label) in enumerate(train_data_set):img = img.to(device)label = label.to(device)optimizer.zero_grad()       # 清除梯度,grad=0,这个grad是net中的一个变量test_label = net(img)        # 这里的img应该是个四维的张量,100组数据,每组数据是1通道28*28loss_data = loss(test_label, label.long())    # 两个参数的形状不一样loss_data.backward()     # #后向传播,更新参数的过程,经此操作之后,grad中就有求导值了optimizer.step()        # 参考grad中的求导值,更新net中的参数print("Epoch:%d [%d|%d] loss:%f" % (epoch_id, batch_id, len(train_data_set), loss_data))torch.save(net.state_dict(), 'model.pth')print("训练结束")def test():preNet = LeNet5(num_class=10)preNet.load_state_dict(torch.load("model.pth"))test_data_set = MNISTDataSet('t10k-images.idx3-ubyte', 't10k-labels.idx1-ubyte')test_data_set = torch.utils.data.DataLoader(test_data_set, batch_size = BATCH_SIZE, shuffle=True, num_workers=2)allYes = 0print("测试开始")for batch_idx, (img, label) in enumerate(test_data_set):pre_result = preNet(img)     # 100组,每组十个概率pre_label = torch.argmax(pre_result,dim=1)print(batch_idx, "/", len(test_data_set), end=" ")for idx in range(len(pre_label)):if label[idx] == pre_label[idx]:allYes += 1print("√", end=" ")else:print("×", end=" ")print(" ")print("acc:", allYes / test_data_set.__len__())if __name__ == '__main__':train()test()

四.结果

深度学习-2.1 LeNet5-手写字体识别相关推荐

  1. 深度学习(4)手写数字识别实战

    深度学习(4)手写数字识别实战 Step0. 数据及模型准备 1. X and Y(数据准备) 2. out=relu{relu{relu[X@W1+b1]@W2+b2}@W3+b3}out=relu ...

  2. 深度学习(3)手写数字识别问题

    深度学习(3)手写数字识别问题 1. 问题归类 2. 数据集 3. Image 4. Input and Output 5. Regression VS Classification 6. Compu ...

  3. 深度学习 卷积神经网络-Pytorch手写数字识别

    深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...

  4. 利用python卷积神经网络手写数字识别_Keras深度学习:卷积神经网络手写数字识别...

    引言:最近在闭关学习中,由于多久没有写博客了,今天给大家带来学习的一些内容,还在学习神经网络的同学,跑一跑下面的代码,给你一些自信吧!Nice 奥里给! 正文:首先该impor的库就不多说了,不会的就 ...

  5. python手写字体程序_深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...

  6. 深度学习,实现手写字体识别(大数据人工智能公司)

    手写字体识别是指给定一系列的手写字体图片以及对应的标签,构建模型进行学习,目标是对于一张新的手写字体图片能够自动识别出对应的文字或数字.通过深度学习构建普通神经网络和卷积神经网络,处理手写字体数据.通 ...

  7. 神经网络学习(二)Tensorflow-简单神经网络(全连接层神经网络)实现手写字体识别

    神经网络学习(二)神经网络-手写字体识别 框架:Tensorflow 1.10.0 数据集:mnist数据集 策略:交叉熵损失 优化:梯度下降 五个模块:拿数据.搭网络.求损失.优化损失.算准确率 一 ...

  8. 【PyTorch学习笔记_04】--- PyTorch(开始动手操作_案例1:手写字体识别)

    手写字体识别的流程 定义超参数(自己定义的参数) 构建transforms, 主要是对图像做变换 下载,加载数据集MNIST 构建网络模型(重要,自己定义) 定义训练方法 定义测试方法 开始训练模型, ...

  9. 第六讲 Keras实现手写字体识别分类

    一 本节课程介绍 1.1 知识点 1.图像识别分类相关介绍: 2.Mnist手写数据集介绍: 3.标准化数据预处理: 4.实验手写字体识别 二 课程内容 2.1 图像识别分类基本介绍 计算机的图像识别 ...

  10. numpy完成手写字体识别(机器学习作业02)

    numpy完成手写字体识别(机器学习02) 参考代码:mnielsen/neural-networks-and-deep-learning: 参考讲解:深度学习多分类任务的损失函数详解 - 知乎 (z ...

最新文章

  1. linux脚本获取当前用户,bash shell 获取当前正在执行脚本的绝对路径
  2. 动态规划-换钱最少货币数
  3. Angular单元测试遇到的错误消息:Uncaught Error - Cannot find module tslib
  4. 那些拧不开瓶盖的女生全都是装的?理工男这样想......
  5. 李洋疯狂C语言之用递归解决李白喝酒问题(附填空题解法)
  6. Ubuntu 18.04.1 搭建Java环境和HelloWorld 1
  7. php nowdoc用来做什么,PHP中nowdoc和heredoc使用需要注意的一点
  8. 2015-2020年各类国际会议与期刊基于图像的三维对象重建论文综述(6)——Training
  9. Python说文解字_半成品再加工
  10. php生成GIF动态验证码图片(代码家园)
  11. 挖矿木马 sustes 追踪溯源分析
  12. SQL Server 聚合函数 (方差和标准差)
  13. 领域驱动设计(DDD)入门概要
  14. HNUST Java 数据库系统课程设计:学生管理系统
  15. VMware NSX 4.0 -- 网络安全虚拟化平台
  16. 海底光缆是如何铺设出来的?
  17. 广东指导晚造水稻工作 国稻种芯·中国水稻节:惠州加强防治
  18. 2020年UI设计行业的就业状况如何?
  19. netlogo建模案例_NetLogo用于科研:建模
  20. 动态磁盘和基本磁盘转换

热门文章

  1. 流利阅读day2 Anti-vaccination
  2. java毕业设计校园商铺mybatis+源码+调试部署+系统+数据库+lw
  3. ssm+JSP计算机毕业设计中国瑰宝——戏曲赏析网92n88【源码、程序、数据库、部署】
  4. 软件工程中理想团队模式构建的设想与软件流程的理解
  5. 物理笔记(1)--波函数中x/u符号的小思考
  6. MATLAB R2014a Builder JA总结
  7. DHCP Options Classless static route 121/249
  8. 小米手环5NFC 门卡模拟 使用CH340模拟加密卡
  9. 导出Word内容全部乱码问题
  10. 日本浮世绘的艺术配色