在本文中,我们使用tensorflow2.x实现了lenet-5,用于mnist的识别。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from tensorflow import keras

数据预处理

我们先载入mnist数据

(x_train, y_train),(x_test, y_test) = keras.datasets.mnist.load_data()

我们把特征数据增加一个纬度,用于LeNet5的输入:

print(x_train.shape, y_train.shape)
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)
print(x_train.shape, y_train.shape)
(60000, 28, 28) (60000,)
(60000, 28, 28, 1) (60000,)

特征数据归一化:

x_train = x_train/255.0
x_test = x_test/255.0

标签做onehot:

y_train = np.array(pd.get_dummies(y_train))
y_test = np.array(pd.get_dummies(y_test))

构建模型

我们使用sequential构建LeNet-5模型:

model = keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(filters=6, kernel_size=(5,5), input_shape=(28,28,1), padding='same', activation='sigmoid'))
model.add(tf.keras.layers.AveragePooling2D(pool_size=(2,2)))
model.add(tf.keras.layers.Conv2D(filters=16, kernel_size=(5,5), padding='valid', activation='sigmoid'))
model.add(tf.keras.layers.AveragePooling2D(pool_size=(2,2)))
model.add(tf.keras.layers.Conv2D(filters=120, kernel_size=(5,5), padding='valid', activation='sigmoid'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(84, activation='sigmoid'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))

我们看一下模型的详细情况,包括每一层的输出大小,可训练参数数量,模型的总参数等。

model.summary()
Model: "sequential_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_14 (Conv2D)           (None, 28, 28, 6)         156
_________________________________________________________________
average_pooling2d_10 (Averag (None, 14, 14, 6)         0
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 10, 10, 16)        2416
_________________________________________________________________
average_pooling2d_11 (Averag (None, 5, 5, 16)          0
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 1, 1, 120)         48120
_________________________________________________________________
flatten_2 (Flatten)          (None, 120)               0
_________________________________________________________________
dense_3 (Dense)              (None, 84)                10164
_________________________________________________________________
dense_4 (Dense)              (None, 10)                850
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
_________________________________________________________________
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
(60000, 28, 28, 1) (60000, 10) (10000, 28, 28, 1) (10000, 10)

训练模型

history = model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10)
Epoch 1/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0638 - acc: 0.9805 - val_loss: 0.0618 - val_acc: 0.9801
Epoch 2/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0548 - acc: 0.9832 - val_loss: 0.0515 - val_acc: 0.9830
Epoch 3/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0480 - acc: 0.9851 - val_loss: 0.0727 - val_acc: 0.9763
Epoch 4/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0431 - acc: 0.9870 - val_loss: 0.0420 - val_acc: 0.9864
Epoch 5/10
1875/1875 [==============================] - 9s 5ms/step - loss: 0.0390 - acc: 0.9881 - val_loss: 0.0461 - val_acc: 0.9851
Epoch 6/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0347 - acc: 0.9889 - val_loss: 0.0394 - val_acc: 0.9866
Epoch 7/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0309 - acc: 0.9904 - val_loss: 0.0434 - val_acc: 0.9851
Epoch 8/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0279 - acc: 0.9908 - val_loss: 0.0373 - val_acc: 0.9879
Epoch 9/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0257 - acc: 0.9919 - val_loss: 0.0353 - val_acc: 0.9886
Epoch 10/10
1875/1875 [==============================] - 10s 5ms/step - loss: 0.0229 - acc: 0.9930 - val_loss: 0.0361 - val_acc: 0.9876

准确率可以打到98%以上。

保存模型

model.save('mnist.h5')

加载模型&预测

我们使用上面的模型对手写数字进行预测

import cv2
img = cv2.imread('3.png', 0)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f602c0a07f0>

img = cv2.resize(img, (28,28))
img = img.reshape(1, 28, 28, 1)
img = img/255.0
my_model = tf.keras.models.load_model('mnist.h5')
predict = my_model.predict(img)
print(predict)
print(np.argmax(predict))
[[4.9507696e-09 7.0097293e-08 1.7773251e-06 9.9997258e-01 3.6114369e-091.9603556e-05 2.9246516e-10 1.3854858e-06 7.9077779e-07 3.7732302e-06]]
3

