机器学习问题可以分为回归问题和分类问题,回归问题已经在线性回归讲过,本文学习分类问题。分类问题跟回归问题有明显的区别,回归问题是连续的数值,而分类问题是离散的类别,比如将性别分为[男,女],将图片分为[猫,狗,兔]等。

学习分类问题用的最多的示例是LeCun的MNIST手写数字识别。

MNIST手写数字数据集介绍

MNIST手写数字数据集包含60000张训练图片和10000张测试图片,这些图片是从0~9的手写数字,分辨率为28*28,大致是下面这个样子:

MNIST数据集

该数据集广泛适用于各大机器学习平台入门程序,有非常成熟的处理方案。不过为了深入理解数据格式,我们先来手动下载并读取这些数据,数据格式详见。

手动读取MINST数据集:

import 

输出:

预览MINIST数据集

分类任务的误差函数

前面说过回归型任务可以用MSE(Mean squared error)均方误差作为loss函数,分类任务的误差如何表示呢?比如手写数字识别任务需要将图片分成10类,既0~9每个数字一类,很容易想到的一个误差衡量方式是:如果分类正确误差是0,如果分类不正确误差是1。这样做人很容易理解,但对计算机来说作用不大,因为并没法表示出分类不正确的程度。

通常使用交叉熵(Cross Entropy)表示分类问题的loss函数,其定义如下:

  • 表示类别的个数,比如数字识别的类别是个数是10个
  • 表示真实的类别,如果属于当前类别
    ,则为值1,否则为值0
  • 表示预测输出属于类别
    的概率,如果有10个类别,则根据总概率为1的约束一定有:

特别的如果只有两个类别,比如性别[男,女]这种二分类问题,上述公式可以简写成:

TensorFlow实现手写数字识别(一)

可以用神经网络来做手写数字识别,一个只有一层全连接层的网络拓扑如下:

一层神经网络

注:图上全连接的线条没有全部画出,两层之间任意两点之间都有一条连线。

网络非常的简单,784个像素输入通过一个全连接层跟10个输出相连。网络所需参数个数为权重 784 * 10,外加10个bias。用公式表达就是:

  • 其中

    表示训练数据向量,对于MNIST就是(60000, 784)维的向量
  • 是权重矩阵,维度是(784, 10)
  • 是bias向量,维度是(10)
  • activation是激活函数,他是一个将线性空间向非线性空间映射的函数,用以增强网络的表达能力。常见的激活函数有relu, sigmoid, tanh等。
  • 表示预测属于各个类别的概率,最终将采用概率最大的那个类别作为最终类别。

为什么需要激活函数?

多个线性函数的组合仍然是线性函数,激活函数将线性空间映射到非线性空间,增强网络的表达能力。比如任意多个一元线性回归的加和是多元线性回归,无法表示抛物线、正弦曲线等非线性数据。而多个线性单元经过激活函数之后再做组合,理论上可以拟合任意复杂的公式。详见https://zhuanlan.zhihu.com/p/165849993

常见激活函数

上述神经网络拓扑对应的代码实现如下,该版本的代码尽可能的展现细节,除了微分计算采用了tf自动求导其余的操作均手工完成,对理解原理很有帮助

import 

输出结果:

这么简单的网络经过1000轮的训练,在测试集上取得了77.42%的识别正确率,且到后期随着训练轮数增加预测效果也不再增加,算是差强人意吧。

上图中看到训练误差在逐渐降低,预测效果在逐渐升高但有些波动,这根learning_rate的设置有关系,比如学习率设置过大导致训练效果来回抖动。学习率如何设置是一个单独的话题,感兴趣的话可以了解一下“优化器”。

TensorFlow实现手写数字识别(二)

上述网络只有一层表现力较弱,识别正确率只有77.42%,通过增加网络层数可以提高网络的表达能力,两层网络结构如下:

两层神经网络

注:图上全连接的线条没有全部画出,两层之间任意两点之间都有一条连线。

代码如下:

import 

执行结果:

识别效果达到了82.16%,比一层网络要好一些。

keras实现手写数字识别

深层网络比浅层有更好的表现力,不过手写深层网络显得很麻烦,keras大大简化网络构造的复杂度。

keras是一套针对模型的的极简API,可以用几行代码写出复杂的网络结构,完全屏蔽实现细节(那是具体后端的工作,比如tensorflow)

使用keras编写两层网络的手写数字识别:

import 

网络构建和训练超乎想象的简洁!最终训练效果如下:

注意:keras.datasets.mnist会从https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz下载数据。可以手动下载这个数据并放在~/.keras/datasets,全路径名为 ~/.keras/datasets/mnist.npz

网络summary
训练和预测结果

最终识别效果95%,效果比手写的示例有大幅提升。

keras实现手写数字识别——卷积神经网络

针对图像处理有专门的神经网络结构——卷积神经网络。使用keras可以很容易的搭建卷积神经网络:

卷积神经网络是一种专门用于处理图像的特殊的神经网络,他在传统神经网络中加入“卷积”和“池化”层模拟人的大脑对视觉的处理方式,大幅提升图片的识别准确度。

import 

输出:

网路summary
训练和预测结果

识别精度高达99%,卷积神经网络处理图像分类果然有奇效!

总结

