EMNIST 数据集是一个包含手写字母,数字的数据集,它具有和MNIST相同的数据格式。The EMNIST Dataset | NIST

  1. 引用模块介绍:
import tensorflow as tfimport mnistfrom tensorflow.keras import datasets, layers, modelsimport numpy as npimport matplotlib.pyplot as pltimport gzip,os

其中要注意的是,tensorflow和keras和numpy的版本一定要对应,如果不对应就无法正常引用,python版本也不能太新,3.6到3.7最佳,如果python版本不满足,可以安装anaconda,在anconda prompt中创建虚拟环境,让其中的python=3.6.5即可

Tensorflow=2.3.1  numpy=1.19.5  keras=2.4.3  这是一种可行的库的版本

2.1首先导入数据集和可视化

路径最好继续用相对路径,这里的路径需要根据自己的文件路径进行修改

emnist数据集可以到官网上下载

def load_mnist(path):# 放置mnist.py的目录。注意斜杠f = np.load(path)x_train, y_train = f['x_train'], f['y_train']x_test, y_test = f['x_test'], f['y_test']f.close()return (x_train, y_train), (x_test, y_test)
def mnist_parse_file(fname):fopen = gzip.open if os.path.splitext(fname)[1] == '.gz' else openwith fopen(fname, 'rb') as fd:return mnist.parse_idx(fd)train_images = mnist_parse_file(".\\Dataset\\emnist-letters-train-images-idx3-ubyte.gz")
train_labels = mnist_parse_file(".\\Dataset\\emnist-letters-train-labels-idx1-ubyte.gz")
test_images = mnist_parse_file(".\\Dataset\\emnist-letters-test-images-idx3-ubyte.gz")
test_labels = mnist_parse_file(".\\Dataset\\emnist-letters-test-labels-idx1-ubyte.gz")

显示训练集的第6张图片

2.2 神经网络模型

首先查看训练集和测试集的大小,为后面的步骤做准备

#查看各集合大小
print(len(train_images),len(train_labels),len(test_images),len(test_labels))
print(test_images[0].shape)

训练集的大小是124800,测试集的大小是20800

接下来进行神经网络模型的构建

# 初始化序列模型   神经网络
model = models.Sequential()# 一层隐含层,92个神经元
model.add(layers.Dense(92,input_shape=[784]))
#第二层隐含层,92个神经元,激活函数为relu
model.add(layers.Dense(92, activation='relu'))
#第三层隐含层,92个神经元
model.add(layers.Dense(92, activation='relu'))
# 输出层,对应a-z
model.add(layers.Dense(27, activation='softmax'))#另一种创建model的方法
#model = tf.keras.models.Sequential([
#  tf.keras.layers.Dense(128, input_shape=[784]),
#  tf.keras.layers.Dense(40, activation='relu'),
#  tf.keras.layers.Dense(10, activation='softmax')
#])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.summary()

这里需要注意的是一个问题,如何确定神经网络的隐含层数以及每层的神经元数目?

这里附上一个在stackoverflow上看到的经验公式:

经过计算,神经元的数目最好不要超过157,大致在75左右为佳,但是经过多次实验,发现在92个神经元的时候效果比较好

激活函数使用relu,效果比较好

输出层要设置27个输出节点,原因是标签到达26,而如果设置26个节点,接受的标签范围是[0,26),不包括26。

迭代次数设置20次即可,过多的次数并不会提高准确性

x_train = train_images.reshape(-1, 784)
x_test = test_images.reshape(-1, 784)
x_train, x_test = x_train / 255.0, x_test / 255.0history = model.fit(x_train,train_labels,epochs=20,validation_data=(x_test,test_labels))
model.save("emnist_ann.model")

其中,x_train和x_test除以255是为了将变量归一化,从而降低计算量

执行结果如下:

最终测试的预测准确率在89.09%

2.3 卷积神经网络

建立卷积神经网络模型

# 初始化序列模型   卷积神经网络
model = models.Sequential()# 添加第一层卷积层,用32个3*3的卷积核,激活函数选择‘relu’,输入层(input_shape)# 是28*28(MNIST每一张图片的尺寸)后面的‘1’是图片的颜色数,MNIST是灰度图因# 此选1
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28,1)))
# 加一层最大池化,池化窗口2*2
model.add(layers.MaxPooling2D((2, 2)))
# 加一层卷积层,16个2*2的卷积核,之后的input_shape都是自动的。
model.add(layers.Conv2D(16,(2,2), activation='relu'))
# 加一层最大池化,池化窗口2*2
model.add(layers.MaxPooling2D((2,2)))
#加一层卷积层,8个3*3的卷积核
model.add(layers.Conv2D(8,(3,3), activation='relu'))
# 将卷积后的矩阵展开,这就是全连接层的第一层
model.add(layers.Flatten())
# 再加一层全连接,80个神经元,激活函数为relu
model.add(layers.Dense(80, activation='relu'))
# 输出层,对应a-z
model.add(layers.Dense(27))model.summary()model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# fit,里面要解释的参数只有epochs和batch_size,epochs是用全部训练集的用例训练几次
# 模型的意思,为什么要用同样的数据集重复训练多次模型呢(这里举个恰当的例子),
# batch_size是每次迭代用到几个用例。
history = model.fit(train_images.reshape(124800, 28, 28, 1), train_labels, epochs=10, batch_size=32, validation_data=(test_images.reshape(20800, 28, 28, 1), test_labels))
print(history)
model.save("emnist_cnn.model")

