作者:Treant 人工智能爱好者社区专栏作者
博客专栏:https://www.cnblogs.com/en-heng

1.问题

Kaggle上有一个图像分类比赛Digit Recognizer,数据集是大名鼎鼎的MNIST——图片是已分割 (image segmented)过的28*28的灰度图,手写数字部分对应的是0~255的灰度值,背景部分为0。

from keras.datasets import mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train[0] # .shape = 28*28
"""
[[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
...
[ 0 0 0 0 0 0 0 0 0 0 0 0 3 18 18 18 126 136
175 26 166 255 247 127 0 0 0 0]
[ 0 0 0 0 0 0 0 0 30 36 94 154 170 253 253 253 253 253
225 172 253 242 195 64 0 0 0 0]
...
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]]
"""

手写数字图片是长这样的:

手写数字识别可以看做是一个图像分类问题——对二维向量的灰度图进行分类。

2.识别

Rodrigo Benenson给出50种方法在MNIST的错误率。本文将从传统方法过渡到深度学习,对比准确率来看。以下代码基于Python 3.6 + sklearn 0.18.1 + keras 2.0.4。

传统方法

kNN

思路比较简单:将二维向量拉直成一个一维向量,基于距离度量以判断向量间的相似性。显而易见,这种不带特征提取的朴素办法,丢掉了二维向量中最重要的四周相邻像素的信息。在比较干净的数据集MNIST还有不错的表现,准确率为96.927%。此外,kNN模型训练慢。

from sklearn import neighbors
from sklearn.metrics import precision_score

num_pixels = x_train[0].shape[0] * x_train[0].shape[1]
x_train = x_train.reshape((x_train.shape[0], num_pixels))
x_test = x_test.reshape((x_test.shape[0], num_pixels))

knn = neighbors.KNeighborsClassifier()
knn.fit(x_train, y_train)
pred = knn.predict(x_test)
precision_score(y_test, pred, average='macro') # 0.96927533865705706

MLP

多层感知器MLP (Multi Layer Perceptron)亦即三层的前馈神经网络,所采用的特征与kNN方法类似——每一个像素点的灰度值对应于输入层的一个神经元,隐藏层的神经元数为700(一般介于输入层与输出层的数量之间)。sklearn的MLPClassifier实现MLP分类,下面给出基于keras的MLP实现。没怎么细致地调参,准确率大概在98.530%左右。

from keras.layers import Dense
from keras.models import Sequential
from keras.utils import np_utils

# normalization
num_pixels = 28 * 28
x_train = x_train.reshape(x_train.shape[0], num_pixels).astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], num_pixels).astype('float32') / 255
# one-hot enconder for class
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_train.shape[1]

model = Sequential([
Dense(700, input_dim=num_pixels, activation='relu', kernel_initializer='normal'), # hidden layer
Dense(num_classes, activation='softmax', kernel_initializer='normal') # output layer
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=600, batch_size=200, verbose=2)
model.evaluate(x_test, y_test, verbose=0) # [0.10381294689745164, 0.98529999999999995]

深度学习

LeCun早在1989年发表的论文 [1]中提出了用CNN (Convolutional Neural Networks)来做手写数字识别,后来 [2]又改进到Lenet-5,其网络结构如下图所示:

卷积、池化、卷积、池化,然后套2个全连接层,最后接个Guassian连接层。众所周知,CNN自带特征提取功能,不需要刻意地设计特征提取器。基于keras,Lenet-5 非正式实现如下:

import keras
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Dense, Dropout, Flatten, Activation
from keras.models import Sequential
from keras.utils import np_utils

img_rows, img_cols = 28, 28
# TensorFlow backend: image_data_format() == 'channels_last'
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1).astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1).astype('float32') / 255
# one-hot code for class
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
num_classes = y_train.shape[1]