本文首先对MNIST数据集进行了介绍,之后用多种方式实现了手写数字识别的神经网络,对理解神经网络原理很有帮助。Keras是编写神经网络的利器,用几行代码就能构建出复杂的网络模型,极大简化编码复杂度。另外简单介绍了激活函数、keras、卷积神经网络等概念。

如有收获请直接点赞~

matlab朴素贝叶斯手写数字识别_从“手写数字识别”学习分类任务相关推荐

  1. 用matlab朴素贝叶斯,Matlab朴素贝叶斯

    你好我正在使用KDD 1999数据集,我正在寻找在matlab中应用朴素贝叶斯.我想知道的是,kdd数据集是一个494021x42数据数组,如果您注意到下面的朴素贝叶斯代码中的"traini ...

  2. matlab 朴素贝叶斯模型 代码及其案例

    简介 朴素贝叶斯分类器(Naive Bayes Classifier 或 NBC)发源于古典数学理论,有着坚实的数学基础,以及稳定的分类效率.同时,NBC模型所需估计的参数很少,对缺失数据不太敏感,算 ...

  3. MATLAB朴素贝叶斯(德国信用卡案例)

    我们matlab建模课的案例 数据以及代码 链接:https://pan.baidu.com/s/18qpV2qsHzwbnOgZBMBHdGQ?pwd=r8g2  提取码:r8g2 参考书:MATL ...

  4. matlab朴素贝叶斯手写数字识别_基于MNIST数据集实现手写数字识别

    介绍 在TensorFlow的官方入门课程中,多次用到mnist数据集.mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx*-ubyte ...

  5. matlab朴素贝叶斯手写数字识别_TensorFlow手写数字识别(一)

    本篇文章通过TensorFlow搭建最基础的全连接网络,使用MNIST数据集实现基础的模型训练和测试. MNIST数据集 MNIST数据集 :包含7万张黑底白字手写数字图片,其中55000张为训练集, ...

  6. matlab朴素贝叶斯手写数字识别_机器学习系列四:MNIST 手写数字识别

    4. MNIST 手写数字识别 机器学习中另外一个相当经典的例子就是MNIST的手写数字学习.通过海量标定过的手写数字训练,可以让计算机认得0~9的手写数字.相关的实现方法和论文也很多,我们这一篇教程 ...

  7. matlab朴素贝叶斯工具箱,朴素贝叶斯分类matlab实现.doc

    朴素贝叶斯分类matlab实现 实验二 朴素贝叶斯分类 一.实验目的 通过实验,加深对统计判决与概率密度估计基本思想.方法的认识,了解影响Bayes分类器性能的因素,掌握基于Bayes决策理论的随机模 ...

  8. 朴素贝叶斯算法python sklearn实现_朴素贝叶斯算法优化与 sklearn 实现

    进行拉普拉斯平滑运算后,我们运行程序,仍然得出了两个测试样本均属于非侮辱类的结果,这是为什么呢? 我们查看最终计算出的 p0 和 p1 会发现,他们的结果都是 0,这又是为什么呢? 这是因为出现了另一 ...

  9. 机器学习之朴素贝叶斯(一)思想及典型例题手写实现

    条件概率.全概率公式.贝叶斯公式 条件概率公式: 在事件B发生的条件下,A发生的概率 换一种写法: 理解了条件概率公式后,用一个引例介绍后面两个公式:村子里有三个小偷,事件B={村子失窃},已知小偷们 ...

最新文章

  1. IDEAWebstorm使用
  2. 帧中继简单总结(修改)
  3. SpringMVC如何实现aop
  4. java web 打印控件_web打印,web打印控件,dotnet web打印控件,java web打印控件,webprint...
  5. Where we can find sharepoint user list
  6. 系统讲解——更好的实施专案(Porject)
  7. 天然气表怎么看多少方_宝宝奶粉的的营养成分表,到底怎么看?
  8. Oracle 10g 问题集锦
  9. linux目录隐藏技术,Linux环境下的高级隐藏技术
  10. 从零开始学习音视频编程技术(二) 音频格式讲解
  11. 小程序入门学习06--data、url传参、调用豆瓣api
  12. VC文档与视图结构学习总结
  13. go模板引擎生成html,goweb-模板引擎
  14. oa系统在线试用,零成本开始研发协作免费试用
  15. android怎么实现记住密码功能,Android实现用户登录记住密码功能
  16. Oracle 设置数据库时区
  17. Google SketchUp SDK
  18. 做硬件,想当然,犯大错
  19. hdu 3932 Groundhog Build Home
  20. TPTP测试项目的性能

热门文章

  1. java 模块化 soa_OSGI与SOA的千丝万缕
  2. 苹果明年有望推出15英寸版MacBook Air
  3. 799元!乐视智能门锁新品Le1S发布
  4. 剧集《赘婿》向流媒体平台Watcha授出翻拍权
  5. 研究称:苹果开始感受到全球芯片短缺影响,但三星等受影响更大
  6. 三星电子与索尼在CMOS图像传感器市场份额差距缩小
  7. 路痴福音!高德地图上线真AR步行导航,可实景指引
  8. 戴志坚接替李小加出任职港交所行政总裁 基本年薪700万港元
  9. 特斯拉副总裁回应“质量不合格”报道:离谱 已准备起诉
  10. 华为P50 Pro最新渲染图曝光:后置造型有点奇怪