Keras Mnist

在这里,我们将利用Keras搭建一个深度学习网络对mnist数据集进行识别。

  • 本文参考 keras-mnist-tutorial
  • 整个代码分为三个部分:
    1. 数据准备
    2. 模型搭建
    3. 训练优化

让我们开始吧

首先先导入一些模块

%matplotlib inline
import numpy as np
import matplotlib.pyplot as pltfrom keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
from keras.utils import np_utils

数据准备

我们通过keras自带的数据集mnist进行导入数据,然后对其归一化处理,并且将原二维数据变成一维数据,作为网络的输入。

读入mnist数据集。可以看到每条样本是一个28*28的矩阵,共有60000个训练数据,10000个测试数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data();
print(x_train.shape)
print(x_test.shape)
(60000, 28, 28)
(10000, 28, 28)

将一些样本图像打印出来看看

for i in range(9):plt.subplot(3,3,i+1)plt.imshow(X_train[i], cmap='gray', interpolation='none')plt.title("Class {}".format(y_train[i]))

将二维数据变成一维数据

X_train = X_train.reshape(len(X_train), -1)
X_test = X_test.reshape(len(X_test), -1)

接下来对数据进行归一化。原来的数据范围是[0,255],我们通过归一化时靠近0附近。归一化的方式有很多,大家随意。

# uint不能有负数,我们先转为float类型
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train = (X_train - 127) / 127
X_test = (X_test - 127) / 127

接下来 One-hot encoding

nb_classes = 10
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test = np_utils.to_categorical(y_test, nb_classes)

搭建网络

数据已经准备好了,接下来我们进行网络的搭建,我们的网络有三层,都是全连接网络,大概长的像这样

这里新遇到一个Dropout,这是一种防止过拟合(overfitting)的方法,详见Dropout层

model = Sequential()model.add(Dense(512, input_shape=(784,), kernel_initializer='he_normal'))
model.add(Activation('relu'))
model.add(Dropout(0.2)) model.add(Dense(512, kernel_initializer='he_normal'))
model.add(Activation('relu'))
model.add(Dropout(0.2)) model.add(Dense(nb_classes))
model.add(Activation('softmax'))

OK!模型搭建好了,我们通过编译对学习过程进行配置

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

那么我们进行训练吧

