2021.10.7 【学习日记】手写数字识别及神经网络基本模型

1 概述

张量(tensor)是数字的容器,是矩阵向任意维度的推广,其维度称为轴(axis)。深度学习的本质是对张量做各种运算处理,其分层几乎总是通过神经网络(neural network)的模型学习得到。
神经网络是高维空间中复杂的几何变换,通过一层层数据变换,抽丝剥茧,为复杂抽象、高度折叠的数据流找到简洁的表示。对于数据的处理大体上可分为分类(classification)、标注(tagging)和回归(regression)。
举个简单的例子,计算机识别手写数字时,实际上是我们输入每张图片对应的像素矩阵,图片大小为28*28像素,全连接网络把它展成784个像素组成的一维数组作为输入特征,经过一定的权重(weight)计算处理后得到预测值。用损失函数(loss function)评判预测和目标值的距离。通过优化器(optimizer)微调权重值——在深度学习领域称为反向传播(backpropagation)。通过设定一定的循环次数(loop)优化权重,使损失值越来越低。

近年来,人们提出了越来越多较为成功的模型,如VGG16,RNN,ResNet,使深度学习大放异彩。ImageNet大规模视觉识别挑战赛不断激励着研究人员挑战极限,促进了深度学习的继续发展。
Keras发布于2015年,是基于Python的深度学习框架,可以方便地定义和调用几乎所有类型的深度学习模型,封装程度和复用性高。其友好简洁的API使得解决深度学习问题像搭乐高积木一样简单。在机器学习竞赛网站Kaggle上,Keras也大受欢迎。
TensorFlow作为Keras的后端引擎之一,封装了低层次的张量运算库。本文所总结的基本模型基于tf.keras搭建。

2 通用模型

简单的神经网络搭建可以归纳为“六步法”。
一、 import需要的模块。引入我们需要的库并重命名,如import tensorflow as tf。
二、 加载数据集,划分和指定数据训练集的输入和标签、测试集的输入和标签。按照需要可以对数据类型转换astype,或改变形状reshape;对标签分类编码需用到keras.utils.np_utils;布置随机种子seed和打乱顺序shuffle;做归一化、去中心化等处理。在此不做赘述。
三、 用Sequential类或API搭建神经网络结构。Sequential是一个容器,装载着神经网络的各种信息,我们可以根据目标选择不同的层,如拉直flatten,全连接dense,卷积conv,LSTM层等,对层做线性堆叠。如果网络中存在跳连,如循环神经网络RNN,那么Sequential法不再适用,这时需要用类封装网格结构。


四、 compile配置训练方法,包括优化器optimizer、损失函数loss、评价指标如metrics等。
五、 fit执行训练过程,定义batch_size和epochs。其中会返回History对象,里面是包含训练过程中val_acc,acc,val_loss等数据的字典。
六、 summary打印网格结构和参数统计,绘制图像plt.plot直观地表示结果。
处理简单的深度学习问题都可以用上述基本模型,除了第三步需要根据实际情况配置,其余部分均无需改动。

3 实验结果

3.1手写数字识别介绍
下面用基本模型解决手写数字识别的问题。手写数字识别是深度学习中的”hello world”,这个实验非常简单,但重要性不可忽视。其中用到的MNIST数据集,被应用于机器学习或图像处理领域的各种场合。

NMIST由数字0到9的图像构成,有6万张训练图像,1万张测试图像,输出标签已经给出,为数字0到9。各图像数据是28*28像素的灰度图像,通道为1,像素取值为0到255。
导入keras中提供的MNIST数据集接口,尝试查看训练集中第一张图片,及其形状的数组形式,直观感受计算机看到的数据。

import tensorflow as tf
from matplotlib import pyplot as plt
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
plt.imshow(x_train[0], cmap='gray')
plt.show()
print("x_train[0]:\n", x_train[0])

显示结果如下:


3.2模型应用
像素值归一化,方便计算机处理0到1的小范围数据,使神经网络更快收敛。

x_train, x_test = x_train / 255.0, x_test / 255.0

搭建网格架构。先对图像数据拉直为一维数据;再进行128路全连接,激活函数为relu;输出层为10路,用softmax概率分布,因为输出标签为0-9共十项,概率最大的判别为数字识别结果。

