1 读取本地MNIST数据,训练,保存模型

# -*- coding: utf-8 -*-
"""
Created on Thu Apr  9 21:10:34 2020@author: tianx
"""
import numpy as np
import os
import gzip
import keras
from keras.models import Sequential # 导入序贯模型,可以通过顺序的方式,叠加神经网络层
from keras.layers import Densefrom keras import optimizers
from keras.optimizers import SGD # 导入优化函数# 定义加载数据的函数,data_folder为保存gz数据的文件夹,该文件夹下有4个文件
# 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
# 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'def load_data(data_folder):files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']paths = []for fname in files:paths.append(os.path.join(data_folder,fname))with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)(x_train,y_train), (x_test, y_test) = load_data('D:/Data/Mnist/MNIST/raw/')print(x_train.shape,y_train.shape) # 60000张28*28的单通道灰度图
print(x_test.shape,y_test.shape)
import matplotlib.pyplot as plt # 导入可视化的包
im = plt.imshow(x_train[0],cmap='gray')
plt.show()
y_train[0]x_train = x_train.reshape(60000,784) # 将图片摊平,变成向量
x_test = x_test.reshape(10000,784) # 对测试集进行同样的处理x_train = x_train / 255
x_test = x_test / 255y_train = keras.utils.to_categorical(y_train,10)
y_test = keras.utils.to_categorical(y_test,10)model = Sequential() # 构建一个空的序贯模型
# 添加神经网络层
model.add(Dense(512,activation='relu',input_shape=(784,)))
model.add(Dense(256,activation='relu'))
model.add(Dense(10,activation='softmax'))
model.summary()model.compile(optimizer=SGD(),loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train,y_train,batch_size=64,epochs=1,validation_data=(x_test,y_test)) # 此处直接将测试集用作了验证集score = model.evaluate(x_test,y_test)
print("loss:",score[0])
print("accu:",score[1])# 保存模型
path = "D:/Data/Model/model_file_path.h5"
model.save(path)# 取一个测试集数据
test_data=x_test[0]
test_data=test_data[np.newaxis,:]
test_data=np.tile(test_data,[64,1])
#test_data=test_data.reshape(1,784)
#test_28=test_data.reshape(28,28)
#plt.imshow(test_28,cmap='gray')
#plt.show()# 预测结果
result=model.predict(test_data)
result=result.tolist()
#print(result[0].index(max(result[0])
print(result[0].index(max(result[0])))
# model_save_path = "model_file_path.h5"

2 对单个MNIST数据进行预测

# -*- coding: utf-8 -*-
"""
Created on Fri Apr 10 14:22:09 2020@author: tianx
"""import numpy as np
import os
import gzip
import keras
from keras.models import Sequential # 导入序贯模型,可以通过顺序的方式,叠加神经网络层
from keras.layers import Densefrom keras import optimizers
from keras.optimizers import SGD # 导入优化函数
from keras.models import load_modeldef load_data(data_folder):files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']paths = []for fname in files:paths.append(os.path.join(data_folder,fname))with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)读取数据集
(x_train,y_train), (x_test, y_test) = load_data('D:/Data/Mnist/MNIST/raw/')
# 模型的路径
path = "D:/Data/Model/model_file_path.h5"
加载模型
model = load_model(path)# 取一个测试集数据
test_data=x_test[0]
test_data=test_data[np.newaxis,:] # 增加一个维度
test_data=test_data.reshape(1,784) # 改变形状
test_data=np.tile(test_data,[64,1]) # 在0维度上复制64份,拟合网络的输入# 预测结果
result=model.predict(test_data)
print(np.argmax(result[0]))

参考博客:

01 本地加载MNIST数据:

https://www.cnblogs.com/ypzhai/p/9997856.html

02 模型的保存和加载:

https://www.cnblogs.com/Mrzhang3389/p/10746300.html