model.fit(X_train, y_train, epochs=20, batch_size=64, verbose=1, validation_split=0.05)
Train on 57000 samples, validate on 3000 samples
Epoch 1/20
57000/57000 [==============================] - 19s 327us/step - loss: 0.0811 - acc: 0.9761 - val_loss: 0.0702 - val_acc: 0.9810
Epoch 2/20
57000/57000 [==============================] - 19s 328us/step - loss: 0.0752 - acc: 0.9772 - val_loss: 0.0720 - val_acc: 0.9813
Epoch 3/20
57000/57000 [==============================] - 19s 331us/step - loss: 0.0687 - acc: 0.9788 - val_loss: 0.0670 - val_acc: 0.9830
Epoch 4/20
57000/57000 [==============================] - 20s 350us/step - loss: 0.0667 - acc: 0.9794 - val_loss: 0.0755 - val_acc: 0.9810
Epoch 5/20
57000/57000 [==============================] - 20s 353us/step - loss: 0.0688 - acc: 0.9794 - val_loss: 0.0671 - val_acc: 0.9820
Epoch 6/20
57000/57000 [==============================] - 20s 346us/step - loss: 0.0639 - acc: 0.9807 - val_loss: 0.0744 - val_acc: 0.9790
Epoch 7/20
57000/57000 [==============================] - 20s 342us/step - loss: 0.0626 - acc: 0.9805 - val_loss: 0.0685 - val_acc: 0.9837
Epoch 8/20
57000/57000 [==============================] - 21s 365us/step - loss: 0.0669 - acc: 0.9796 - val_loss: 0.0988 - val_acc: 0.9757
Epoch 9/20
57000/57000 [==============================] - 20s 345us/step - loss: 0.0605 - acc: 0.9819 - val_loss: 0.0769 - val_acc: 0.9833
Epoch 10/20
57000/57000 [==============================] - 19s 338us/step - loss: 0.0592 - acc: 0.9820 - val_loss: 0.0576 - val_acc: 0.9870
Epoch 11/20
57000/57000 [==============================] - 19s 336us/step - loss: 0.0600 - acc: 0.9822 - val_loss: 0.0689 - val_acc: 0.9847
Epoch 12/20
57000/57000 [==============================] - 20s 345us/step - loss: 0.0625 - acc: 0.9813 - val_loss: 0.0689 - val_acc: 0.9843
Epoch 13/20
57000/57000 [==============================] - 20s 346us/step - loss: 0.0573 - acc: 0.9829 - val_loss: 0.0679 - val_acc: 0.9853
Epoch 14/20
57000/57000 [==============================] - 19s 342us/step - loss: 0.0555 - acc: 0.9833 - val_loss: 0.0642 - val_acc: 0.9850
Epoch 15/20
57000/57000 [==============================] - 20s 359us/step - loss: 0.0571 - acc: 0.9831 - val_loss: 0.0779 - val_acc: 0.9833
Epoch 16/20
57000/57000 [==============================] - 21s 361us/step - loss: 0.0564 - acc: 0.9831 - val_loss: 0.0610 - val_acc: 0.9867
Epoch 17/20
57000/57000 [==============================] - 20s 354us/step - loss: 0.0574 - acc: 0.9834 - val_loss: 0.0669 - val_acc: 0.9867
Epoch 18/20
57000/57000 [==============================] - 20s 353us/step - loss: 0.0526 - acc: 0.9848 - val_loss: 0.0863 - val_acc: 0.9830
Epoch 19/20
57000/57000 [==============================] - 20s 349us/step - loss: 0.0548 - acc: 0.9832 - val_loss: 0.0726 - val_acc: 0.9847
Epoch 20/20
57000/57000 [==============================] - 20s 352us/step - loss: 0.0512 - acc: 0.9845 - val_loss: 0.0735 - val_acc: 0.9860<keras.callbacks.History at 0x2904822cd30>

训练完毕,测试测试

loss, accuracy = model.evaluate(X_test, y_test)
print('Test loss:', loss)
print('Accuracy:', accuracy)
10000/10000 [==============================] - 1s 80us/step
Test loss: 0.0864374790877
Accuracy: 0.9817