model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])

随后编译训练网络的参数,优化器推荐选择adam,损失函数选择稀疏化处理(one-hot编码)的分类交叉熵,并将其作为观察指标。

model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])

拟合模型,设置训练数据和标签,一次喂入神经网络的数据量为32,迭代次数为5次,每迭代一次在测试集上测试一次准确率。此外还可以用validation_split划分数据集中测试集的比例。

model.fit(x_train, y_train,
batch_size=32, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1)

打印网格结构,统计参数。

model.summary()


3.3Class类介绍
在第二章介绍的通用模型第三步中,除了使用Sequential简单搭建网络,还可以用Class类的方式封装。

class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return y

如需使用该网络,将class实例化即可。运行结果和原方法完全一样。

model = MnistModel()

3.4结果分析
model.fit()返回一个History对象,其中的成员history是一个字典,存储了网络的训练和验证精度及损失共4个参数。为了方便观察训练结果,画图显示结果。

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize = (8,8))
plt.subplot(1,2,1)
plt.plot(acc, label='training accuracy')
plt.plot(val_acc, label='validation accuracy')
plt.legend()
plt.subplot(1,2,2)
plt.plot(loss, label='training loss')
plt.plot(val_loss, label='validation loss')
plt.legend()
plt.show()

由于网络初始化的随机性,每次生成图像可能会略有不同。
增加迭代次数,观察训练效果。随着迭代次数增加,验证误差显然在逐渐增大,这是因为发生了过拟合(overfit)。过拟合的最佳解决办法是增加训练数据,其次是对模型允许存储的信息加以约束,方法有权重正则化、dropout正则化等。



4 总结和收获

在实践深度学习框架的过程中,应善用TensorFlow的官方文档查询API。如果使用pycharm开发环境,也可以将鼠标放置在相应函数上按Ctrl键查看源码,了解输入参数、功能等信息。
我有一个直观的感受,训练神经网络处处存在着优化和泛化的对立统一,常常顾此失彼。深度学习模型通常适合拟合训练数据,但也容易产生过拟合。在第三章的实验中,通过提高迭代次数,就发现了这一现象。这时需要正则化,添加罚项等方法减少过拟合。


我们不可能把庞大的数据和训练量丢给计算机,随后便置之不理,这样得出的结果也许在验证集的表现非常优秀,却不知其实已经发生了信息泄露(information leak)。请牢记,理想的模型在欠拟合和过拟合、容量不足和大容量的边界上。
此外,在机器学习界有一种说法,不存在某种算法对所有问题都有效。处理不同问题时,我们要随机应变,力求找到更适合的模型。
本文只是简单讨论了一下深度学习中最为基础的框架,梳理了基本模型实现的思路。在解决更复杂的应用问题时,该模型是远远不够的,需要进行充分改进和优化,包括但不限于数据增强、dropout正则化、K折验证、RNN、GAN等等。篇幅有限,不作介绍。涉及到的相关内容,结合本人学习经验,推荐阅读“Keras之父”肖莱《Python深度学习》,邱锡鹏《神经网络与深度学习》,李航《统计学习方法》。
参考文献:

  1. Francois Chollet. Deep Learning with Python. 张亮译.Python深度学习.北京:人民邮电出版社,2018
  2. 斋藤康毅.Deep Learning from Scratch. 陆宇杰译. 北京:人民邮电出版社,2018
  3. 邱锡鹏.神经网络与深度学习.北京:机械工业出版社,2020
  4. 李航.统计学习方法.北京:清华大学出版社,2012
  5. www.tensorflow.org. TensorFlow官方文档.

2021.10.7

