LeNet-5模型

1990 年代提出的LeNet-5使卷积神经网络在当时成功商用,下图是 LeNet-5 的网络结构图,它接受32 × 32大小的数字、字符图片,这次将LeNet-5模型用来识别MINIST数据集中的数字,并在测试集中计算其识别准确率。

根据上图的网络结构,可以得出下图的模型结构图:

完整代码示例

第一部分:数据集的加载与预处理

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets  # 导入经典数据集加载模块
# 加载 MNIST 数据集
(x, y), (x_test, y_test) = datasets.mnist.load_data()   # 返回数组的形状
# 将数据集转换为DataSet对象,不然无法继续处理
train_db = tf.data.Dataset.from_tensor_slices((x, y))
# 将数据顺序打散
train_db = train_db.shuffle(10000)  # 数字为缓冲池的大小
# 设置批训练
train_db = train_db.batch(512)  # batch size 为 128
#预处理函数
def preprocess(x, y):   # 输入x的shape 为[b, 32, 32], y为[b]# 将像素值标准化到 0~1区间x = tf.cast(x, dtype=tf.float32) / 255.# 将图片改为28*28大小的x = tf.reshape(x, [-1, 28 * 28])y = tf.cast(y, dtype=tf.int32)  # 转成整型张量y = tf.one_hot(y, depth=10)return x, y
# 将数据集传入预处理函数,train_db支持map映射函数
train_db = train_db.map(preprocess)
# 训练20个epoch
train_db = train_db.repeat(20)
# 以同样的方式处理测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(1000).batch(512).map(preprocess)

第二部分:构建LeNet-5模型