其中,可以将全连接的神经元适当增加,提高准确率

在测试集中,预测的准确率在91.65%,相比于普通的神经网络,卷积神经网络的准确率更高一些,但是在算法上也更加复杂

基于tensorflow、keras利用emnist数据集构建CNN卷积神经网络进行手写字母识别相关推荐

  1. 【FPGA教程案例100】深度学习1——基于CNN卷积神经网络的手写数字识别纯Verilog实现,使用mnist手写数字数据库

    FPGA教程目录 MATLAB教程目录 ---------------------------------------- 目录 1.软件版本 2.CNN卷积神经网络的原理 2.1 mnist手写数字数 ...

  2. CNN卷积神经网络实现手写数字识别(基于tensorflow)

    1.1卷积神经网络简介 文章目录 1.1卷积神经网络简介 1.2 神经网络 1.2.1 神经元模型 1.2.2 神经网络模型 1.3 卷积神经网络 1.3.1卷积的概念 1.3.2 卷积的计算过程 1 ...

  3. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  4. keras从入门到放弃(十三)卷积神经网络处理手写数字识别

    今天来一个cnn例子 手写数字识别,因为是图像数据 import keras from keras import layers import numpy as np import matplotlib ...

  5. 用Keras搭建神经网络 简单模版(三)—— CNN 卷积神经网络(手写数字图片识别)...

    # -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...

  6. 卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...

    LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧. 环境 Pyc ...

  7. 基于卷积神经网络的手写数字识别、python实现

    一.CNN网络结构与构建 参数: 输入数据的维数,通道,高,长 input_dim=(1, 28, 28) 卷积层的超参数,filter_num:滤波器数量,filter_size:滤波器大小,str ...

  8. DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构进行迁移学习

    DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构迁移学习 目录 数据集 输出结果 设计思路 1.基模型 2.思路导图 核心代码 更多输出 数据集 Datas ...

  9. DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构FineTuning

    DL之VGG16:基于VGG16(Keras)利用Knifey-Spoony数据集对网络架构FineTuning 输出结果   False: input_1 False: block1_conv1 F ...

最新文章

  1. jquery.autocomplete修改 实现键盘上下键 自动填充
  2. msb Lesson00_Object_Class.scala
  3. 如何用matlab分割颜色,Matlab:基于颜色的分割
  4. Qt Creator测验Testing
  5. Deeplearnng.AI第四部分第二周、经典网络
  6. Re(正则表达式)库入门
  7. jquery 字符串查找_JQuery、Vue等考点
  8. 2021湖北高考成绩查询热线,湖北招生考试网:2021年湖北高考成绩查询入口、查分系统...
  9. 二叉树C++ | 广度优先遍历(层级顺序遍历)_2
  10. navmenu 收起没有动画 element_ABC360等3家英语动画片课程测评:用动画片学英语不靠谱?...
  11. mybatis collection标签_MyBatis第二天(结果映射+动态sql+关联查询)
  12. LayUI表单验证select定位失效问题
  13. Layabox 屏幕适配
  14. 如何在职场上获得良好的起点
  15. 页面中多次使用TWEEN.update()的坑
  16. Mock.js + RAP 使用介绍
  17. 处理linux centos7中登陆plsql后退格键上下键使用乱码问题
  18. 如何做到像百度云或者网易公开课一样动态更换APP启动图
  19. 技术人员的职业发展规划思考书单推荐
  20. linux升级之后黑屏,fedora升级到28之后gnome登录黑屏的解决方法

热门文章

  1. 计算机桌面锁写快捷,锁定计算机快捷键_锁定计算机的快捷键
  2. 深度学习之遥感图像标注(二)
  3. 智能温室管理系统种蘑菇是怎么样的
  4. 连连看算法 Unity版
  5. Prenetics拟赴美上市:预计2021年收入翻两倍,阿里、平安均为股东
  6. JUC-Callable
  7. SVN服务器端的安装和配置(服务端的使用)
  8. svg矢量图制作工具(Sketsa SVG Editor) v7.1.1 中文免费版
  9. 研发组织中的“长尾类”问题如何看待和消除?
  10. u盘被格式化了文件还可以恢复吗?