Tensorflow2.0:实战LeNet-5识别MINIST数据集
LeNet-5模型
1990 年代提出的LeNet-5使卷积神经网络在当时成功商用,下图是 LeNet-5 的网络结构图,它接受32 × 32大小的数字、字符图片,这次将LeNet-5模型用来识别MINIST数据集中的数字,并在测试集中计算其识别准确率。
根据上图的网络结构,可以得出下图的模型结构图:
完整代码示例
第一部分:数据集的加载与预处理
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets # 导入经典数据集加载模块
# 加载 MNIST 数据集
(x, y), (x_test, y_test) = datasets.mnist.load_data() # 返回数组的形状
# 将数据集转换为DataSet对象,不然无法继续处理
train_db = tf.data.Dataset.from_tensor_slices((x, y))
# 将数据顺序打散
train_db = train_db.shuffle(10000) # 数字为缓冲池的大小
# 设置批训练
train_db = train_db.batch(512) # batch size 为 128
#预处理函数
def preprocess(x, y): # 输入x的shape 为[b, 32, 32], y为[b]# 将像素值标准化到 0~1区间x = tf.cast(x, dtype=tf.float32) / 255.# 将图片改为28*28大小的x = tf.reshape(x, [-1, 28 * 28])y = tf.cast(y, dtype=tf.int32) # 转成整型张量y = tf.one_hot(y, depth=10)return x, y
# 将数据集传入预处理函数,train_db支持map映射函数
train_db = train_db.map(preprocess)
# 训练20个epoch
train_db = train_db.repeat(20)
# 以同样的方式处理测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(1000).batch(512).map(preprocess)
第二部分:构建LeNet-5模型
from tensorflow.keras import Sequential
from tensorflow.keras import layers,losses, optimizers
def main():network = Sequential([layers.Conv2D(6, kernel_size=3, strides=1), # 第一个卷积层, 6 个 3x3 卷积核layers.MaxPooling2D(pool_size=2, strides=2), # 高宽各减半的池化层layers.ReLU(), # 激活函数layers.Conv2D(16, kernel_size=3, strides=1), # 第二个卷积层, 16 个 3x3 卷积核layers.MaxPooling2D(pool_size=2, strides=2), # 高宽各减半的池化层layers.ReLU(), # 激活函数layers.Flatten(), # 打平层,形成一维向量,方便全连接层处理layers.Dense(120, activation='relu'), # 全连接层, 120 个节点layers.Dense(84, activation='relu'), # 全连接层, 84 节点layers.Dense(10) # 全连接层, 10 个节点])# 构建网络模型,给输入X的形状,其中4为随意的BatchSizenetwork.build(input_shape=(4, 28, 28, 1))# 统计网络信息# print(network.summary())optimizer = optimizers.Adam(lr=1e-4)loss_all = []# 创建损失函数的类,在实际计算时直接调用类实例即可criteon = losses.CategoricalCrossentropy(from_logits=True)for step, (x, y) in enumerate(train_db):# 将输入张量x的shape[512.784]变成[x = tf.reshape(x, (-1, 28, 28))with tf.GradientTape() as tape:# 插入通道维度,=>[b,28,28,1]x = tf.expand_dims(x, axis=3)# 前向计算,获得10类别的预测分布,[b, 784] => [b, 10]out = network(x)# 将真实标签转化为one-hot编码,[b] => [b, 10]# 计算交叉熵损失函数,标量loss = criteon(y, out)# 自动计算梯度,关键看如何表示待优化变量grads = tape.gradient(loss, network.trainable_variables)# 自动更新参数optimizer.apply_gradients(zip(grads, network.trainable_variables))# step为80次时,记录并输出损失函数结果if step % 100 == 0:print(step, 'loss:', float(loss))loss_all.append(float(loss))# step为80次时,用测试集验证模型if step % 100 == 0:total, total_correct = 0., 0correct, total = 0, 0for x, y in test_db: # 遍历所有训练集样本# 插入通道维度,=>[b,28,28,1]x = tf.reshape(x, (-1, 28, 28))x = tf.expand_dims(x, axis=3)# 前向计算,获得10类别的预测分布,[b, 784] => [b, 10]out = network(x)# 真实的流程时先经过softmax,再argmax# 但是由于softmax不改变元素的大小相对关系,故省去pred = tf.argmax(out, axis=-1)y = tf.cast(y, tf.int64)y = tf.argmax(y, axis=-1)# 统计预测正确数量correct += float(tf.reduce_sum(tf.cast(tf.equal(pred, y), tf.float32)))# 统计预测样本总数total += x.shape[0]# 计算准确率print('test acc:', correct / total)
第三部分 测试结果
经过2300step迭代训练后,识别数字的准确率达到了97.68%
Tensorflow2.0:实战LeNet-5识别MINIST数据集相关推荐
- Tensorflow2.0实战练习之猫狗数据集(包含自定义训练和迁移学习)
最近在学习使用Tenforflow2.0,写下这篇文章,用来帮助和我一样的初学者,文章中如果存在某些问题,还希望各位指出. 目录 数据集介绍 数据处理及增强 VGG模型介绍 模型搭建 训练及结果展示 ...
- 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战
基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕实战 基于opencv tensorflow2.0开发的人脸识别锁定与解锁win10屏幕 前言 运行python环境 ...
- 基于tensorflow2.0+opencv的花卉识别系统源码(含数据集)
花卉识别-基于tensorflow2.3实现 完整代码下载地址:基于tensorflow2.0+opencv的花卉识别系统源码( 文件目录 # 数据下载地址 https://storage.googl ...
- 第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/LQ_qing/article/deta ...
- 笔记3:Tensorflow2.0实战之MNSIT数据集
最近Tensorflow相继推出了alpha和beta两个版本,这两个都属于tensorflow2.0版本:早听说新版做了很大的革新,今天就来用一下看看 这里还是使用MNSIT数据集进行测试 导入必要 ...
- 数据挖掘实战(9.5)--使用神经网络识别MINIST数据集
一.minist数据集 minist数据集分为两个部分,训练集和测试集,然后在不同的集合中分为两个文件,数据Images文件和Labels文件.在数据集中一个有60000个训练数据和10000个测试数 ...
- 基于TensorFlow2.0的摄像头数字识别
import numpy as np import cv2 from skimage import data, segmentation, measure, morphology, color imp ...
- 基于 opencv tenserflow2.0 实战 CNN 人脸识别锁定与解锁 win10 屏幕
代码参考:https://download.csdn.net/download/weixin_55771290/87430422 前言 windows hello 的低阶板本,没有 Windows h ...
- 神经网络与深度学习——TensorFlow2.0实战(笔记)(二)(安装TensorFlow2.0)
创建环境并激活 conda create --name tensorflow2.0 python==3.7 activate tensorflow2.0 安装相关软件包(conda命令或pip命令2选 ...
最新文章
- Metaspace 引起的 FullGC 问题排查过程及解决方案
- Bootstrap -- 插件: 按钮状态、折叠样式、轮播样式
- 在Entity Framework中使用存储过程(一):实现存储过程的自动映射
- 【PC工具】好用的搜索引擎DogeDoge替代百度搜索,中国的duckduckgo
- 关于文件保存/关闭时报错:文件正由另一进程使用,因此该进程无法访问此文件。...
- 美国政府已关闭 5800 个数据中心,计划关闭 1400 个
- Pandas 操作 csv 文件
- Linux文件操作实用笔记
- ASP.NET Core分布式项目实战(客户端集成IdentityServer)--学习笔记
- 使用表中的数组数据类型
- linux内核有ebpf吗,聊聊很重要的内核技术eBPF
- client_hello_cb、get_session_cb、servername_cb、cert_cb
- 返回零长度的数组或者集合,而不是null
- 免费MD5破解、在线查询网站
- html是什么文本文件,纯文本文件是什么意思
- IDEA 导入项目 导入不进去
- LCD/OLED显示产品从新品导入量产的线体认证策划
- HTML_旅行志界面
- CVPR2021 视频目标检测论文推荐
- 数据结构-C语言代码 day6-栈及其应用