NVIDIA DLI 深度学习入门培训 | 特设三场!!

4月28日/5月19日/5月26日一天密集式学习 轻松带你入门阅读全文>

正文共6095个字,6张图,预计阅读时间15分钟。

数据集:MNIST
框架:Keras
显卡:NVIDIA GEFORCE 750M
参考:Keras中文文档(http://keras-cn.readthedocs.io/en/latest/other/datasets/#mnist)

这是优达学城的深度学习项目,数据集和需求都很简单,关键是为了熟悉框架的使用以及项目搭建的套路,只要用很简单的卷积神经网络就能实现,准确率轻轻松松就能上90%。

需求描述

随机从MNIST数据集中选择5个或5个以下的数字,拼成一张图片,如下图所示。搭建一个模型,识别图片中的数字,空白字符的类型为0。

imgs and labels

项目实战

载入数据集

keras有

from keras.datasets import mnist

(X_raw, y_raw), (X_raw_test, y_raw_test) = mnist.load_data()

n_train, n_test = X_raw.shape[0], X_raw_test.shape[0]

查看数据集

import matplotlib.pyplot as pltimport random

%matplotlib inline
%config InlineBackend.figure_format = 'retina'for i in range(15):
   plt.subplot(3, 5, i+1)
   index = random.randint(0, n_train-1)
   plt.title(str(y_raw[index]))
   plt.imshow(X_raw[index], cmap='gray')

plt.axis('off')

dataset

合成数据

载入数据集的时候将数据集分成了训练集X_raw和测试集X_test,这里需要从X_raw中随机选取数字,然后拼成新的图片,并将20%设为验证集,防止模型过拟合。

注意:数字的长度不一定为5,不到5的以空白填充,最终图片高28长28x5=140

  • 为什么将数据分成训练集、验证集和测试集?
    训练集是用来训练模型的;验证集是用来对训练的模型进行进一步调参优化,如果使用测试集验证,网络就会记住测试集,容易使模型过拟合;测试集用来测试模型表现。

难点:

原图是28x28,拼成28x140,原来一行有28,现在一行有140,是每行做的append,用list.append效率会很低,用矩阵转置就会很快。

import numpy as np

from sklearn.model_selection import train_test_split

n_class, n_len, width, height = 11, 5, 28, 28

def generate_dataset(X, y):
   X_len = X.shape[0] # 原数据集有几个,新数据集还要有几个
   # 新数据集的shape为(X_len, 28, 28*5, 1),X_len是X的个数,原数据集是28x28,取5个数字(包含空白)拼接,则为28x140, 1是颜色通道,灰度图,所以是1
   X_gen = np.zeros((X_len, height, width*n_len, 1), dtype=np.uint8)

# 新数据集对应的label,最终的shape为(5,  X_len,11)

y_gen = [np.zeros((X_len, n_class), dtype=np.uint8) for i in range(n_len)]    
   for i in range(X_len):

# 随机确定数字长度
       rand_len = random.randint(1, 5)
       lis = list()

# 设置每个数字
       for j in range(0, rand_len):

# 随机找一个数
           index = random.randint(0, X_len - 1)

# 将对应的y置1, y是经过onehot编码的,所以y的第三维是11,0~9为10个数字,10为空白,哪个索引为1就是数字几
           y_gen[j][i][y[index]] = 1
           lis.append(X[index].T)

# 其余位取空白    
       for m in range(rand_len, 5):

# 将对应的y置1
           y_gen[m][i][10] = 1
           lis.append(np.zeros((28, 28),dtype=np.uint8))
       lis = np.array(lis).reshape(140,28).T
    
       X_gen[i] = lis.reshape(28,140,1)        
   return X_gen, y_gen
X_raw_train, X_raw_valid, y_raw_train, y_raw_valid = train_test_split(X_raw, y_raw, test_size=0.2, random_state=50)

X_train, y_train = generate_dataset(X_raw_train, y_raw_train)
X_valid, y_valid = generate_dataset(X_raw_valid, y_raw_valid)

X_test, y_test = generate_dataset(X_raw_test, y_raw_test)

显示合成的图片

# 显示生成的图片for i in range(15):
   plt.subplot(5, 3, i+1)
   index = random.randint(0, n_test-1)
   title = ''
   for j in range(n_len):
       title += str(np.argmax(y_test[j][index])) + ','
   
   plt.title(title)
   plt.imshow(X_test[index][:,:,0], cmap='gray')

plt.axis('off')

合成的图片

CNN搭建

使用了keras的函数式模型,很方便,可以参考官方文档。

由于数据集比较简答,所以随便一个网络结构都能有不错的表现,我用的是两层卷机模型,卷积层、最大池化层、卷积层、最大池化层,然后两个全连接层。

from keras.models import Modelfrom keras.layers import *import tensorflow as tf# This returns a tensorinputs = Input(shape=(28, 140, 1))

conv_11 = Conv2D(filters= 32, kernel_size=(5,5), padding='Same', activation='relu')(inputs)
max_pool_11 = MaxPool2D(pool_size=(2,2))(conv_11)
conv_12 = Conv2D(filters= 10, kernel_size=(3,3), padding='Same', activation='relu')(max_pool_11)
max_pool_12 = MaxPool2D(pool_size=(2,2), strides=(2,2))(conv_12)
flatten11 = Flatten()(max_pool_12)
hidden11 = Dense(15, activation='relu')(flatten11)
prediction1 = Dense(11, activation='softmax')(hidden11)

hidden21 = Dense(15, activation='relu')(flatten11)
prediction2 = Dense(11, activation='softmax')(hidden21)

hidden31 = Dense(15, activation='relu')(flatten11)
prediction3 = Dense(11, activation='softmax')(hidden31)

hidden41 = Dense(15, activation='relu')(flatten11)
prediction4 = Dense(11, activation='softmax')(hidden41)

hidden51 = Dense(15, activation='relu')(flatten11)
prediction5 = Dense(11, activation='softmax')(hidden51)

model = Model(inputs=inputs, outputs=[prediction1,prediction2,prediction3,prediction4,prediction5])

model.compile(optimizer='rmsprop',
             loss='categorical_crossentropy',

metrics=['accuracy'])

可视化网络

依赖 pydot-ng 和 graphviz,若出现错误,用命令行输入pip install pydot-ng & brew install graphviz

windows需要安装一下graphviz,配置一下环境

from keras.utils.vis_utils import plot_model, model_to_dotfrom IPython.display import Image, SVG

SVG(model_to_dot(model).create(prog='dot', format='svg'))

网络可视化

训练模型

训练20代,如果验证集上的准确率连续两次没有提高,就减小学习率。显卡不是很好,但依然很快,大概20分钟左右就学好了。

from keras.callbacks import ReduceLROnPlateau

learnrate_reduce_1 = ReduceLROnPlateau(monitor='val_dense_2_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_2 = ReduceLROnPlateau(monitor='val_dense_4_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_3 = ReduceLROnPlateau(monitor='val_dense_6_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_4 = ReduceLROnPlateau(monitor='val_dense_8_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)
learnrate_reduce_5 = ReduceLROnPlateau(monitor='val_dense_10_acc', patience=2, verbose=1,factor=0.8, min_lr=0.00001)

model.fit(X_train, y_train, epochs=20, batch_size=128,
         validation_data=(X_valid, y_valid),

callbacks=[learnrate_reduce_1,learnrate_reduce_2,learnrate_reduce_3,learnrate_reduce_4,learnrate_reduce_5])

计算模型准确率

5个数字全部识别正确为正确,错一个即为错。可以用循环一一比对,我这里用了些概率论知识,因为都是独立事件,所以5个数字的准确率乘起来就是模型准确率。

def evaluate(model):
   # TODO: 按照错一个就算错的规则计算准确率.
   result = model.evaluate(np.array(X_test).reshape(len(X_test),28,140,1), [y_test[0], y_test[1], y_test[2], y_test[3], y_test[4]], batch_size=32)    return result[6] * result[7] * result[8] * result[9] * result[10]

evaluate(model)

最后可以得到0.9476的正确率。

预测值可视化

def get_result(result):
   # 将 one_hot 编码解码
   resultstr = ''
   for i in range(n_len):
       resultstr += str(np.argmax(result[i])) + ','
   return resultstr

index = random.randint(0, n_test-1)
y_pred = model.predict(X_test[index].reshape(1,28,140,1))

plt.title('real: %s\npred:%s'%(get_result([y_test[x][index] for x in range(n_len)]), get_result(y_pred)))
plt.imshow(X_test[index,:,:,0], cmap='gray')

plt.axis('off')

预测结果可视化

保存模型

model.save('model.h5'

原文链接:https://www.jianshu.com/p/79265078c95b

查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org

请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看

LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础

卷积神经网络实现多个数字识别相关推荐

  1. 读书笔记-深度学习入门之pytorch-第四章(含卷积神经网络实现手写数字识别)(详解)

    1.卷积神经网络在图片识别上的应用 (1)局部性:对一张照片而言,需要检测图片中的局部特征来决定图片的类别 (2)相同性:可以用同样的模式去检测不同照片的相同特征,只不过这些特征处于图片中不同的位置, ...

  2. 深度学习 卷积神经网络-Pytorch手写数字识别

    深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...

  3. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  4. 【图像识别】基于卷积神经网络cnn实现银行卡数字识别matlab源码

    1 基于卷积神经网络cnn实现银行卡数字识别模型 模型参考这里. 2 部分代码 %印刷体识别 clc;clear;close all; addpath('util/'); addpath('data/ ...

  5. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  6. keras从入门到放弃(十三)卷积神经网络处理手写数字识别

    今天来一个cnn例子 手写数字识别,因为是图像数据 import keras from keras import layers import numpy as np import matplotlib ...

  7. 卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...

    LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧. 环境 Pyc ...

  8. CNN卷积神经网络实现手写数字识别(基于tensorflow)

    1.1卷积神经网络简介 文章目录 1.1卷积神经网络简介 1.2 神经网络 1.2.1 神经元模型 1.2.2 神经网络模型 1.3 卷积神经网络 1.3.1卷积的概念 1.3.2 卷积的计算过程 1 ...

  9. 【图像识别】基于卷积神经网络CNN手写数字识别matlab代码

    1 简介 针对传统手写数字的随机性,无规律性等问题,为了提高手写数字识别的检测准确性,本文在研究手写数字区域特点的基础上,提出了一种新的手写数字识别检测方法.首先,对采集的手写数字图像进行预处理,由于 ...

最新文章

  1. Swing 实现聊天系统 私发与群发
  2. CSDP是个好东西——CSDP 认证考试简介
  3. ansible的安装和ansible的模板
  4. 从头开始学eShopOnContainers——Visual Studio 2017环境配置
  5. thoughtworks面试题分析与解答
  6. 基于SSM的学生宿舍管理系统
  7. hibernate详细教程(入门到熟练)
  8. 推荐一个项目管理工具:TAPD
  9. 掌握USB/HDMI/MHL/DP验证规范 高速接口传输一次上手
  10. PCL点云滤波器总结
  11. PS选区工具和羽化的运用
  12. startx 命令详解
  13. Kali Linxu中打开Apache服务
  14. 《如何阅读一本书》——读书方法的整理
  15. CMD命令下修改和查看IP地址,DNS,网关
  16. configure: error: no acceptable C compiler found in $PATH 问题解决
  17. 统计学基础(假设检验、两个总体均值之差检验,独立样本t检验,配对样本t检验)
  18. Python编程:实现凯撒密码加密解密
  19. 不止SQL优化,数据库还有哪些优化大法?
  20. OptiView® XG 网络分析平板电脑特性(上)

热门文章

  1. python第一周小测验_测验1: Python基本语法元素 (第1周)-程序题
  2. java获取文件视图_springmvc-直接访问视图文件
  3. 遍历数组长度_Java基础之数组
  4. java购物车商品排序_Java购物车
  5. 计算机网络的自我介绍和评价,计算机网络自我介绍范文
  6. android studio光标变成黑块,解决Android Studio 代码无提示无颜色区分问题
  7. 怎么修改服务器上的分数,教资成绩查询服务器爆了?这里有个小技巧教你!
  8. 自动化测试---Assert
  9. 网络安全实验报告 第一章
  10. UIButton 的简单运用