基于Keras搭建mnist数据集训练识别的Pipeline
搭建模型
import tensorflow as tf
from tensorflow import keras# get data
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()# setup model
model = keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),keras.layers.Dense(128, activation=tf.nn.relu),keras.layers.Dense(10, activation=tf.nn.softmax)
])# 编译
model.compile(optimizer=tf.train.AdamOptimizer(), loss='sparse_categorical_crossentropy',metrics=['accuracy'])# train model
model.fit(train_images, train_labels, epochs=5)# evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels)print('test accuracy:', test_acc)
模型构建方法二:
# 加载MNIST数据集 -- 《深度学习》第二章案例
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images.shape # (60000, 28, 28)
train_labels.shape # (60000,)# 将训练数据和标签关联在一起
from keras import models, layersnet = models.Sequential()
net.add(layers.Dense(512, activation='relu', input_shape=(28*28,))) # 直接输入展平的张量
net.add(layers.Dense(10,activation='softmax')) # 最后输出10个结果# 编译网络
net.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])# 训练数据准备
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255 # 数据归一化test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255 # 数据归一化# 对标签进行分类编码:变成one-hot编码
from keras.utils import to_categoricaltrain_labels_final = to_categorical(train_labels)
test_labels_final = to_categorical(test_labels)train_images.shape # (60000, 784)
train_labels[0] # 5,这是非onehot编码的标签train_labels_final.shape # (60000, 10)# 拟合网络:训练开始
history = net.fit(train_images, train_labels_final, epochs=5, batch_size=128)'''
Epoch 1/5
60000/60000 [==============================] - 5s 78us/step - loss: 0.0288 - acc: 0.9916
Epoch 2/5
60000/60000 [==============================] - 5s 80us/step - loss: 0.0221 - acc: 0.9932
Epoch 3/5
60000/60000 [==============================] - 5s 78us/step - loss: 0.0171 - acc: 0.9950
Epoch 4/5
60000/60000 [==============================] - 5s 79us/step - loss: 0.0133 - acc: 0.9961
Epoch 5/5
60000/60000 [==============================] - 5s 80us/step - loss: 0.0100 - acc: 0.9970
'''# 保存模型
import os
model_name = "keras_mnist_trained_model.h5"
save_dir = os.path.join(os.getcwd(), 'saved_models')if not os.path.isdir(save_dir):os.makedirs(save_dir)model_path = os.path.join(save_dir, model_name)
net.save(model_path)
print("模型保存在:%s" % model_path)# 评估模型
test_loss, test_acc = net.evaluate(test_images, test_labels_final)
print('test_loss: ', test_loss) # test_loss: 0.06558757138366236,交叉熵损失函数,分类用
print('test_acc: ', test_acc) # test_acc: 0.9831# 使用模型进行测试数据集预测
# test_images.shape # (10000, 784)
test_images[0].shape # (784,)
to_be_predicted = test_images[0].reshape((1,784)) # .reshape([1,784])
res = net.predict(to_be_predicted) # array([[2.7967730e-13, 4.2122917e-16, 6.3757026e-09, 1.7213833e-07,# 6.7121612e-19, 6.7293619e-12, 3.6417281e-21, 9.9999988e-01,# 7.6961736e-12, 5.2838995e-09]], dtype=float32)
res.argmax() # 下标从0开始,这个结果是7
注意单个图片加载进来预测,一定要扩展一个维度,因为训练时也预留了位置,给批量输入。
画图看看
# 显示图片
import matplotlib.pyplot as plt
# plt.imshow(test_images[0].reshape((28,28,1)))im = test_images[0].reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
# plt.savefig("test.png") # 保存成文件
plt.close()
加载模型以预测
上面我们在saved_models
中保存了keras_mnist_trained_model.h5
,现在我们加载这个模型,并看看模型的summary
.
# 加载模型进行预测
from keras.models import load_model
model = load_model('./saved_models/keras_mnist_trained_model.h5')model.summary()
'''
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_3 (Dense) (None, 512) 401920
_________________________________________________________________
dense_4 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
'''
模型加载完毕,现在开始处理。
from PIL import Image
from keras.preprocessing.image import img_to_array# 整合为一个cell
def load_image_to_array(path):image = Image.open(path)image = image.resize((28, 28)) # resize(28,28)是错的 image = img_to_array(image) # 此时是4个通道,加载进来的是png图像image = image[:,:,0]image = image.reshape([1,28*28]) # reshape到网络可以接收return imagedef softmax_to_label(res):return res.argmax()image_to_predict = load_image_to_array('./9.png')
res = model.predict(image_to_predict)label = softmax_to_label(res)print("The number is: ", label)
这种通过模型输出的很准,测试下来都是准确预测。
这种我随便在网上找的,然后拆分预测,结果都是2,取了其中的一个通道。
这也可以看出模型并非完全智能,模型学到的是像素的分布。
总之,本篇的目的是通过最简单的MNIST数据集来打通使用Keras做训练、预测的Pipeline。
END.
基于Keras搭建mnist数据集训练识别的Pipeline相关推荐
- 基于Keras搭建cifar10数据集训练预测Pipeline
基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...
- DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化
DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...
- TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络
TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...
- 基于Keras搭建CNN、TextCNN文本分类模型
基于Keras搭建CNN.TextCNN文本分类模型 一.CNN 1.1 数据读取分词 1.2.数据编码 1.3 数据序列标准化 1.4 构建模型 1.5 模型验证 二.TextCNN文本分类 2.1 ...
- 基于Keras搭建LSTM网络实现文本情感分类
基于Keras搭建LSTM网络实现文本情感分类 一.语料概况 1.1 数据统计 1.1.1 查看样本均衡情况,对label进行统计 1.1.2 计句子长度及长度出现的频数 1.1.3 绘制句子长度累积 ...
- 使用tf.keras搭建mnist手写数字识别网络
使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...
- 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%
基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...
- DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)
DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...
- DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)
DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...
最新文章
- Android CheckedTextView 实现单选与多选
- Http Module 介绍(转)
- MapReduce执行过程
- Android NDK开发——人脸检测与静默活体检测
- 成功解决WARNING:tensorflow:From :read_data_sets (from tensorflow.contrib.learn.python.learn.
- Windows环境下利用VS和mingw编译LLVM
- 【转】LUA内存分析
- 在数组中查找第k个最大元素_查找数组中每个元素的最近最大邻居
- 实现后台高级查询(中级版)
- AndroidStudio_从Eclipse到AndroidStudio开发工具_两者使用的区别_通过向导新建项目和引入module---Android原生开发工作笔记68
- .net winform 的 OnKeyDown 与 方向键
- 代码很烂,所以离职。
- 《微信背后的产品观》一书
- 计算机主板别称是什么城,上海别称什么城?
- 苦心研究两周,我特么终于搞懂啥是「元宇宙」了
- FusionComputer密码修改
- 福特汉姆大学计算机科学专业,福特汉姆大学优势专业
- Git 配置别名 —— 让命令变得更简单
- C语言实现二叉平衡树
- 热烈祝贺|酒事有鲤盛装亮相2023中国(山东)精酿啤酒产业发展创新论坛暨展览会
热门文章
- visual studio 2012 下配置OPENcv3.1 和CMAKE问题总结
- phpboot使用mysql_PHP MySQL 插入数据
- java调用js模板引擎_JavaScript模板引擎应用场景及实现原理详解
- 全局中断_【安全圈】微软更新造成Office 365等多个在线服务中断!
- python列表生成式和map效率_Python列表生成式12个小功能,你常用哪几个?
- ssh整合mysql不能自动生成表_ssh整合思想 Spring与Hibernate的整合 项目在服务器启动则自动创建数据库表...
- 静态配置_配置静态LSP示例
- python如何查看类信息_关于如何查看本地python类库详细信息的方法
- c语言指针++_C和C ++中的指针
- 如何在C ++中实现内联函数?