卷积神经网络实现多个数字识别
NVIDIA DLI 深度学习入门培训 | 特设三场!!
4月28日/5月19日/5月26日一天密集式学习 轻松带你入门阅读全文>
正文共6095个字,6张图,预计阅读时间15分钟。
数据集:MNIST
框架:Keras
显卡:NVIDIA GEFORCE 750M
参考:Keras中文文档(http://keras-cn.readthedocs.io/en/latest/other/datasets/#mnist)
这是优达学城的深度学习项目,数据集和需求都很简单,关键是为了熟悉框架的使用以及项目搭建的套路,只要用很简单的卷积神经网络就能实现,准确率轻轻松松就能上90%。
需求描述
随机从MNIST数据集中选择5个或5个以下的数字,拼成一张图片,如下图所示。搭建一个模型,识别图片中的数字,空白字符的类型为0。
imgs and labels
项目实战
载入数据集
keras有
from keras.datasets import mnist
(X_raw, y_raw), (X_raw_test, y_raw_test) = mnist.load_data()
n_train, n_test = X_raw.shape[0], X_raw_test.shape[0]
查看数据集
import matplotlib.pyplot as pltimport random
%matplotlib inline
%config InlineBackend.figure_format = 'retina'for i in range(15):
plt.subplot(3, 5, i+1)
index = random.randint(0, n_train-1)
plt.title(str(y_raw[index]))
plt.imshow(X_raw[index], cmap='gray')
plt.axis('off')
dataset
合成数据
载入数据集的时候将数据集分成了训练集X_raw和测试集X_test,这里需要从X_raw中随机选取数字,然后拼成新的图片,并将20%设为验证集,防止模型过拟合。
注意:数字的长度不一定为5,不到5的以空白填充,最终图片高28长28x5=140
为什么将数据分成训练集、验证集和测试集?
训练集是用来训练模型的;验证集是用来对训练的模型进行进一步调参优化,如果使用测试集验证,网络就会记住测试集,容易使模型过拟合;测试集用来测试模型表现。
难点:
原图是28x28,拼成28x140,原来一行有28,现在一行有140,是每行做的append,用list.append效率会很低,用矩阵转置就会很快。
import numpy as np
from sklearn.model_selection import train_test_split
n_class, n_len, width, height = 11, 5, 28, 28
def generate_dataset(X, y):
X_len = X.shape[0] # 原数据集有几个,新数据集还要有几个
# 新数据集的shape为(X_len, 28, 28*5, 1),X_len是X的个数,原数据集是28x28,取5个数字(包含空白)拼接,则为28x140, 1是颜色通道,灰度图,所以是1
X_gen = np.zeros((X_len, height, width*n_len, 1), dtype=np.uint8)
# 新数据集对应的label,最终的shape为(5, X_len,11)
y_gen = [np.zeros((X_len, n_class), dtype=np.uint8) for i in range(n_len)]
for i in range(X_len):
# 随机确定数字长度
rand_len = random.randint(1, 5)
lis = list()
# 设置每个数字
for j in range(0, rand_len):
# 随机找一个数
index = random.randint(0, X_len - 1)
# 将对应的y置1, y是经过onehot编码的,所以y的第三维是11,0~9为10个数字,10为空白,哪个索引为1就是数字几
y_gen[j][i][y[index]] = 1
lis.append(X[index].T)
# 其余位取空白
for m in range(rand_len, 5):
# 将对应的y置1
y_gen[m][i][10] = 1
lis.append(np.zeros((28, 28),dtype=np.uint8))
lis = np.array(lis).reshape(140,28).T
X_gen[i] = lis.reshape(28,140,1)
return X_gen, y_gen
X_raw_train, X_raw_valid, y_raw_train, y_raw_valid = train_test_split(X_raw, y_raw, test_size=0.2, random_state=50)
X_train, y_train = generate_dataset(X_raw_train, y_raw_train)
X_valid, y_valid = generate_dataset(X_raw_valid, y_raw_valid)
X_test, y_test = generate_dataset(X_raw_test, y_raw_test)
显示合成的图片
# 显示生成的图片for i in range(15):
plt.subplot(5, 3, i+1)
index = random.randint(0, n_test-1)
title = ''
for j in range(n_len):
title += str(np.argmax(y_test[j][index])) + ','
plt.title(title)
plt.imshow(X_test[index][:,:,0], cmap='gray')
plt.axis('off')
合成的图片
CNN搭建
使用了keras的函数式模型,很方便,可以参考官方文档。
由于数据集比较简答,所以随便一个网络结构都能有不错的表现,我用的是两层卷机模型,卷积层、最大池化层、卷积层、最大池化层,然后两个全连接层。
from keras.models import Modelfrom keras.layers import *import tensorflow as tf# This returns a tensorinputs = Input(shape=(28, 140, 1))
conv_11 = Conv2D(filters= 32, kernel_size=(5,5), padding='Same', activation='relu')(inputs)
max_pool_11 = MaxPool2D(pool_size=(2,2))(conv_11)
conv_12 = Conv2D(filters= 10, kernel_size=(3,3), padding='Same', activation='relu')(max_pool_11)
max_pool_12 = MaxPool2D(pool_size=(2,2), strides=(2,2))(conv_12)
flatten11 = Flatten()(max_pool_12)
hidden11 = Dense(15, activation='relu')(flatten11)
prediction1 = Dense(11, activation='softmax')(hidden11)
hidden21 = Dense(15, activation='relu')(flatten11)
prediction2 = Dense(11, activation='softmax')(hidden21)
hidden31 = Dense(15, activation='relu')(flatten11)
prediction3 = Dense(11, activation='softmax')(hidden31)
hidden41 = Dense(15, activation='relu')(flatten11)
prediction4 = Dense(11, activation='softmax')(hidden41)
hidden51 = Dense(15, activation='relu')(flatten11)
prediction5 = Dense(11, activation='softmax')(hidden51)
model = Model(inputs=inputs, outputs=[prediction1,prediction2,prediction3,prediction4,prediction5])
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
可视化网络
依赖 pydot-ng 和 graphviz,若出现错误,用命令行输入pip install pydot-ng & brew install graphviz
windows需要安装一下graphviz,配置一下环境
from keras.utils.vis_utils import plot_model, model_to_dotfrom IPython.display import Image, SVG
SVG(model_to_dot(model).create(prog='dot', format='svg'))
网络可视化
训练模型
训练20代,如果验证集上的准确率连续两次没有提高,就减小学习率。显卡不是很好,但依然很快,大概20分钟左右就学好了。
from keras.callbacks import ReduceLROnPlateau
learnrate_reduce_1 = ReduceLROnPlateau(monitor='val_dense_2_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_2 = ReduceLROnPlateau(monitor='val_dense_4_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_3 = ReduceLROnPlateau(monitor='val_dense_6_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_4 = ReduceLROnPlateau(monitor='val_dense_8_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_5 = ReduceLROnPlateau(monitor='val_dense_10_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
model.fit(X_train, y_train, epochs=20, batch_size=128,
validation_data=(X_valid, y_valid),
callbacks=[learnrate_reduce_1,learnrate_reduce_2,learnrate_reduce_3,learnrate_reduce_4,learnrate_reduce_5])
计算模型准确率
5个数字全部识别正确为正确,错一个即为错。可以用循环一一比对,我这里用了些概率论知识,因为都是独立事件,所以5个数字的准确率乘起来就是模型准确率。
def evaluate(model):
# TODO: 按照错一个就算错的规则计算准确率.
result = model.evaluate(np.array(X_test).reshape(len(X_test),28,140,1), [y_test[0], y_test[1], y_test[2], y_test[3], y_test[4]], batch_size=32) return result[6] * result[7] * result[8] * result[9] * result[10]
evaluate(model)
最后可以得到0.9476的正确率。
预测值可视化
def get_result(result):
# 将 one_hot 编码解码
resultstr = ''
for i in range(n_len):
resultstr += str(np.argmax(result[i])) + ','
return resultstr
index = random.randint(0, n_test-1)
y_pred = model.predict(X_test[index].reshape(1,28,140,1))
plt.title('real: %s\npred:%s'%(get_result([y_test[x][index] for x in range(n_len)]), get_result(y_pred)))
plt.imshow(X_test[index,:,:,0], cmap='gray')
plt.axis('off')
预测结果可视化
保存模型
model.save('model.h5'
原文链接:https://www.jianshu.com/p/79265078c95b
查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:
www.leadai.org
请关注人工智能LeadAI公众号,查看更多专业文章
大家都在看
LSTM模型在问答系统中的应用
基于TensorFlow的神经网络解决用户流失概览问题
最全常见算法工程师面试题目整理(一)
最全常见算法工程师面试题目整理(二)
TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络
装饰器 | Python高级编程
今天不如来复习下Python基础
卷积神经网络实现多个数字识别相关推荐
- 读书笔记-深度学习入门之pytorch-第四章(含卷积神经网络实现手写数字识别)(详解)
1.卷积神经网络在图片识别上的应用 (1)局部性:对一张照片而言,需要检测图片中的局部特征来决定图片的类别 (2)相同性:可以用同样的模式去检测不同照片的相同特征,只不过这些特征处于图片中不同的位置, ...
- 深度学习 卷积神经网络-Pytorch手写数字识别
深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...
- 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)
基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...
- 【图像识别】基于卷积神经网络cnn实现银行卡数字识别matlab源码
1 基于卷积神经网络cnn实现银行卡数字识别模型 模型参考这里. 2 部分代码 %印刷体识别 clc;clear;close all; addpath('util/'); addpath('data/ ...
- 卷积神经网络CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- keras从入门到放弃(十三)卷积神经网络处理手写数字识别
今天来一个cnn例子 手写数字识别,因为是图像数据 import keras from keras import layers import numpy as np import matplotlib ...
- 卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...
LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧. 环境 Pyc ...
- CNN卷积神经网络实现手写数字识别(基于tensorflow)
1.1卷积神经网络简介 文章目录 1.1卷积神经网络简介 1.2 神经网络 1.2.1 神经元模型 1.2.2 神经网络模型 1.3 卷积神经网络 1.3.1卷积的概念 1.3.2 卷积的计算过程 1 ...
- 【图像识别】基于卷积神经网络CNN手写数字识别matlab代码
1 简介 针对传统手写数字的随机性,无规律性等问题,为了提高手写数字识别的检测准确性,本文在研究手写数字区域特点的基础上,提出了一种新的手写数字识别检测方法.首先,对采集的手写数字图像进行预处理,由于 ...
最新文章
- Swing 实现聊天系统 私发与群发
- CSDP是个好东西——CSDP 认证考试简介
- ansible的安装和ansible的模板
- 从头开始学eShopOnContainers——Visual Studio 2017环境配置
- thoughtworks面试题分析与解答
- 基于SSM的学生宿舍管理系统
- hibernate详细教程(入门到熟练)
- 推荐一个项目管理工具:TAPD
- 掌握USB/HDMI/MHL/DP验证规范 高速接口传输一次上手
- PCL点云滤波器总结
- PS选区工具和羽化的运用
- startx 命令详解
- Kali Linxu中打开Apache服务
- 《如何阅读一本书》——读书方法的整理
- CMD命令下修改和查看IP地址,DNS,网关
- configure: error: no acceptable C compiler found in $PATH 问题解决
- 统计学基础(假设检验、两个总体均值之差检验,独立样本t检验,配对样本t检验)
- Python编程:实现凯撒密码加密解密
- 不止SQL优化,数据库还有哪些优化大法?
- OptiView® XG 网络分析平板电脑特性(上)
热门文章
- python第一周小测验_测验1: Python基本语法元素 (第1周)-程序题
- java获取文件视图_springmvc-直接访问视图文件
- 遍历数组长度_Java基础之数组
- java购物车商品排序_Java购物车
- 计算机网络的自我介绍和评价,计算机网络自我介绍范文
- android studio光标变成黑块,解决Android Studio 代码无提示无颜色区分问题
- 怎么修改服务器上的分数,教资成绩查询服务器爆了?这里有个小技巧教你!
- 自动化测试---Assert
- 网络安全实验报告 第一章
- UIButton 的简单运用