from tensorflow.keras import Sequential
from tensorflow.keras import layers,losses, optimizers
def main():network = Sequential([layers.Conv2D(6, kernel_size=3, strides=1),  # 第一个卷积层, 6 个 3x3 卷积核layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层layers.ReLU(),  # 激活函数layers.Conv2D(16, kernel_size=3, strides=1),  # 第二个卷积层, 16 个 3x3 卷积核layers.MaxPooling2D(pool_size=2, strides=2),  # 高宽各减半的池化层layers.ReLU(),  # 激活函数layers.Flatten(),  # 打平层,形成一维向量,方便全连接层处理layers.Dense(120, activation='relu'),  # 全连接层, 120 个节点layers.Dense(84, activation='relu'),  # 全连接层, 84 节点layers.Dense(10)  # 全连接层, 10 个节点])# 构建网络模型,给输入X的形状,其中4为随意的BatchSizenetwork.build(input_shape=(4, 28, 28, 1))# 统计网络信息# print(network.summary())optimizer = optimizers.Adam(lr=1e-4)loss_all = []# 创建损失函数的类,在实际计算时直接调用类实例即可criteon = losses.CategoricalCrossentropy(from_logits=True)for step, (x,  y) in enumerate(train_db):# 将输入张量x的shape[512.784]变成[x = tf.reshape(x, (-1, 28, 28))with tf.GradientTape() as tape:# 插入通道维度,=>[b,28,28,1]x = tf.expand_dims(x, axis=3)# 前向计算,获得10类别的预测分布,[b, 784] => [b, 10]out = network(x)# 将真实标签转化为one-hot编码,[b] => [b, 10]# 计算交叉熵损失函数,标量loss = criteon(y, out)# 自动计算梯度,关键看如何表示待优化变量grads = tape.gradient(loss, network.trainable_variables)# 自动更新参数optimizer.apply_gradients(zip(grads, network.trainable_variables))# step为80次时,记录并输出损失函数结果if step % 100 == 0:print(step, 'loss:', float(loss))loss_all.append(float(loss))# step为80次时,用测试集验证模型if step % 100 == 0:total, total_correct = 0., 0correct, total = 0, 0for x, y in test_db:  # 遍历所有训练集样本# 插入通道维度,=>[b,28,28,1]x = tf.reshape(x, (-1, 28, 28))x = tf.expand_dims(x, axis=3)# 前向计算,获得10类别的预测分布,[b, 784] => [b, 10]out = network(x)# 真实的流程时先经过softmax,再argmax# 但是由于softmax不改变元素的大小相对关系,故省去pred = tf.argmax(out, axis=-1)y = tf.cast(y, tf.int64)y = tf.argmax(y, axis=-1)# 统计预测正确数量correct += float(tf.reduce_sum(tf.cast(tf.equal(pred, y), tf.float32)))# 统计预测样本总数total += x.shape[0]# 计算准确率print('test acc:', correct / total)

第三部分 测试结果

经过2300step迭代训练后,识别数字的准确率达到了97.68%

Tensorflow2.0:实战LeNet-5识别MINIST数据集相关推荐

  1. Tensorflow2.0实战练习之猫狗数据集(包含自定义训练和迁移学习)

    最近在学习使用Tenforflow2.0,写下这篇文章,用来帮助和我一样的初学者,文章中如果存在某些问题,还希望各位指出. 目录 数据集介绍 数据处理及增强 VGG模型介绍 模型搭建 训练及结果展示 ...

  2. 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战

    基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕 前言 运行python环境 ...

  3. 基于tensorflow2.0+opencv的花卉识别系统源码(含数据集)

    花卉识别-基于tensorflow2.3实现 完整代码下载地址:基于tensorflow2.0+opencv的花卉识别系统源码( 文件目录 # 数据下载地址 https://storage.googl ...

  4. 第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/LQ_qing/article/deta ...

  5. 笔记3:Tensorflow2.0实战之MNSIT数据集

    最近Tensorflow相继推出了alpha和beta两个版本,这两个都属于tensorflow2.0版本:早听说新版做了很大的革新,今天就来用一下看看 这里还是使用MNSIT数据集进行测试 导入必要 ...

  6. 数据挖掘实战(9.5)--使用神经网络识别MINIST数据集

    一.minist数据集 minist数据集分为两个部分,训练集和测试集,然后在不同的集合中分为两个文件,数据Images文件和Labels文件.在数据集中一个有60000个训练数据和10000个测试数 ...

  7. 基于TensorFlow2.0的摄像头数字识别

    import numpy as np import cv2 from skimage import data, segmentation, measure, morphology, color imp ...

  8. 基于 opencv tenserflow2.0 实战 CNN 人脸识别锁定与解锁 win10 屏幕

    代码参考:https://download.csdn.net/download/weixin_55771290/87430422 前言 windows hello 的低阶板本,没有 Windows h ...

  9. 神经网络与深度学习——TensorFlow2.0实战(笔记)(二)(安装TensorFlow2.0)

    创建环境并激活 conda create --name tensorflow2.0 python==3.7 activate tensorflow2.0 安装相关软件包(conda命令或pip命令2选 ...

最新文章

  1. Metaspace 引起的 FullGC 问题排查过程及解决方案
  2. Bootstrap -- 插件: 按钮状态、折叠样式、轮播样式
  3. 在Entity Framework中使用存储过程(一):实现存储过程的自动映射
  4. 【PC工具】好用的搜索引擎DogeDoge替代百度搜索,中国的duckduckgo
  5. 关于文件保存/关闭时报错:文件正由另一进程使用,因此该进程无法访问此文件。...
  6. 美国政府已关闭 5800 个数据中心,计划关闭 1400 个
  7. Pandas 操作 csv 文件
  8. Linux文件操作实用笔记
  9. ASP.NET Core分布式项目实战(客户端集成IdentityServer)--学习笔记
  10. 使用表中的数组数据类型
  11. linux内核有ebpf吗,聊聊很重要的内核技术eBPF
  12. client_hello_cb、get_session_cb、servername_cb、cert_cb
  13. 返回零长度的数组或者集合,而不是null
  14. 免费MD5破解、在线查询网站
  15. html是什么文本文件,纯文本文件是什么意思
  16. IDEA 导入项目 导入不进去
  17. LCD/OLED显示产品从新品导入量产的线体认证策划
  18. HTML_旅行志界面
  19. CVPR2021 视频目标检测论文推荐
  20. 数据结构-C语言代码 day6-栈及其应用

热门文章

  1. 以瓴羊QuickBI为例,教你如何制作在线电子表格
  2. 论文超详细精读|八千字:AS-GCN
  3. 511遇见易语言API模块视进入许可证(EnterCriticalSection)
  4. ​2022年护士资格证考试实践能力模拟考习题及答案
  5. 做什么类型的网站容易挂google广告容易取得长久收入
  6. 微信小程序分享网络路径图片
  7. 事件绑定(onmouseout,onmouseover)
  8. ABB利用官方API二次开发之控制信号
  9. 【华安php入门系列】--第2天-php的变量
  10. react放大镜组件