model = Sequential()
model.add(Conv2D(filters=6, kernel_size=(5, 5), padding='valid', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Activation("sigmoid"))

model.add(Conv2D(16, kernel_size=(5, 5), padding='valid'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Activation("sigmoid"))
model.add(Dropout(0.25))
# full connection
model.add(Conv2D(120, kernel_size=(1, 1), padding='valid'))
model.add(Flatten())
# full connection
model.add(Dense(84, activation='sigmoid'))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD(lr=0.08, momentum=0.9),
metrics=['accuracy'])
model.summary()
model.fit(x_train, y_train, batch_size=32, epochs=8,
verbose=1, validation_data=(x_test, y_test))

以上三种方法的准确率如下:

3.参考资料

[1] LeCun, Yann, et al. "Backpropagation applied to handwritten zip code recognition." Neural computation 1.4 (1989): 541-551.
[2] LeCun, Yann, et al. "Gradient-based learning applied to document recognition." Proceedings of the IEEE 86.11 (1998): 2278-2324.
[3] Taylor B. Arnold, Computer vision: LeNet-5, AlexNet, VGG-19, GoogLeNet.

【从传统方法到深度学习】图像分类相关推荐

  1. 【百家稷学】从传统方法到深度学习,人脸算法和应用的演变(河南平顶山学院技术分享)...

    继续咱们百家稷学专题,本次聚焦在人脸方向.百家稷学专题的目标,是走进100所高校和企业进行学习与分享. 分享主题 本次分享是在河南平顶山学院,主题是<从传统方法到深度学习,人脸算法和应用的演变& ...

  2. 【视频课】一课彻底掌握深度学习图像分类各种问题,学习CV你值得拥有

    课程介绍 对于刚接触深度学习计算机视觉的初学者来说,图像分类问题是最常见的问题,如何最好图像分类任务,关系到大家能否正确顺利地入门.读了许多论文,可能仍然不懂代码如何实现.跑了代码,仍旧不懂如何运用图 ...

  3. 【AI-1000问】为什么深度学习图像分类的输入多是224*224

    文章首发于微信公众号<有三AI> [AI-1000问]为什么深度学习图像分类的输入多是224*224 写在前边的通知 大家好,今天这又是一个新专栏了,名叫<有三AI 1000问> ...

  4. python人脸识别框很小_人脸识别:从传统方法到深度学习

    人脸识别:从传统方法到深度学习 这开始于上世纪七十年代,人脸识别成为了计算机视觉领域和生物识别领域最具有研究型的话题之一.传统方法依赖于手工制作模型特征,通过深度神经网络训练大量的数据集的方法也在最近 ...

  5. 深度学习 图像分类_深度学习时代您应该阅读的10篇文章了解图像分类

    深度学习 图像分类 前言 (Foreword) Computer vision is a subject to convert images and videos into machine-under ...

  6. 深度学习图像分类(六):Stochastic_Depth_Net

    深度学习图像分类(六):Stochastic_Depth_Net 文章目录 深度学习图像分类(六):Stochastic_Depth_Net 前言 一.Motivation 二.核心结构:Drop P ...

  7. PyTorch深度学习图像分类--猫狗大战

    PyTorch深度学习图像分类--猫狗大战 1.背景介绍 2.环境配置 2.1软硬件清单 2.1.1配置PyPorch 2.1.2开发软件 2.1.3 显卡 2.2 数据准备 3 基础理论 3.1Py ...

  8. 统计深度学习与最优传输理论,传统方法vs深度学习,符号主义与联结主义

    统计深度学习与最优传输理论,传统方法vs深度学习,符号主义与联结主义 统计深度学习与最优传输理论 传统计算机视觉方法与基于统计的深度学习方法 符号主义与联结主义    本文多处摘引自当深度学习遇到3D ...

  9. 目标检测的二十年发展史—从传统方法到深度学习

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 本文转载自DeepBlue深兰科技 本文主要参考自文献[1]:Zhengxia Zou, Zh ...

最新文章

  1. linux 下mysql的管理,Linux下 MySQL安装和基本管理
  2. ISA SERVER 2004 对多重网络支持功能简述
  3. Android文档-开发者指南-第一部分:入门-中英文对照版
  4. [html] 404页面有什么作用?
  5. PL/SQL Developer的错误提示弹框的文本显示乱码问题
  6. [访问系统] C#计算机信息类ComputerInfo (转载)
  7. mysql as 后面字段,mysql 字段as详解及实例代码
  8. WEB项目中使用QQ表情
  9. Safari浏览器显示网页不全问题解决方法
  10. layui控制文本框只能填写数字
  11. 短信验证码收不到了怎么办?
  12. 产业互联网将不再只是虚无缥缈,触不可及的空中楼阁
  13. 劳动仲裁委员会的具体地址即(朝阳区酒仙桥南十里居28楼的具体路线)______转...
  14. numpy学习(五)——文件的保存和读写(np.save()、np.load()、np.savez()、np.savetxt()、np.loadtxt())
  15. 51单片机(STC)串口无阻塞发送函数
  16. 自己动手写H3C校园网登录客户端(Linux平台版)
  17. 【极术通讯】2022年十大科技应用趋势
  18. C# 将字符串(符合xml格式)与XML互转
  19. 扫盲:什么是单片机时序,如何看懂时序图
  20. 【C++】9.GIS应用:开源GIS平台开发入门(MapServer+QGIS+PostGIS+OpenLayers)

热门文章

  1. dbref java_java – Spring Data REST MongoDB:检索DBRef的对...
  2. 搭建AWStats日志分析系统
  3. Linux网络深入DHCP、FTP原理和配置方法(详细图解)
  4. servlet中弹出对话框
  5. activity5.1初始密码
  6. 无头结点单链表的逆置_第1章第2节练习题11 就地逆置单链表
  7. 怎么把路由的#号去掉_VLAN应用篇系列:交换机VLAN间路由与传统单臂路由(子接口)方式...
  8. python僵尸进程和孤儿进程_python中多进程应用及僵尸进程、孤儿进程
  9. Juniper 210 密码清不掉_三分钟学会如何找回mysql密码
  10. 研华数据采集卡如何采集压力信号转化为数字信号_涨知识啦!PLC编程中如何使用开关、模拟、脉冲量...