4.1 keras基础实例 手写数字识别
1)手写数据集
手写数据集是深度学习中,最基础应用最广泛的数据集。
手写数据集内置在keras中
import keras
from keras import layers
import matplotlib. pyplot as plt
import numpy as np
import keras.datasets.mnist as mnist# 1)加载数据集
(train_image, train_label),(test_image,test_label) = mnist.load_data()# 2)验证数据集的性质
train_image.shape,train_label.shape
test_image.shape, test_label.shape
plt.imshow(train_image[0])# 3)初始化一个模型
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28) ----> (600000, 28*28)
# 建立全链接层, 使用relu激活
model.add(layers.Dense(64, activation='relu'))
# 添加一个分类层,使用softmax激活。输出0-9是个数字,所以单元数为10
model.add(layers.Dense(10, activation='softmax'))# 4)编译模型
# 当label是顺序编码的时候,计算交叉熵是 sparse_categorical_crossentropy
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])# 5)训练模型
model.fit(train_image, train_label, epochs=50, batch_size=512)
# batch_size = 512, 一个batch一个batch的去训练,不是将所有数据拿进去训练
# 原因:计算机的性能或者说计算机的内存容量在处理大型数据的时候,比如说图片数据的时候,
# 将全部数据加载进去,可能会引起内存爆炸。model.evaluate(train_image,train_label)
model.evaluate(test_image,test_label)# 预测test数据集的前10张图片
model.predict(test_image[:10])
# 预测的
np.argmax(model.predict(test_image[:10]),axis=1)
# 实际的
test_label[:10]# 模型的优化 ----- 进行过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28) ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])model.fit(train_image,train_label,epochs=50,batch_size=512, validation_data=(test_image,test_label))# 模型的再优化 ---- 增加过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28) ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(10, activation='softmax'))model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])model.fit(train_image,train_label,epochs=200,batch_size=512, validation_data=(test_image,test_label))
2)代码分析:
1、验证数据集特性:
train_image.shape, train_label.shape
、
test_image.shape,test_label.shape
plt.imshow(train_image[0])
2、初始化模型
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28) ----> (600000, 28*28)
# 建立全链接层, 使用relu激活
model.add(layers.Dense(64, activation='relu'))
# 添加一个分类层,使用softmax激活。输出0-9是个数字,所以单元数为10
model.add(layers.Dense(10, activation='softmax'))
3、编译模型
# 当label是顺序编码的时候,计算交叉熵是 sparse_categorical_crossentropy
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])
4、训练模型
model.fit(train_image, train_label, epochs=50, batch_size=512)
batch_size = 512, 一个batch(批次)一个batch(批次)的去训练,不是将所有数据拿进去训练
原因:计算机的性能或者说计算机的内存容量在处理大型数据的时候,比如说图片数据的时候,
将全部数据加载进去,可能会引起内存爆炸。
5、预测模型
# 预测test数据集的前10张图片
model.predict(test_image[:10])
# 预测的
np.argmax(model.predict(test_image[:10]),axis=1)
# 实际的
test_label[:10]
预测的
实际的
发现,倒数第三位预测错误了。7 和9长得比较像
3)模型的优化
1、初始化模型
添加多个隐藏层 --- 增加网络容量直到过拟合
# 模型的优化 ----- 进行过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28) ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))# 多添加的两个隐藏层
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))model.add(layers.Dense(10, activation='softmax'))
2、编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])
3、训练模型
model.fit(train_image,train_label,epochs=50,batch_size=512, validation_data=(test_image,test_label))
参数 “ validation_data ” 在训练过程中观察其 在测试数据集上的表现
4)模型的再优化与抑制过拟合
添加DropOut层抑制过拟合
# 模型的再优化 ---- 增加过拟合
model = keras.Sequential()
model.add(layers.Flatten()) #(60000, 28, 28) ----> (600000, 28*28)
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(10, activation='softmax'))model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])model.fit(train_image,train_label,epochs=200,batch_size=512, validation_data=(test_image,test_label))
# 后续工作:继续增大网络容量,直到过拟合
参考《3.8 网络参数选择的总原则》
4.1 keras基础实例 手写数字识别相关推荐
- 使用tf.keras搭建mnist手写数字识别网络
使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...
- TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络
TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...
- 深度学习入门实例——基于keras的mnist手写数字识别
本文介绍了利用keras做mnist数据集的手写数字识别. 参考网址 http://www.cnblogs.com/lc1217/p/7132364.html mnist数据集中的图片为28*28的单 ...
- 1、基于Keras、Mnist手写数字识别数据集构建全连接(FC)神经网络训练模型
文章目录 前言 一.MNIST数据集是什么? 二.构建神经网络训练模型 1.导入库 2.载入数据 3.数据处理 4.创建模型 5.编译模型 6.训练模型 7.评估模型 三.总代码 前言 提示: 1.本 ...
- keras框架实现手写数字识别
详细细节可学习从零开始神经网络:keras框架实现数字图像识别详解! 代码实现: [1] ''' 将训练数据和检测数据加载到内存中(第一次运行需要下载数据,会比较慢): (mnist是手写数据集) t ...
- Multi-class Classication (多分类问题)实例--手写数字识别
本实例整理自斯坦福机器学习课程课后练习ex3 本例是对一个手写体的数据集(0-9)进行分类,也就是对原有的数据集进行训练,然后给定一个手写体,识别该手写体是数字几.其分类思想就是之前Andrew Ng ...
- 【TensorFlow-windows】keras接口——卷积手写数字识别,模型保存和调用
前言 上一节学习了以TensorFlow为底端的keras接口最简单的使用,这里就继续学习怎么写卷积分类模型和各种保存方法(仅保存权重.权重和网络结构同时保存) 国际惯例,参考博客: 官方教程 [注] ...
- caffe(4):mnist实例---手写数字识别
深度学习的第一个实例一般都是mnist,只要这个例子完全弄懂了,其它的就是举一反三的事了.由于篇幅原因,本文不具体介绍配置文件里面每个参数的具体函义,如果想弄明白的,请参看我以前的博文: 数据层及参数 ...
- tensorflow2.0基础操作-手写数字识别实战
import tensorflow as tf from tensorflow import keras from tensorflow.keras import datasets, layers, ...
最新文章
- RGB-D对红外热像仪和毫米波雷达标定
- 王建民做客第六期青年学者月度沙龙 分享工业软件的开源创新发展模式
- 编译 glibc-2.14 时出现的一个LD_LIBRARY_PATH不路径bug
- web开发快速提高工作效率的一些资源
- FCN全连接卷积网络(5)--Fully Convolutional Networks for Semantic Segmentation阅读(相关工作部分)
- IOS之Swift5.x开发通讯录实战
- 如何通过建造餐厅来了解Scala差异
- 《Python Cookbook 3rd》笔记(3.8):分数运算
- java采集温湿度水浸_智能电力水浸监控解决方案
- INFO:AdminStudio Debug
- css3的那些高级选择器一
- win7网络改局域网计算机名,局域网共享一键修复工具(支持win7) 修复windows7各种共享问题...
- 进化吧,MySQL锁!无锁->偏向锁->轻量级锁->重量级锁(请自动脑补数码宝贝进化音)
- POTN——新时代网络融合的必经之路
- 微信小程序appid的修改方法
- 全球2%高智商天才必测脑力题!却只有1%的人,能在5分钟内全部做对!
- java split 冒号_Java中字符串split() 的使用方法,没你想的那么简单
- Explaining Knowledge Distillation by Quantifying the Knowledge
- 如何给apk安装包去毒,避免被识别为病毒和木马
- 室外无人驾驶挑战赛小结-2019‘恩智浦’杯全国大学生智能车竞赛