1.测试集数据:
链接:https://pan.baidu.com/s/1YDzwrWvd6dsYaKK7cyvdBQ
提取码:e0j5
第一部分:训练
2.导入相应的库

import os
import keras
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
import skimage.io as io
from keras.callbacks import ModelCheckpoint

2.加载MNIST数据集:

(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data()

print(tf.shape(x_train))
print(tf.shape(y_train))
print(tf.shape(x_test))
print(tf.shape(y_test))


3.增加一个维度:

x_train=tf.reshape(x_train,[-1,28,28,1])
x_test=tf.reshape(x_test,[-1,28,28,1])
print(tf.shape(x_train))
print(tf.shape(x_test))


4.归一化处理:

x_train=x_train/255
x_test=x_test/255

5.进行类别标签:

y_train=tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test=tf.keras.utils.to_categorical(y_test,num_classes=10)

6.定义日志文件:

model_checkpoint=ModelCheckpoint('Lenet5_memBrane.hdf5',monitor='loss',verbose=1,save_best_only=True)

7.定义Lenet5模型:

model=tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(28,28,1)),tf.keras.layers.Conv2D(20,kernel_size=[5,5],strides=[1,1]),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPooling2D(pool_size=[2,2],strides=2,padding='same'),tf.keras.layers.Conv2D(50,kernel_size=[5,5],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPooling2D(pool_size=[2,2],strides=2,padding='same'),tf.keras.layers.Flatten(),tf.keras.layers.Dense(500),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(10),tf.keras.layers.Activation('softmax')
])
model.summary()

8.模型编译:

model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])

9.模型训练:

model.fit(x_train,y_train,epochs=50,batch_size=32,callbacks=[model_checkpoint])
model.save('Lenet5_100.h5')

第二部分:测试
1.加载模型:

model=tf.keras.models.load_model('Lenet5_100.h5')

2.评估:

print(np.shape(x_test))
print(np.shape(y_test))
loss,accuracy=model.evaluate(x_test,y_test)

3.导入测试集数据文件中前十张图片进行预测:

Image_List = []
image_path='Test_Images_Number/'
for i in range(10):path=image_path+str(i)+'.jpg'Image_List.append(path)
print(Image_List)

注意:我这里是将测试集数据文件放在了当前的目录,注意一下这个文件路径就可以了,路径中最好不要有中文名。

import cv2
def load_Image(Image):img=cv2.imread(Image)img=cv2.resize(img,(28,28))print(np.shape(img))image_list=[]for item in img:row=[]
#         print('item: ',np.shape(item)) [28,3]for i in item:#这里的i[0]表示取第一层通道 ,因为我们训练的模型是基于灰度图来的row.append([i[0]])
#             print('i[0]: ',np.shape(i[0])) ()image_list.append(row)array=np.array(image_list)array=array/255image=np.expand_dims(array,axis=0)return image
class_Numbers={0:'0',1:'1',2:'2',3:'3',4:'4',5:'5',6:'6',7:'7',8:'8',9:'9'
}
for Image in Image_List:image=load_Image(Image)print(np.shape(image))predictions=model.predict(image)index_Number=np.argmax(predictions)print('图像预测结果: ',class_Numbers[index_Number])

实验效果:


从实验的效果中可以看出,不是很准确,但是还有很多的改进空间,可以尝试其他的模型,比如VGG13,VGG16等模型.