tensorflow综合示例7:LeNet-5实现mnist识别相关推荐

  1. tensorflow综合示例3:对结构化数据进行分类:csv keras feature_column

    文章目录 1.数据集 1.1 使用 Pandas 从csv创建一个 dataframe 1.2 将 dataframe 拆分为训练.验证和测试集 1.3 用 tf.data 创建输入流水线Datase ...

  2. tensorflow综合示例1:tensorflow-keras的基本使用方式

    import numpy as np import matplotlib.pyplot as plt import pandas as pd import tensorflow as tf from ...

  3. tensorflow综合示例5:图象分割

    本文主要内容来自: https://www.tensorflow.org/tutorials/images/segmentation?hl=zh-cn 图像分割 这篇教程将重点讨论图像分割任务,使用的 ...

  4. tensorflow综合示例4:逻辑回归:使用Estimator

    文章目录 1.加载csv格式的数据集并生成Dataset 1.1 pandas读取csv数据生成Dataframe 1.2 将Dataframe生成Dataset 2.将数据封装成Feature co ...

  5. tensorflow学习笔记(八):LSTM手写体(MNIST)识别

    文章目录 一.LSTM简介 二.主要函数 三.LSTM手写体(MNIST)识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.LSTM简介 LSTM是一种特殊的RNN,很好的解决了RNN中 ...

  6. tensorflow学习笔记(七):CNN手写体(MNIST)识别

    文章目录 一.CNN简介 二.主要函数 三.CNN的手写体识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.CNN简介 一般的卷积神经网络由以下几个层组成:卷积层,池化层,非线性激活函数 ...

  7. make--变量与函数的综合示例 自动生成依赖关系

    一.变量与函数的示例 示例的要求 1.自动生成target文件夹存放可执行文件 2.自动生成objs文件夹存放编译生成的目标文件 3.支持调试版本的编译选项 4.考虑代码的扩展性 完成该示例所需的 1 ...

  8. C结构体工具DirectStruct(综合示例二)

    2019独角兽企业重金招聘Python工程师标准>>> C结构体工具DirectStruct(综合示例二) 1.编写定义文件,用工具dsc处理之,自动生成XML转换代码和ESQL代码 ...

  9. QT综合示例:QT串口通信

    QT综合示例:QT串口通信 0.界面: 1.代码: 如果用qt写程序作为上位机,然后通过和usb和下位机通信的时候,就需要用到qt中的串口通信了. 0.界面: 1.代码: 1).pro 添加: QT ...

最新文章

  1. CocoaPods远程私有库
  2. 光电耦合NEC2051 的输入输出特性
  3. 安卓java修改按钮大小_android弹出activity设置大小的方法
  4. 是第一个成功设计微型计算机的人,()是第一个成功设计微型计算机的人。
  5. Dynamics CRM Publisher
  6. MySQL防止重复插入唯一限制的数据 4种方法
  7. JScrollPane实现自动滚动到底部
  8. 数据库的UNDO和REDO
  9. android日志收集存入mysql_rsyslog+analyzer+mysql实现日志收集展示
  10. Hadoop IO 文件压缩 序列化
  11. jQuery选择器总结[转]
  12. snmp 获取设备类型_SNMP开发系列(三)SNMP Agent的实现
  13. android程序表白,几条曲线构建Android表白程序
  14. 注意!恶意NPM包正在安装勒索软件和密码窃取木马
  15. JavaScript实现随机抽奖功能
  16. 省市区areacode反查的精简写法
  17. 新松机器人袁_中科新松许小刚:智能协作机器人是中国机器人产业发展新节点...
  18. java 地图四色着色算法_趣味地图系列之6 四色定理之我见
  19. 水利水电课程指导之建筑制图基础_第一章1.3 平面图形的尺寸标注
  20. 软件的知识产权保护---著作权法及实施条例

热门文章

  1. 【已解决】Error: could not open `C:\Program Files\Java\jre1.8.0_121\lib\amd64\jvm.cfg‘
  2. 问题描述: 在一个圆形操场的四周摆放着n 堆石子。现要将石子有次序地合并成一堆。 规定每次只能选相邻的2 堆石子合并成新的一堆,并将新的一堆石子数记为该次合并的得分。 试设计一个算法,计算出将n堆石子
  3. 9行代码AC——HDU 6857 -Clockwise or Counterclockwise(2020 Multi-University Training Contest 8)(判断三点顺序)
  4. oracle 存储过程挂起,library cache pin与PROCEDURE的重建
  5. shell脚本每日一练(二)
  6. C语言 void和void *(无类型指针)
  7. 区位码\机器码\内码关系
  8. cfile 修改某些位_王者荣耀:打野刀效果再次修改,自定义房间配置试运行!
  9. python3.6字典有序_Python如何按值对字典进行排序?
  10. 好想学python 怎么猜人物_想自学Python,如何才能坚持下来?