【学习日记】手写数字识别及神经网络基本模型
2021.10.7 【学习日记】手写数字识别及神经网络基本模型
1 概述
张量(tensor)是数字的容器,是矩阵向任意维度的推广,其维度称为轴(axis)。深度学习的本质是对张量做各种运算处理,其分层几乎总是通过神经网络(neural network)的模型学习得到。
神经网络是高维空间中复杂的几何变换,通过一层层数据变换,抽丝剥茧,为复杂抽象、高度折叠的数据流找到简洁的表示。对于数据的处理大体上可分为分类(classification)、标注(tagging)和回归(regression)。
举个简单的例子,计算机识别手写数字时,实际上是我们输入每张图片对应的像素矩阵,图片大小为28*28像素,全连接网络把它展成784个像素组成的一维数组作为输入特征,经过一定的权重(weight)计算处理后得到预测值。用损失函数(loss function)评判预测和目标值的距离。通过优化器(optimizer)微调权重值——在深度学习领域称为反向传播(backpropagation)。通过设定一定的循环次数(loop)优化权重,使损失值越来越低。
近年来,人们提出了越来越多较为成功的模型,如VGG16,RNN,ResNet,使深度学习大放异彩。ImageNet大规模视觉识别挑战赛不断激励着研究人员挑战极限,促进了深度学习的继续发展。
Keras发布于2015年,是基于Python的深度学习框架,可以方便地定义和调用几乎所有类型的深度学习模型,封装程度和复用性高。其友好简洁的API使得解决深度学习问题像搭乐高积木一样简单。在机器学习竞赛网站Kaggle上,Keras也大受欢迎。
TensorFlow作为Keras的后端引擎之一,封装了低层次的张量运算库。本文所总结的基本模型基于tf.keras搭建。
2 通用模型
简单的神经网络搭建可以归纳为“六步法”。
一、 import需要的模块。引入我们需要的库并重命名,如import tensorflow as tf。
二、 加载数据集,划分和指定数据训练集的输入和标签、测试集的输入和标签。按照需要可以对数据类型转换astype,或改变形状reshape;对标签分类编码需用到keras.utils.np_utils;布置随机种子seed和打乱顺序shuffle;做归一化、去中心化等处理。在此不做赘述。
三、 用Sequential类或API搭建神经网络结构。Sequential是一个容器,装载着神经网络的各种信息,我们可以根据目标选择不同的层,如拉直flatten,全连接dense,卷积conv,LSTM层等,对层做线性堆叠。如果网络中存在跳连,如循环神经网络RNN,那么Sequential法不再适用,这时需要用类封装网格结构。
四、 compile配置训练方法,包括优化器optimizer、损失函数loss、评价指标如metrics等。
五、 fit执行训练过程,定义batch_size和epochs。其中会返回History对象,里面是包含训练过程中val_acc,acc,val_loss等数据的字典。
六、 summary打印网格结构和参数统计,绘制图像plt.plot直观地表示结果。
处理简单的深度学习问题都可以用上述基本模型,除了第三步需要根据实际情况配置,其余部分均无需改动。
3 实验结果
3.1手写数字识别介绍
下面用基本模型解决手写数字识别的问题。手写数字识别是深度学习中的”hello world”,这个实验非常简单,但重要性不可忽视。其中用到的MNIST数据集,被应用于机器学习或图像处理领域的各种场合。
NMIST由数字0到9的图像构成,有6万张训练图像,1万张测试图像,输出标签已经给出,为数字0到9。各图像数据是28*28像素的灰度图像,通道为1,像素取值为0到255。
导入keras中提供的MNIST数据集接口,尝试查看训练集中第一张图片,及其形状的数组形式,直观感受计算机看到的数据。
import tensorflow as tf
from matplotlib import pyplot as plt
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
plt.imshow(x_train[0], cmap='gray')
plt.show()
print("x_train[0]:\n", x_train[0])
显示结果如下:
3.2模型应用
像素值归一化,方便计算机处理0到1的小范围数据,使神经网络更快收敛。
x_train, x_test = x_train / 255.0, x_test / 255.0
搭建网格架构。先对图像数据拉直为一维数据;再进行128路全连接,激活函数为relu;输出层为10路,用softmax概率分布,因为输出标签为0-9共十项,概率最大的判别为数字识别结果。
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])
随后编译训练网络的参数,优化器推荐选择adam,损失函数选择稀疏化处理(one-hot编码)的分类交叉熵,并将其作为观察指标。
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
拟合模型,设置训练数据和标签,一次喂入神经网络的数据量为32,迭代次数为5次,每迭代一次在测试集上测试一次准确率。此外还可以用validation_split划分数据集中测试集的比例。
model.fit(x_train, y_train,
batch_size=32, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1)
打印网格结构,统计参数。
model.summary()
3.3Class类介绍
在第二章介绍的通用模型第三步中,除了使用Sequential简单搭建网络,还可以用Class类的方式封装。
class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return y
如需使用该网络,将class实例化即可。运行结果和原方法完全一样。
model = MnistModel()
3.4结果分析
model.fit()返回一个History对象,其中的成员history是一个字典,存储了网络的训练和验证精度及损失共4个参数。为了方便观察训练结果,画图显示结果。
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize = (8,8))
plt.subplot(1,2,1)
plt.plot(acc, label='training accuracy')
plt.plot(val_acc, label='validation accuracy')
plt.legend()
plt.subplot(1,2,2)
plt.plot(loss, label='training loss')
plt.plot(val_loss, label='validation loss')
plt.legend()
plt.show()
由于网络初始化的随机性,每次生成图像可能会略有不同。
增加迭代次数,观察训练效果。随着迭代次数增加,验证误差显然在逐渐增大,这是因为发生了过拟合(overfit)。过拟合的最佳解决办法是增加训练数据,其次是对模型允许存储的信息加以约束,方法有权重正则化、dropout正则化等。
4 总结和收获
在实践深度学习框架的过程中,应善用TensorFlow的官方文档查询API。如果使用pycharm开发环境,也可以将鼠标放置在相应函数上按Ctrl键查看源码,了解输入参数、功能等信息。
我有一个直观的感受,训练神经网络处处存在着优化和泛化的对立统一,常常顾此失彼。深度学习模型通常适合拟合训练数据,但也容易产生过拟合。在第三章的实验中,通过提高迭代次数,就发现了这一现象。这时需要正则化,添加罚项等方法减少过拟合。
我们不可能把庞大的数据和训练量丢给计算机,随后便置之不理,这样得出的结果也许在验证集的表现非常优秀,却不知其实已经发生了信息泄露(information leak)。请牢记,理想的模型在欠拟合和过拟合、容量不足和大容量的边界上。
此外,在机器学习界有一种说法,不存在某种算法对所有问题都有效。处理不同问题时,我们要随机应变,力求找到更适合的模型。
本文只是简单讨论了一下深度学习中最为基础的框架,梳理了基本模型实现的思路。在解决更复杂的应用问题时,该模型是远远不够的,需要进行充分改进和优化,包括但不限于数据增强、dropout正则化、K折验证、RNN、GAN等等。篇幅有限,不作介绍。涉及到的相关内容,结合本人学习经验,推荐阅读“Keras之父”肖莱《Python深度学习》,邱锡鹏《神经网络与深度学习》,李航《统计学习方法》。
参考文献:
- Francois Chollet. Deep Learning with Python. 张亮译.Python深度学习.北京:人民邮电出版社,2018
- 斋藤康毅.Deep Learning from Scratch. 陆宇杰译. 北京:人民邮电出版社,2018
- 邱锡鹏.神经网络与深度学习.北京:机械工业出版社,2020
- 李航.统计学习方法.北京:清华大学出版社,2012
- www.tensorflow.org. TensorFlow官方文档.
2021.10.7
【学习日记】手写数字识别及神经网络基本模型相关推荐
- 基于深度学习的手写数字识别、python实现
基于深度学习的手写数字识别.python实现 一.what is 深度学习 二.加深层可以减少网络的参数数量 三.深度学习的手写数字识别 一.what is 深度学习 深度学习是加深了层的深度神经网络 ...
- 基于深度学习的手写数字识别Matlab实现
基于深度学习的手写数字识别Matlab实现 1.网络设计 2. 训练方法 3.实验结果 4.实验结果分析 5.结论 1.网络设计 1.1 CNN(特征提取网络+分类网络) 随着深度学习的迅猛发展,其应 ...
- 03_深度学习实现手写数字识别(python)
本次项目采用了多种模型进行测试,并尝试策略来提升模型的泛化能力,最终取得了99.67%的准确率,并采用pyqt5来制作可视化GUI界面进行呈现.具体代码已经开源. 代码详情见附录 1简介 早在1998 ...
- 【深度学习】手写数字识别Tensorflow2实验报告
实验一:手写数字识别 一.实验目的 利用深度学习实现手写数字识别,当输入一张手写图片后,能够准确的识别出该图片中数字是几.输出内容是0.1.2.3.4.5.6.7.8.9的其中一个. 二.实验原理 ( ...
- 实验四 手写数字识别的神经网络算法设计与实现
实验四 手写数字识别的神经网络算法设计与实现 一.实验目的 通过学习BP神经网络技术,对手写数字进行识别,基于结构的识别法及模板匹配法来提高识别率. 二.实验器材 PC机 matlab软件 三.实验内 ...
- 模式识别 实验四 手写数字识别的神经网络算法设计与实现
实验四 手写数字识别的神经网络算法设计与实现 一.实验目的 通过学习BP神经网络技术,对手写数字进行识别,基于结构的识别法及模板匹配法来提高识别率. 二.实验器材 PC机 matlab软件 三.实验内 ...
- 基于深度学习的手写数字识别算法Python实现
摘 要 深度学习是传统机器学习下的一个分支,得益于近些年来计算机硬件计算能力质的飞跃,使得深度学习成为了当下热门之一.手写数字识别更是深度学习入门的经典案例,学习和理解其背后的原理对于深度学习的理解有 ...
- Python基于深度学习的手写数字识别
Python基于深度学习的手写数字识别 1.代码的功能和运行方法 2. 网络设计 3.训练方法 4.实验结果分析 5.结论 1.代码的功能和运行方法 代码可以实现任意数字0-9的识别,只需要将图片载入 ...
- Python实现深度学习MNIST手写数字识别(单文件,非框架,无需GPU,适合初学者)
注: 本文根据阿卡蒂奥的Python深度学习博客文章代码进行调整,修复了少量问题,原文地址:https://blog.csdn.net/akadiao/article/details/78175737 ...
最新文章
- 发布一个验证码生成组件
- 实例分析C语言中strlen和sizeof的区别
- 调试Release发布版程序的Crash错误
- Swift之String的简单实用
- SAP BSP应用configuration的加载原理
- python函数的作用复用代码_Python-函数和代码复用
- jpa 实体映射视图_JPA教程:映射实体–第1部分
- Java-数据类型拓展
- Lamda和kappa架构
- Linux后门入侵检测工具 rkhunter 安装使用
- 计蒜客 蒜头君的数轴
- PY++ 自动将你的C++程序接口封装供python调用
- CentOS更换阿里yum源
- 机器人读懂人心的九大模型
- FOI冬令营 Day1
- 冬吃萝卜有讲究 名中医解疑惑
- 2015年全国大学生电子设计大赛综合测评题
- 主动事务处理器编写BFM
- 七、手写实现决策树算法
- 基于51单片机蓝牙密码锁