搭建模型

import tensorflow as tf
from tensorflow import keras# get data
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()# setup model
model = keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),keras.layers.Dense(128, activation=tf.nn.relu),keras.layers.Dense(10, activation=tf.nn.softmax)
])# 编译
model.compile(optimizer=tf.train.AdamOptimizer(), loss='sparse_categorical_crossentropy',metrics=['accuracy'])# train model
model.fit(train_images, train_labels, epochs=5)# evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels)print('test accuracy:', test_acc)

模型构建方法二:

# 加载MNIST数据集 -- 《深度学习》第二章案例
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images.shape # (60000, 28, 28)
train_labels.shape # (60000,)# 将训练数据和标签关联在一起
from keras import models, layersnet = models.Sequential()
net.add(layers.Dense(512, activation='relu', input_shape=(28*28,))) # 直接输入展平的张量
net.add(layers.Dense(10,activation='softmax')) # 最后输出10个结果# 编译网络
net.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])# 训练数据准备
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255 # 数据归一化test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255 # 数据归一化# 对标签进行分类编码:变成one-hot编码
from keras.utils import to_categoricaltrain_labels_final = to_categorical(train_labels)
test_labels_final = to_categorical(test_labels)train_images.shape # (60000, 784)
train_labels[0] # 5,这是非onehot编码的标签train_labels_final.shape # (60000, 10)# 拟合网络:训练开始
history = net.fit(train_images, train_labels_final, epochs=5, batch_size=128)'''
Epoch 1/5
60000/60000 [==============================] - 5s 78us/step - loss: 0.0288 - acc: 0.9916
Epoch 2/5
60000/60000 [==============================] - 5s 80us/step - loss: 0.0221 - acc: 0.9932
Epoch 3/5
60000/60000 [==============================] - 5s 78us/step - loss: 0.0171 - acc: 0.9950
Epoch 4/5
60000/60000 [==============================] - 5s 79us/step - loss: 0.0133 - acc: 0.9961
Epoch 5/5
60000/60000 [==============================] - 5s 80us/step - loss: 0.0100 - acc: 0.9970
'''# 保存模型
import os
model_name = "keras_mnist_trained_model.h5"
save_dir = os.path.join(os.getcwd(), 'saved_models')if not os.path.isdir(save_dir):os.makedirs(save_dir)model_path = os.path.join(save_dir, model_name)
net.save(model_path)
print("模型保存在:%s" % model_path)# 评估模型
test_loss, test_acc = net.evaluate(test_images, test_labels_final)
print('test_loss: ', test_loss) # test_loss:  0.06558757138366236,交叉熵损失函数,分类用
print('test_acc: ', test_acc) # test_acc:  0.9831# 使用模型进行测试数据集预测
# test_images.shape # (10000, 784)
test_images[0].shape # (784,)
to_be_predicted = test_images[0].reshape((1,784)) # .reshape([1,784])
res = net.predict(to_be_predicted) # array([[2.7967730e-13, 4.2122917e-16, 6.3757026e-09, 1.7213833e-07,# 6.7121612e-19, 6.7293619e-12, 3.6417281e-21, 9.9999988e-01,#  7.6961736e-12, 5.2838995e-09]], dtype=float32)
res.argmax() # 下标从0开始,这个结果是7

注意单个图片加载进来预测,一定要扩展一个维度,因为训练时也预留了位置,给批量输入。

画图看看

# 显示图片
import matplotlib.pyplot as plt
# plt.imshow(test_images[0].reshape((28,28,1)))im = test_images[0].reshape(28, 28)
fig = plt.figure()
plotwindow = fig.add_subplot(111)
plt.axis('off')
plt.imshow(im, cmap='gray')
plt.show()
# plt.savefig("test.png")  # 保存成文件
plt.close()

加载模型以预测

上面我们在saved_models中保存了keras_mnist_trained_model.h5,现在我们加载这个模型,并看看模型的summary.

# 加载模型进行预测
from keras.models import load_model
model = load_model('./saved_models/keras_mnist_trained_model.h5')model.summary()
'''
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense_3 (Dense)              (None, 512)               401920
_________________________________________________________________
dense_4 (Dense)              (None, 10)                5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
'''