Keras-2 Keras Mnist相关推荐

  1. 【Keras】30 秒上手 Keras+实例对mnist手写数字进行识别准确率达99%以上

    本文我们将学习使用Keras一步一步搭建一个卷积神经网络.具体来说,我们将使用卷积神经网络对手写数字(MNIST数据集)进行识别,并达到99%以上的正确率. @为什么选择Keras呢? 主要是因为简单 ...

  2. Keras——用Keras搭建自编码神经网络(AutoEncoder)

    文章目录 1.前言 2.用Keras搭建自编码神经网络 2.1.导入必要模块 2.2.数据预处理 2.3.搭建模型 2.4.实例化并激活模型 2.5.训练 2.6.可视化 1.前言 自编码,简单来说就 ...

  3. Keras——用Keras搭建RNN分类循环神经网络

    文章目录 1.前言 2.用Keras搭建RNN循环神经网络 2.1.导入必要模块 2.2.超参数设置 2.3.数据预处理 2.4.搭建模型 2.5.激活模型 2.6.训练+测试 1.前言 这次我们用循 ...

  4. Keras——用Keras搭建分类神经网络

    文章目录 1.前言 2.用Keras搭建分类神经网络 2.1.导入必要模块 2.2.数据预处理 2.3.搭建模型 2.4.激活模型 2.5.训练+测试 1.前言 今天用 Keras 来构建一个分类神经 ...

  5. Magnitude-based weight pruning with Keras(keras模型权重裁剪)

    keras模型权重裁剪 https://github.com/lixiaolei1982/model-optimization/blob/master/tensorflow_model_optimiz ...

  6. keras系列︱keras是如何指定显卡且限制显存用量

    keras系列︱keras是如何指定显卡且限制显存用量 原创 2017年07月21日 10:59:24 标签: keras / gpu / 显卡 / 指定 / 限制 6630 keras在使用GPU的 ...

  7. DL之Keras:keras保存网络结构、网络拓扑图、网络模型(json、yaml、h5等)注意事项及代码实现

    DL之Keras:keras保存网络结构.网络拓扑图.网络模型(json.yaml.h5等)注意事项及代码实现 目录 keras保存网络结构.网络拓扑图.网络模型(json.yaml.h5等)注意事项 ...

  8. DL之Keras: Keras深度学习框架的注意事项(默认下载存放路径等)、使用方法之详细攻略

    DL之Keras: Keras深度学习框架的注意事项(自动下载存放路径等).使用方法之详细攻略 目录 Keras深度学习框架的注意事项 1.Keras自动下载默认数据集/模型存放位置 Windows系 ...

  9. 【keras】keras教程(参考官方文档)

    文章目录 一.callbacks篇 1.ReduceLROnPlateau 训练过程优化学习率 2.EarlyStopping 早停操作 3.ModelCheckpoint 用于设置保存的方式 4.T ...

  10. keras学习- No module named ' tensorflow.keras ' 报错,看清 tf.keras与keras

    环境描述: 系统ubantu16.04 安装anaconda  版本conda 4.5.4 创建虚拟环境 tf-gpu tensorflow-gpu版本(1.7.0-gpu, 能够import ten ...

最新文章

  1. 一个java的DES加解密类转换成C#
  2. oracle账户锁定解决方法
  3. Linux permission denied解决方法
  4. 卷积神经网络 池化层上采样(upsampling、interpolating)、下采样(subsampled、downsampled)是什么?(上采样为放大图像或图像插值、下采样为缩小图像)
  5. ImageMagick 打水印支持透明度设置
  6. 【渝粤题库】陕西师范大学292969 会计学 作业 (专升本、高起本)
  7. otis电梯服务器tt使用说明_南充私人电梯
  8. 私人定制-代码生成器3
  9. 线程的共享资源和私有资源
  10. 25款精选免费小程序源码demo下载
  11. Originpro拟合Gompertz模型
  12. VBM法MRI图像处理——记第一次使用cat12
  13. python根据词性进行词频统计_如何根据词性来确定语篇中的词频?
  14. lookup无序查找_Excel LOOKUP不排序怎么快速找到数据_lookup函数讲解
  15. YELP NLP 英文文本断句
  16. 【ABAP系列】SAP ABAP smartforms设备类型CNSAPWIN不支持页格式ZXXX
  17. 理解体检报告10个必须项目
  18. 反垃圾邮件的一些相关链接
  19. 苹果敢对“赞赏”分成30%真的是靠垄断吗?
  20. ArcGIS API for JavaScript创建 3D 地图

热门文章

  1. etag java_你知道HTTP协议的ETag是干什么的吗?
  2. sendfile实现文件服务器,sendfile
  3. 大学计算机专业全民,计算机专业大学排名实力顺序(上大学国内计算机专业大学哪个好值得报读)...
  4. 关于游戏的C 语言的课设报告,猜单词游戏C课程设计报告.doc
  5. 运行项目报错invalid notify_url
  6. Java学习笔记2.1.2 Java基本语法 - Java三种注释方式
  7. 《天天数学》连载37:二月六日
  8. 统计学基础学习笔记:描述统计量
  9. 【BZOJ4562】食物链,拓扑DP
  10. 【BZOJ1911】【codevs1318】特别行动队,斜率优化DP