MNIST手写数据,从训练到数据预测(keras)相关推荐

  1. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  2. 用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体

    感想 首先我是首先看了一下莫凡pyhton教程中tensorflow python搭建自己的神经网络教程以及查看了官方的教程TensorFlow中文社区-MNIST进阶教程,这里面只是有简单的测试出来 ...

  3. mnist手写数字模型训练、保存、加载及图片预测

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 主要过程 导入 加载数据 创建模型和训练 模型应用 总结 前言 非专业程序员,主业PLC单片机,2019年想扩充知识体 ...

  4. 使用tf.keras搭建mnist手写数字识别网络

    使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...

  5. 用python的numpy实现mnist手写数字识别

    完整代码的文章底部(Optimization_mnist.py和lr_utils.py),原理和公式部分可以看前面文章,转载文章请附上本文链接 学完前面(1到6)文章就完成了吴恩达deeplearni ...

  6. Caffe MNIST 手写数字识别(全面流程)

    目录 1.下载MNIST数据集 2.生成MNIST图片训练.验证.测试数据集 3.制作LMDB数据库文件 4.准备LeNet-5网络结构定义模型.prototxt文件 5.准备模型求解配置文件_sol ...

  7. ANN原来如此简单!——用Excel实现的MNIST手写数字识别(之三)

    ANN原来如此简单 人工神经网络目前仍然是一个火热的话题,许多人都对它充满了兴趣.然而,对于想了解ANN具体是怎么回事的同学来说,往往缺乏一个足够简单可视化的方法去了解神经网络的内部构造.网络上的各种 ...

  8. pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练

    文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...

  9. Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集简介、下载、使用方法(包括数据增强)之详细攻略

    Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集简介+数据增强(将已有MNIST数据集通过移动像素上下左右的方法来扩大数据集为初始数据集的5倍) 目录 MNIST ...

  10. Dataset之MNIST:MNIST(手写数字图片识别及其ubyte.gz文件)数据集简介、下载、使用方法(包括数据增强)之详细攻略

    Dataset之MNIST:MNIST(手写数字图片识别及其ubyte.gz文件)数据集简介.下载.使用方法(包括数据增强,将已有MNIST数据集通过移动像素上下左右的方法来扩大数据集为初始数据集的5 ...

最新文章

  1. JavaScript —— 如何判断一个非数字输入
  2. win10计算机无限弹网页,win10系统浏览网页时频繁弹出广告怎么办 Window10阻止网页弹出广告的四种方法...
  3. 【人脸识别】初识人脸识别
  4. 安装keras and theano于google colab上
  5. 【java图文趣味版】数组元素的访问与遍历
  6. 数据挖掘原理与算法_资料 | 数据挖掘:概念、模型、方法和算法(第2版)/ 国外计算机科学经典教材...
  7. 运维测试工作笔记0001---单台普通8G内存的服务器-可以达到的http并发量
  8. Python 安装skimage即Scikit-Image
  9. Java实现人脸识别(各项目结构都有案例说明)
  10. MySql 磁盘满了的处理
  11. 计算机清除服务命令,快速清理电脑垃圾用什么命令
  12. H5 video 自动播放(autoplay)不生效解决方案
  13. intellij idea 工具栏的隐藏和显示
  14. Excel函数之~计算日期、天数、星期
  15. 鲨鱼抓包(Wireshark)简易操作说明
  16. 股市几个常用基本面指标介绍
  17. Scrapy:运行爬虫程序的方式
  18. 景观雕塑商城搭建应该注意些什么
  19. UVa232 Crossword Answers(纵横字谜的答案)
  20. Fdog系列(一):思来想去,不如写一个聊天软件,那就从仿QQ注册页面开始吧。

热门文章

  1. GDB 调试 Nginx 磨刀不误砍柴工
  2. 给定一个无重复元素的数组 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合
  3. 差分贴片晶振使最强军事武器出世
  4. 制作TTF格式的字体
  5. 刨根系列 之 Unity3D UGUI 背后的工作原理
  6. b站python_python学习 —— B站抢楼原理
  7. 中台建设利器-SPI插件机制
  8. 天才小毒妃 第945章 龙非夜心情很不好
  9. android远程指纹认证流程的猜测
  10. 浅谈:百度竞价恶意点击汇总及处理方法