1)手写数据集

手写数据集是深度学习中,最基础应用最广泛的数据集。

手写数据集内置在keras中

import keras
from keras import layers
import matplotlib. pyplot as plt
import numpy as np
import keras.datasets.mnist as mnist# 1)加载数据集
(train_image, train_label),(test_image,test_label) = mnist.load_data()# 2)验证数据集的性质
train_image.shape,train_label.shape
test_image.shape, test_label.shape
plt.imshow(train_image[0])# 3)初始化一个模型
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28)  ----> (600000, 28*28)
# 建立全链接层, 使用relu激活
model.add(layers.Dense(64, activation='relu'))
# 添加一个分类层,使用softmax激活。输出0-9是个数字,所以单元数为10
model.add(layers.Dense(10, activation='softmax'))# 4)编译模型
# 当label是顺序编码的时候,计算交叉熵是 sparse_categorical_crossentropy
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])# 5)训练模型
model.fit(train_image, train_label, epochs=50, batch_size=512)
# batch_size = 512, 一个batch一个batch的去训练,不是将所有数据拿进去训练
# 原因:计算机的性能或者说计算机的内存容量在处理大型数据的时候,比如说图片数据的时候,
# 将全部数据加载进去,可能会引起内存爆炸。model.evaluate(train_image,train_label)
model.evaluate(test_image,test_label)# 预测test数据集的前10张图片
model.predict(test_image[:10])
# 预测的
np.argmax(model.predict(test_image[:10]),axis=1)
# 实际的
test_label[:10]# 模型的优化   -----  进行过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28)  ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])model.fit(train_image,train_label,epochs=50,batch_size=512, validation_data=(test_image,test_label))# 模型的再优化   ---- 增加过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28)  ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(10, activation='softmax'))model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])model.fit(train_image,train_label,epochs=200,batch_size=512, validation_data=(test_image,test_label))

2)代码分析:

1、验证数据集特性:

train_image.shape, train_label.shape

 、

test_image.shape,test_label.shape

plt.imshow(train_image[0])

2、初始化模型

model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28)  ----> (600000, 28*28)
# 建立全链接层, 使用relu激活
model.add(layers.Dense(64, activation='relu'))
# 添加一个分类层,使用softmax激活。输出0-9是个数字,所以单元数为10
model.add(layers.Dense(10, activation='softmax'))

3、编译模型

# 当label是顺序编码的时候,计算交叉熵是 sparse_categorical_crossentropy
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])

4、训练模型

model.fit(train_image, train_label, epochs=50, batch_size=512)

batch_size = 512, 一个batch(批次)一个batch(批次)的去训练,不是将所有数据拿进去训练
原因:计算机的性能或者说计算机的内存容量在处理大型数据的时候,比如说图片数据的时候,
将全部数据加载进去,可能会引起内存爆炸。

5、预测模型

# 预测test数据集的前10张图片
model.predict(test_image[:10])
# 预测的
np.argmax(model.predict(test_image[:10]),axis=1)
# 实际的
test_label[:10]

预测的

实际的

发现,倒数第三位预测错误了。7 和9长得比较像

3)模型的优化

1、初始化模型

添加多个隐藏层 --- 增加网络容量直到过拟合

# 模型的优化   -----  进行过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28)  ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))# 多添加的两个隐藏层
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))model.add(layers.Dense(10, activation='softmax'))

2、编译模型

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])

3、训练模型

model.fit(train_image,train_label,epochs=50,batch_size=512, validation_data=(test_image,test_label))

参数 “ validation_data ”         在训练过程中观察其 在测试数据集上的表现

4)模型的再优化与抑制过拟合

添加DropOut层抑制过拟合

# 模型的再优化   ---- 增加过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28)  ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(10, activation='softmax'))model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])model.fit(train_image,train_label,epochs=200,batch_size=512, validation_data=(test_image,test_label))

#  后续工作:继续增大网络容量,直到过拟合

参考《3.8 网络参数选择的总原则》

