MNIST手写数据,从训练到数据预测(keras)
1 读取本地MNIST数据,训练,保存模型
# -*- coding: utf-8 -*-
"""
Created on Thu Apr 9 21:10:34 2020@author: tianx
"""
import numpy as np
import os
import gzip
import keras
from keras.models import Sequential # 导入序贯模型,可以通过顺序的方式,叠加神经网络层
from keras.layers import Densefrom keras import optimizers
from keras.optimizers import SGD # 导入优化函数# 定义加载数据的函数,data_folder为保存gz数据的文件夹,该文件夹下有4个文件
# 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
# 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'def load_data(data_folder):files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']paths = []for fname in files:paths.append(os.path.join(data_folder,fname))with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)(x_train,y_train), (x_test, y_test) = load_data('D:/Data/Mnist/MNIST/raw/')print(x_train.shape,y_train.shape) # 60000张28*28的单通道灰度图
print(x_test.shape,y_test.shape)
import matplotlib.pyplot as plt # 导入可视化的包
im = plt.imshow(x_train[0],cmap='gray')
plt.show()
y_train[0]x_train = x_train.reshape(60000,784) # 将图片摊平,变成向量
x_test = x_test.reshape(10000,784) # 对测试集进行同样的处理x_train = x_train / 255
x_test = x_test / 255y_train = keras.utils.to_categorical(y_train,10)
y_test = keras.utils.to_categorical(y_test,10)model = Sequential() # 构建一个空的序贯模型
# 添加神经网络层
model.add(Dense(512,activation='relu',input_shape=(784,)))
model.add(Dense(256,activation='relu'))
model.add(Dense(10,activation='softmax'))
model.summary()model.compile(optimizer=SGD(),loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train,y_train,batch_size=64,epochs=1,validation_data=(x_test,y_test)) # 此处直接将测试集用作了验证集score = model.evaluate(x_test,y_test)
print("loss:",score[0])
print("accu:",score[1])# 保存模型
path = "D:/Data/Model/model_file_path.h5"
model.save(path)# 取一个测试集数据
test_data=x_test[0]
test_data=test_data[np.newaxis,:]
test_data=np.tile(test_data,[64,1])
#test_data=test_data.reshape(1,784)
#test_28=test_data.reshape(28,28)
#plt.imshow(test_28,cmap='gray')
#plt.show()# 预测结果
result=model.predict(test_data)
result=result.tolist()
#print(result[0].index(max(result[0])
print(result[0].index(max(result[0])))
# model_save_path = "model_file_path.h5"
2 对单个MNIST数据进行预测
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 10 14:22:09 2020@author: tianx
"""import numpy as np
import os
import gzip
import keras
from keras.models import Sequential # 导入序贯模型,可以通过顺序的方式,叠加神经网络层
from keras.layers import Densefrom keras import optimizers
from keras.optimizers import SGD # 导入优化函数
from keras.models import load_modeldef load_data(data_folder):files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']paths = []for fname in files:paths.append(os.path.join(data_folder,fname))with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)读取数据集
(x_train,y_train), (x_test, y_test) = load_data('D:/Data/Mnist/MNIST/raw/')
# 模型的路径
path = "D:/Data/Model/model_file_path.h5"
加载模型
model = load_model(path)# 取一个测试集数据
test_data=x_test[0]
test_data=test_data[np.newaxis,:] # 增加一个维度
test_data=test_data.reshape(1,784) # 改变形状
test_data=np.tile(test_data,[64,1]) # 在0维度上复制64份,拟合网络的输入# 预测结果
result=model.predict(test_data)
print(np.argmax(result[0]))
参考博客:
01 本地加载MNIST数据:
https://www.cnblogs.com/ypzhai/p/9997856.html
02 模型的保存和加载:
https://www.cnblogs.com/Mrzhang3389/p/10746300.html
MNIST手写数据,从训练到数据预测(keras)相关推荐
- DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化
DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...
- 用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体
感想 首先我是首先看了一下莫凡pyhton教程中tensorflow python搭建自己的神经网络教程以及查看了官方的教程TensorFlow中文社区-MNIST进阶教程,这里面只是有简单的测试出来 ...
- mnist手写数字模型训练、保存、加载及图片预测
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 主要过程 导入 加载数据 创建模型和训练 模型应用 总结 前言 非专业程序员,主业PLC单片机,2019年想扩充知识体 ...
- 使用tf.keras搭建mnist手写数字识别网络
使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...
- 用python的numpy实现mnist手写数字识别
完整代码的文章底部(Optimization_mnist.py和lr_utils.py),原理和公式部分可以看前面文章,转载文章请附上本文链接 学完前面(1到6)文章就完成了吴恩达deeplearni ...
- Caffe MNIST 手写数字识别(全面流程)
目录 1.下载MNIST数据集 2.生成MNIST图片训练.验证.测试数据集 3.制作LMDB数据库文件 4.准备LeNet-5网络结构定义模型.prototxt文件 5.准备模型求解配置文件_sol ...
- ANN原来如此简单!——用Excel实现的MNIST手写数字识别(之三)
ANN原来如此简单 人工神经网络目前仍然是一个火热的话题,许多人都对它充满了兴趣.然而,对于想了解ANN具体是怎么回事的同学来说,往往缺乏一个足够简单可视化的方法去了解神经网络的内部构造.网络上的各种 ...
- pytorch MNIST 手写数字识别 + 使用自己的测试集 + 数据增强后再训练
文章目录 1. MNIST 手写数字识别 2. 聚焦数据集扩充后的模型训练 3. pytorch 手写数字识别基本实现 3.1完整代码及 MNIST 测试集测试结果 3.1.1代码 3.1.2 MNI ...
- Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集简介、下载、使用方法(包括数据增强)之详细攻略
Dataset之MNIST:MNIST(手写数字图片识别+ubyte.gz文件)数据集简介+数据增强(将已有MNIST数据集通过移动像素上下左右的方法来扩大数据集为初始数据集的5倍) 目录 MNIST ...
- Dataset之MNIST:MNIST(手写数字图片识别及其ubyte.gz文件)数据集简介、下载、使用方法(包括数据增强)之详细攻略
Dataset之MNIST:MNIST(手写数字图片识别及其ubyte.gz文件)数据集简介.下载.使用方法(包括数据增强,将已有MNIST数据集通过移动像素上下左右的方法来扩大数据集为初始数据集的5 ...
最新文章
- JavaScript —— 如何判断一个非数字输入
- win10计算机无限弹网页,win10系统浏览网页时频繁弹出广告怎么办 Window10阻止网页弹出广告的四种方法...
- 【人脸识别】初识人脸识别
- 安装keras and theano于google colab上
- 【java图文趣味版】数组元素的访问与遍历
- 数据挖掘原理与算法_资料 | 数据挖掘:概念、模型、方法和算法(第2版)/ 国外计算机科学经典教材...
- 运维测试工作笔记0001---单台普通8G内存的服务器-可以达到的http并发量
- Python 安装skimage即Scikit-Image
- Java实现人脸识别(各项目结构都有案例说明)
- MySql 磁盘满了的处理
- 计算机清除服务命令,快速清理电脑垃圾用什么命令
- H5 video 自动播放(autoplay)不生效解决方案
- intellij idea 工具栏的隐藏和显示
- Excel函数之~计算日期、天数、星期
- 鲨鱼抓包(Wireshark)简易操作说明
- 股市几个常用基本面指标介绍
- Scrapy:运行爬虫程序的方式
- 景观雕塑商城搭建应该注意些什么
- UVa232 Crossword Answers(纵横字谜的答案)
- Fdog系列(一):思来想去,不如写一个聊天软件,那就从仿QQ注册页面开始吧。