tensorflow综合示例7:LeNet-5实现mnist识别
在本文中,我们使用tensorflow2.x实现了lenet-5,用于mnist的识别。
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from tensorflow import keras
数据预处理
我们先载入mnist数据
(x_train, y_train),(x_test, y_test) = keras.datasets.mnist.load_data()
我们把特征数据增加一个纬度,用于LeNet5的输入:
print(x_train.shape, y_train.shape)
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)
print(x_train.shape, y_train.shape)
(60000, 28, 28) (60000,)
(60000, 28, 28, 1) (60000,)
特征数据归一化:
x_train = x_train/255.0
x_test = x_test/255.0
标签做onehot:
y_train = np.array(pd.get_dummies(y_train))
y_test = np.array(pd.get_dummies(y_test))
构建模型
我们使用sequential构建LeNet-5模型:
model = keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(filters=6, kernel_size=(5,5), input_shape=(28,28,1), padding='same', activation='sigmoid'))
model.add(tf.keras.layers.AveragePooling2D(pool_size=(2,2)))
model.add(tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), padding='valid', activation='sigmoid'))
model.add(tf.keras.layers.AveragePooling2D(pool_size=(2,2)))
model.add(tf.keras.layers.Conv2D(filters=120, kernel_size=(5,5), padding='valid', activation='sigmoid'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(84, activation='sigmoid'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
我们看一下模型的详细情况,包括每一层的输出大小,可训练参数数量,模型的总参数等。
model.summary()
Model: "sequential_9"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_14 (Conv2D) (None, 28, 28, 6) 156
_________________________________________________________________
average_pooling2d_10 (Averag (None, 14, 14, 6) 0
_________________________________________________________________
conv2d_15 (Conv2D) (None, 10, 10, 16) 2416
_________________________________________________________________
average_pooling2d_11 (Averag (None, 5, 5, 16) 0
_________________________________________________________________
conv2d_16 (Conv2D) (None, 1, 1, 120) 48120
_________________________________________________________________
flatten_2 (Flatten) (None, 120) 0
_________________________________________________________________
dense_3 (Dense) (None, 84) 10164
_________________________________________________________________
dense_4 (Dense) (None, 10) 850
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
(60000, 28, 28, 1) (60000, 10) (10000, 28, 28, 1) (10000, 10)
训练模型
history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10)
Epoch 1/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0638 - acc: 0.9805 - val_loss: 0.0618 - val_acc: 0.9801
Epoch 2/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0548 - acc: 0.9832 - val_loss: 0.0515 - val_acc: 0.9830
Epoch 3/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0480 - acc: 0.9851 - val_loss: 0.0727 - val_acc: 0.9763
Epoch 4/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0431 - acc: 0.9870 - val_loss: 0.0420 - val_acc: 0.9864
Epoch 5/10
1875/1875 [==============================] - 9s 5ms/step - loss: 0.0390 - acc: 0.9881 - val_loss: 0.0461 - val_acc: 0.9851
Epoch 6/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0347 - acc: 0.9889 - val_loss: 0.0394 - val_acc: 0.9866
Epoch 7/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0309 - acc: 0.9904 - val_loss: 0.0434 - val_acc: 0.9851
Epoch 8/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0279 - acc: 0.9908 - val_loss: 0.0373 - val_acc: 0.9879
Epoch 9/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0257 - acc: 0.9919 - val_loss: 0.0353 - val_acc: 0.9886
Epoch 10/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0229 - acc: 0.9930 - val_loss: 0.0361 - val_acc: 0.9876
准确率可以打到98%以上。
保存模型
model.save('mnist.h5')
加载模型&预测
我们使用上面的模型对手写数字进行预测
import cv2
img = cv2.imread('3.png', 0)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f602c0a07f0>
img = cv2.resize(img, (28,28))
img = img.reshape(1, 28, 28, 1)
img = img/255.0
my_model = tf.keras.models.load_model('mnist.h5')
predict = my_model.predict(img)
print(predict)
print(np.argmax(predict))
[[4.9507696e-09 7.0097293e-08 1.7773251e-06 9.9997258e-01 3.6114369e-091.9603556e-05 2.9246516e-10 1.3854858e-06 7.9077779e-07 3.7732302e-06]]
3
tensorflow综合示例7:LeNet-5实现mnist识别相关推荐
- tensorflow综合示例3:对结构化数据进行分类:csv keras feature_column
文章目录 1.数据集 1.1 使用 Pandas 从csv创建一个 dataframe 1.2 将 dataframe 拆分为训练.验证和测试集 1.3 用 tf.data 创建输入流水线Datase ...
- tensorflow综合示例1:tensorflow-keras的基本使用方式
import numpy as np import matplotlib.pyplot as plt import pandas as pd import tensorflow as tf from ...
- tensorflow综合示例5:图象分割
本文主要内容来自: https://www.tensorflow.org/tutorials/images/segmentation?hl=zh-cn 图像分割 这篇教程将重点讨论图像分割任务,使用的 ...
- tensorflow综合示例4:逻辑回归:使用Estimator
文章目录 1.加载csv格式的数据集并生成Dataset 1.1 pandas读取csv数据生成Dataframe 1.2 将Dataframe生成Dataset 2.将数据封装成Feature co ...
- tensorflow学习笔记(八):LSTM手写体(MNIST)识别
文章目录 一.LSTM简介 二.主要函数 三.LSTM手写体(MNIST)识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.LSTM简介 LSTM是一种特殊的RNN,很好的解决了RNN中 ...
- tensorflow学习笔记(七):CNN手写体(MNIST)识别
文章目录 一.CNN简介 二.主要函数 三.CNN的手写体识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.CNN简介 一般的卷积神经网络由以下几个层组成:卷积层,池化层,非线性激活函数 ...
- make--变量与函数的综合示例 自动生成依赖关系
一.变量与函数的示例 示例的要求 1.自动生成target文件夹存放可执行文件 2.自动生成objs文件夹存放编译生成的目标文件 3.支持调试版本的编译选项 4.考虑代码的扩展性 完成该示例所需的 1 ...
- C结构体工具DirectStruct(综合示例二)
2019独角兽企业重金招聘Python工程师标准>>> C结构体工具DirectStruct(综合示例二) 1.编写定义文件,用工具dsc处理之,自动生成XML转换代码和ESQL代码 ...
- QT综合示例:QT串口通信
QT综合示例:QT串口通信 0.界面: 1.代码: 如果用qt写程序作为上位机,然后通过和usb和下位机通信的时候,就需要用到qt中的串口通信了. 0.界面: 1.代码: 1).pro 添加: QT ...
最新文章
- CocoaPods远程私有库
- 光电耦合NEC2051 的输入输出特性
- 安卓java修改按钮大小_android弹出activity设置大小的方法
- 是第一个成功设计微型计算机的人,()是第一个成功设计微型计算机的人。
- Dynamics CRM Publisher
- MySQL防止重复插入唯一限制的数据 4种方法
- JScrollPane实现自动滚动到底部
- 数据库的UNDO和REDO
- android日志收集存入mysql_rsyslog+analyzer+mysql实现日志收集展示
- Hadoop IO 文件压缩 序列化
- jQuery选择器总结[转]
- snmp 获取设备类型_SNMP开发系列(三)SNMP Agent的实现
- android程序表白,几条曲线构建Android表白程序
- 注意!恶意NPM包正在安装勒索软件和密码窃取木马
- JavaScript实现随机抽奖功能
- 省市区areacode反查的精简写法
- 新松机器人袁_中科新松许小刚:智能协作机器人是中国机器人产业发展新节点...
- java 地图四色着色算法_趣味地图系列之6 四色定理之我见
- 水利水电课程指导之建筑制图基础_第一章1.3 平面图形的尺寸标注
- 软件的知识产权保护---著作权法及实施条例
热门文章
- 【已解决】Error: could not open `C:\Program Files\Java\jre1.8.0_121\lib\amd64\jvm.cfg‘
- 问题描述: 在一个圆形操场的四周摆放着n 堆石子。现要将石子有次序地合并成一堆。 规定每次只能选相邻的2 堆石子合并成新的一堆,并将新的一堆石子数记为该次合并的得分。 试设计一个算法,计算出将n堆石子
- 9行代码AC——HDU 6857 -Clockwise or Counterclockwise(2020 Multi-University Training Contest 8)(判断三点顺序)
- oracle 存储过程挂起,library cache pin与PROCEDURE的重建
- shell脚本每日一练(二)
- C语言 void和void *(无类型指针)
- 区位码\机器码\内码关系
- cfile 修改某些位_王者荣耀:打野刀效果再次修改,自定义房间配置试运行!
- python3.6字典有序_Python如何按值对字典进行排序?
- 好想学python 怎么猜人物_想自学Python,如何才能坚持下来?