模型加载完毕,现在开始处理。

from PIL import Image
from keras.preprocessing.image import img_to_array# 整合为一个cell
def load_image_to_array(path):image = Image.open(path)image = image.resize((28, 28)) # resize(28,28)是错的 image = img_to_array(image) # 此时是4个通道,加载进来的是png图像image = image[:,:,0]image = image.reshape([1,28*28]) # reshape到网络可以接收return imagedef softmax_to_label(res):return res.argmax()image_to_predict = load_image_to_array('./9.png')
res = model.predict(image_to_predict)label = softmax_to_label(res)print("The number is: ", label)

这种通过模型输出的很准,测试下来都是准确预测。

这种我随便在网上找的,然后拆分预测,结果都是2,取了其中的一个通道。

这也可以看出模型并非完全智能,模型学到的是像素的分布。

总之,本篇的目的是通过最简单的MNIST数据集来打通使用Keras做训练、预测的Pipeline。

END.

基于Keras搭建mnist数据集训练识别的Pipeline相关推荐

  1. 基于Keras搭建cifar10数据集训练预测Pipeline

    基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...

  2. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  3. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  4. 基于Keras搭建CNN、TextCNN文本分类模型

    基于Keras搭建CNN.TextCNN文本分类模型 一.CNN 1.1 数据读取分词 1.2.数据编码 1.3 数据序列标准化 1.4 构建模型 1.5 模型验证 二.TextCNN文本分类 2.1 ...

  5. 基于Keras搭建LSTM网络实现文本情感分类

    基于Keras搭建LSTM网络实现文本情感分类 一.语料概况 1.1 数据统计 1.1.1 查看样本均衡情况,对label进行统计 1.1.2 计句子长度及长度出现的频数 1.1.3 绘制句子长度累积 ...

  6. 使用tf.keras搭建mnist手写数字识别网络

    使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...

  7. 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%

    基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...

  8. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

  9. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...

最新文章

  1. Android CheckedTextView 实现单选与多选
  2. Http Module 介绍(转)
  3. MapReduce执行过程
  4. Android NDK开发——人脸检测与静默活体检测
  5. 成功解决WARNING:tensorflow:From :read_data_sets (from tensorflow.contrib.learn.python.learn.
  6. Windows环境下利用VS和mingw编译LLVM
  7. 【转】LUA内存分析
  8. 在数组中查找第k个最大元素_查找数组中每个元素的最近最大邻居
  9. 实现后台高级查询(中级版)
  10. AndroidStudio_从Eclipse到AndroidStudio开发工具_两者使用的区别_通过向导新建项目和引入module---Android原生开发工作笔记68
  11. .net winform 的 OnKeyDown 与 方向键
  12. 代码很烂,所以离职。
  13. 《微信背后的产品观》一书
  14. 计算机主板别称是什么城,上海别称什么城?
  15. 苦心研究两周,我特么终于搞懂啥是「元宇宙」了
  16. FusionComputer密码修改
  17. 福特汉姆大学计算机科学专业,福特汉姆大学优势专业
  18. Git 配置别名 —— 让命令变得更简单
  19. C语言实现二叉平衡树
  20. 热烈祝贺|酒事有鲤盛装亮相2023中国(山东)精酿啤酒产业发展创新论坛暨展览会

热门文章

  1. visual studio 2012 下配置OPENcv3.1 和CMAKE问题总结
  2. phpboot使用mysql_PHP MySQL 插入数据
  3. java调用js模板引擎_JavaScript模板引擎应用场景及实现原理详解
  4. 全局中断_【安全圈】微软更新造成Office 365等多个在线服务中断!
  5. python列表生成式和map效率_Python列表生成式12个小功能,你常用哪几个?
  6. ssh整合mysql不能自动生成表_ssh整合思想 Spring与Hibernate的整合 项目在服务器启动则自动创建数据库表...
  7. 静态配置_配置静态LSP示例
  8. python如何查看类信息_关于如何查看本地python类库详细信息的方法
  9. c语言指针++_C和C ++中的指针
  10. 如何在C ++中实现内联函数?