手写体数字识别(理解起来更简单一点)相关推荐

  1. keras框架下的深度学习(一)手写体数字识别

    文章目录 前言 一.keras的介绍及其操作使用 二.手写题数字识别 1.介绍 2.对数据的预处理 3.搭建网络框架 4.编译 5.循环训练 6.测试训练的网络模 7.总代码 三.附:梯度下降算法 1 ...

  2. Tensorflow 改进的MNIST手写体数字识别

    上篇简单的Tensorflow解决MNIST手写体数字识别可扩展性并不好.例如计算前向传播的函数需要将所有的变量都传入,当神经网络的结构变得复杂.参数更多时,程序的可读性变得非常差.而且这种方式会导致 ...

  3. Tensorflow解决MNIST手写体数字识别

    这里给出的代码是来自<Tensorflow实战Google深度学习框架>,以供参考和学习. 首先这个示例应用了几个基本的方法: 使用随机梯度下降(batch) 使用Relu激活函数去线性化 ...

  4. 基于matlab的手写体数字识别系统

    摘要:随着科学技术的发展,机器学习成为一大学科热门领域,是一门专门研究计算机怎样模拟或实现人类的学习行为的交叉学科.文章在matlab软件的基础上,利用BP神经网络算法完成手写体数字的识别. 机器学习 ...

  5. 基于MATLAB的手写体数字识别算法的实现

    基于MATLAB的手写体数字识别 一.课题介绍 手写数字识别是模式识别领域的一个重要分支,它研究的核心问题是:如何利用计算机自动识别人手写在纸张上的阿拉伯数字.手写体数字识别问题,简而言之就是识别出1 ...

  6. 基于TensorFlow的手写体数字识别

    目录 一.MNIST数据集介绍 二.原理 2.1.卷积神经网络简介( convolutional neural network 简称CNN) 2.1.1卷积运算过程 2.1.2滑动的步长 2.1.3卷 ...

  7. 分类 手写体数字识别

    分类 手写体数字识别 1.数据集 分离训练集和测试集 2.训练一个二分类器 3.评价分类器的性能 使用交叉验证分类准确率 精准率和召回率 混淆矩阵 精准率和召回率的折衷 ROC 曲线 4.多分类器 5 ...

  8. 基于AlexNet卷积神经网络的手写体数字识别系统研究-附Matlab代码

    ⭕⭕ 目 录 ⭕⭕ ✳️ 一.引言 ✳️ 二.手写体数字识别系统 ✳️ 2.1 MNIST 数据集 ✳️ 2.2 CNN ✳️ 2.3 网络训练 ✳️ 三.手写体数字识别结果 ✳️ 四.参考文献 ✳️ ...

  9. Matlab深度学习-手写体数字识别

    Matlab深度学习 文章目录 Matlab深度学习 前言 一.MNIST手写体数字数据 二.用到的深度学习框架-LeNet5 2-0 LeNet5的网络架构 2-1 框架实现-通过Matlab GU ...

最新文章

  1. PostgreSQL SQL 语言:并行查询
  2. windows安装包安装mysql5.7_Windows7 64位压缩包安装MySQL5.7.9
  3. Couchbase 集群小实践
  4. 选择select 标签中指定值的option
  5. 回文字符串—回文子串—Manacher算法
  6. [Winodows Phone 7控件详解]容器控件
  7. python排名上升_TIOBE:2019年7月全球编程语言排行 Python热度继续上升
  8. STL_stack/queue
  9. python如何安装第三方库
  10. 力扣101. 对称二叉树(JavaScript)
  11. 一文读懂 JavaScript 和 Python 九大语义区别
  12. SQL Server 中添加表注释
  13. Ubuntu编译:error: ‘usleep’ was not declared in this scope
  14. ppt文字磨砂玻璃效果制作教程
  15. 推荐一个免费超级好用的简历模板网站
  16. java全景图片切割 全景,基于Three.js实现360度全景图片
  17. Android必知必会-长按返回健退出
  18. 总有你不知道的,你说呢?
  19. Revit教程- Revit中如何控制屋顶的标高
  20. Android环信即时通信集成全过程(含demo)

热门文章

  1. 树的高度(小米2017秋招真题)
  2. 复习计算机网络基础 day7--网络层
  3. C语言,分解质因数一个解法!_只愿与一人十指紧扣_新浪博客
  4. 必看干货:如何在 JavaScript 中实现 8 种基本图形算法
  5. 如何评判一个深度学习框架?
  6. 经验 | 机器学习要避开十大雷区
  7. 卷积神经网络是如何实现不变性特征提取的
  8. 链表问题2——在单链表中删除倒数第K个节点
  9. 报错——StackOverflowError
  10. 我是架构师-设计模式-工厂模式-工厂方法