4.1 keras基础实例 手写数字识别相关推荐

  1. 使用tf.keras搭建mnist手写数字识别网络

    使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...

  2. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  3. 深度学习入门实例——基于keras的mnist手写数字识别

    本文介绍了利用keras做mnist数据集的手写数字识别. 参考网址 http://www.cnblogs.com/lc1217/p/7132364.html mnist数据集中的图片为28*28的单 ...

  4. 1、基于Keras、Mnist手写数字识别数据集构建全连接(FC)神经网络训练模型

    文章目录 前言 一.MNIST数据集是什么? 二.构建神经网络训练模型 1.导入库 2.载入数据 3.数据处理 4.创建模型 5.编译模型 6.训练模型 7.评估模型 三.总代码 前言 提示: 1.本 ...

  5. keras框架实现手写数字识别

    详细细节可学习从零开始神经网络:keras框架实现数字图像识别详解! 代码实现: [1] ''' 将训练数据和检测数据加载到内存中(第一次运行需要下载数据,会比较慢): (mnist是手写数据集) t ...

  6. Multi-class Classication (多分类问题)实例--手写数字识别

    本实例整理自斯坦福机器学习课程课后练习ex3 本例是对一个手写体的数据集(0-9)进行分类,也就是对原有的数据集进行训练,然后给定一个手写体,识别该手写体是数字几.其分类思想就是之前Andrew Ng ...

  7. 【TensorFlow-windows】keras接口——卷积手写数字识别,模型保存和调用

    前言 上一节学习了以TensorFlow为底端的keras接口最简单的使用,这里就继续学习怎么写卷积分类模型和各种保存方法(仅保存权重.权重和网络结构同时保存) 国际惯例,参考博客: 官方教程 [注] ...

  8. caffe(4):mnist实例---手写数字识别

    深度学习的第一个实例一般都是mnist,只要这个例子完全弄懂了,其它的就是举一反三的事了.由于篇幅原因,本文不具体介绍配置文件里面每个参数的具体函义,如果想弄明白的,请参看我以前的博文: 数据层及参数 ...

  9. tensorflow2.0基础操作-手写数字识别实战

    import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, ...

最新文章

  1. RGB-D对红外热像仪和毫米波雷达标定
  2. 王建民做客第六期青年学者月度沙龙 分享工业软件的开源创新发展模式
  3. 编译 glibc-2.14 时出现的一个LD_LIBRARY_PATH不路径bug
  4. web开发快速提高工作效率的一些资源
  5. FCN全连接卷积网络(5)--Fully Convolutional Networks for Semantic Segmentation阅读(相关工作部分)
  6. IOS之Swift5.x开发通讯录实战
  7. 如何通过建造餐厅来了解Scala差异
  8. 《Python Cookbook 3rd》笔记(3.8):分数运算
  9. java采集温湿度水浸_智能电力水浸监控解决方案
  10. INFO:AdminStudio Debug
  11. css3的那些高级选择器一
  12. win7网络改局域网计算机名,局域网共享一键修复工具(支持win7) 修复windows7各种共享问题...
  13. 进化吧,MySQL锁!无锁->偏向锁->轻量级锁->重量级锁(请自动脑补数码宝贝进化音)
  14. POTN——新时代网络融合的必经之路
  15. 微信小程序appid的修改方法
  16. 全球2%高智商天才必测脑力题!却只有1%的人,能在5分钟内全部做对!
  17. java split 冒号_Java中字符串split() 的使用方法,没你想的那么简单
  18. Explaining Knowledge Distillation by Quantifying the Knowledge
  19. 如何给apk安装包去毒,避免被识别为病毒和木马
  20. 室外无人驾驶挑战赛小结-2019‘恩智浦’杯全国大学生智能车竞赛

热门文章

  1. ShardingSphere分库分表实战
  2. Electron自定义软件卸载流程
  3. 桌面显示器带Type-c接口 支持65W充电和投屏方案
  4. 网络监控2013:IP Camera民用市场暴增
  5. jsp 分页查找算法
  6. NP管理器V3.0.18之第三方MT管理器VIP版
  7. 14.URL重写技术
  8. 【手把手教你整合SSM项目并且完成入门项目到成功运行!!!】
  9. c语言宏函数返回值,C++宏定义方法的返回值
  10. jq slimScroll 滚动条插件 回到顶部方法