【学习日记】手写数字识别及神经网络基本模型相关推荐

  1. 基于深度学习的手写数字识别、python实现

    基于深度学习的手写数字识别.python实现 一.what is 深度学习 二.加深层可以减少网络的参数数量 三.深度学习的手写数字识别 一.what is 深度学习 深度学习是加深了层的深度神经网络 ...

  2. 基于深度学习的手写数字识别Matlab实现

    基于深度学习的手写数字识别Matlab实现 1.网络设计 2. 训练方法 3.实验结果 4.实验结果分析 5.结论 1.网络设计 1.1 CNN(特征提取网络+分类网络) 随着深度学习的迅猛发展,其应 ...

  3. 03_深度学习实现手写数字识别(python)

    本次项目采用了多种模型进行测试,并尝试策略来提升模型的泛化能力,最终取得了99.67%的准确率,并采用pyqt5来制作可视化GUI界面进行呈现.具体代码已经开源. 代码详情见附录 1简介 早在1998 ...

  4. 【深度学习】手写数字识别Tensorflow2实验报告

    实验一:手写数字识别 一.实验目的 利用深度学习实现手写数字识别,当输入一张手写图片后,能够准确的识别出该图片中数字是几.输出内容是0.1.2.3.4.5.6.7.8.9的其中一个. 二.实验原理 ( ...

  5. 实验四 手写数字识别的神经网络算法设计与实现

    实验四 手写数字识别的神经网络算法设计与实现 一.实验目的 通过学习BP神经网络技术,对手写数字进行识别,基于结构的识别法及模板匹配法来提高识别率. 二.实验器材 PC机 matlab软件 三.实验内 ...

  6. 模式识别 实验四 手写数字识别的神经网络算法设计与实现

    实验四 手写数字识别的神经网络算法设计与实现 一.实验目的 通过学习BP神经网络技术,对手写数字进行识别,基于结构的识别法及模板匹配法来提高识别率. 二.实验器材 PC机 matlab软件 三.实验内 ...

  7. 基于深度学习的手写数字识别算法Python实现

    摘 要 深度学习是传统机器学习下的一个分支,得益于近些年来计算机硬件计算能力质的飞跃,使得深度学习成为了当下热门之一.手写数字识别更是深度学习入门的经典案例,学习和理解其背后的原理对于深度学习的理解有 ...

  8. Python基于深度学习的手写数字识别

    Python基于深度学习的手写数字识别 1.代码的功能和运行方法 2. 网络设计 3.训练方法 4.实验结果分析 5.结论 1.代码的功能和运行方法 代码可以实现任意数字0-9的识别,只需要将图片载入 ...

  9. Python实现深度学习MNIST手写数字识别(单文件,非框架,无需GPU,适合初学者)

    注: 本文根据阿卡蒂奥的Python深度学习博客文章代码进行调整,修复了少量问题,原文地址:https://blog.csdn.net/akadiao/article/details/78175737 ...

最新文章

  1. 发布一个验证码生成组件
  2. 实例分析C语言中strlen和sizeof的区别
  3. 调试Release发布版程序的Crash错误
  4. Swift之String的简单实用
  5. SAP BSP应用configuration的加载原理
  6. python函数的作用复用代码_Python-函数和代码复用
  7. jpa 实体映射视图_JPA教程:映射实体–第1部分
  8. Java-数据类型拓展
  9. Lamda和kappa架构
  10. Linux后门入侵检测工具 rkhunter 安装使用
  11. 计蒜客 蒜头君的数轴
  12. PY++ 自动将你的C++程序接口封装供python调用
  13. CentOS更换阿里yum源
  14. 机器人读懂人心的九大模型
  15. FOI冬令营 Day1
  16. 冬吃萝卜有讲究 名中医解疑惑
  17. 2015年全国大学生电子设计大赛综合测评题
  18. 主动事务处理器编写BFM
  19. 七、手写实现决策树算法
  20. 基于51单片机蓝牙密码锁

热门文章

  1. vrf中的ipsec
  2. RNN模型与NLP应用笔记(2):文本处理与词嵌入详解及完整代码实现(Word Embedding)
  3. 视频会议时听不到声音该如何处理?
  4. 传感器常见技术参数介绍
  5. INT 21H中断大全
  6. poj 3095 Linear Pachinko 模拟水题
  7. 实分析笔记(1):康托尔基数理论
  8. 激光打印机的粉盒装粉
  9. Android Launcher分析和修改1——Launcher默认界面配置(default_workspace)
  10. qt客户端显示服务器发送的图片,qt客户端显示服务器发送的图片