训练代码:

# -*- coding: utf-8 -*-
"""
Created on Mon Apr 13 09:18:19 2020@author: tianx
"""
import numpy as np
import os
import gzip
import keras
from keras.models import Sequential # 导入序贯模型,可以通过顺序的方式,叠加神经网络层
from keras.layers import Dense,Flatten,MaxPool2Dfrom keras import optimizers
from keras.optimizers import SGD # 导入优化函数from keras.models import Sequential,Model
from keras import layers,Input
from keras.utils import plot_model# 定义加载数据的函数,data_folder为保存gz数据的文件夹,该文件夹下有4个文件
# 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
# 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'def load_data(data_folder):files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']paths = []for fname in files:paths.append(os.path.join(data_folder,fname))with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)(x_train,y_train), (x_test, y_test) = load_data('D:/Data/Mnist/MNIST/raw/')
print(x_train.shape,y_train.shape) # 60000张28*28的单通道灰度图
print(x_test.shape,y_test.shape)
import matplotlib.pyplot as plt # 导入可视化的包
im = plt.imshow(x_train[0],cmap='gray')
plt.show()
y_train[0]x_train=x_train[:,:,:,np.newaxis]x_test=x_test[:,:,:,np.newaxis]
#x_train = x_train.reshape(60000,784) # 将图片摊平,变成向量
#x_test = x_test.reshape(10000,784) # 对测试集进行同样的处理x_train = x_train / 255
x_test = x_test / 255y_train = keras.utils.to_categorical(y_train,10)  # 转换成one-hot格式
y_test = keras.utils.to_categorical(y_test,10)# 定义网络模型
input_tensor=Input(shape=(28,28,1))
x=layers.Conv2D(32,(3,3),activation='relu')(input_tensor)
x=MaxPool2D((2,2),name='pool1')(x)
x=layers.Conv2D(64,3,activation='relu')(x)
x=MaxPool2D((2,2),name='pool2')(x)
x=Flatten()(x)
x=layers.Dense(1000,activation='relu')(x)
out_tensor=layers.Dense(10,activation='softmax')(x)model=Model(input_tensor,out_tensor)
model.summary()
#plot_model(model,to_file='convolutional_neural_network.png')
path = "D:/Data/Model/model_mnist_cnn.h5"
model.save(path)model.compile(optimizer=SGD(),loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train,y_train,batch_size=64,epochs=1,validation_data=(x_test,y_test)) # 此处直接将测试集用作了验证集score = model.evaluate(x_test,y_test)
print("loss:",score[0])
print("accu:",score[1])

预测代码:

# -*- coding: utf-8 -*-
"""
Created on Mon Apr 13 10:45:28 2020@author: tianx
"""
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 10 14:22:09 2020
@author: tianx
"""import numpy as np
import os
import gzip
import keras
from keras.models import Sequential # 导入序贯模型,可以通过顺序的方式,叠加神经网络层
from keras.layers import Densefrom keras import optimizers
from keras.optimizers import SGD # 导入优化函数
from keras.models import load_modeldef load_data(data_folder):files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']paths = []for fname in files:paths.append(os.path.join(data_folder,fname))with gzip.open(paths[0], 'rb') as lbpath:y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[1], 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)with gzip.open(paths[2], 'rb') as lbpath:y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(paths[3], 'rb') as imgpath:x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)return (x_train, y_train), (x_test, y_test)#读取数据集
(x_train,y_train), (x_test, y_test) = load_data('D:/Data/Mnist/MNIST/raw/')
# 模型的路径
path = "D:/Data/Model/model_mnist_cnn.h5"
#加载模型
model = load_model(path)# 取一个测试集数据
test_data=x_test[0]
test_data=test_data[:,:,np.newaxis] # 增加一个维度
#test_data=test_data.reshape(1,784) # 改变形状
test_data=test_data[np.newaxis,:,:,:]test_data=np.tile(test_data,[64,1,1,1]) # 在0维度上复制64份,拟合网络的输入# 预测结果
result=model.predict(test_data)
print(np.argmax(result[0]))

结果:

参考博客:

函数式API变成(超详细)

https://blog.csdn.net/ting0922/article/details/94437540

kears编写CNN网络,实现对mnist的识别相关推荐

  1. Python机器学习实验二:1.编写代码,实现对iris数据集的KNN算法分类及预测

    Python机器学习实验二:编写代码,实现对iris数据集的KNN算法分类及预测 1.编写代码,实现对iris数据集的KNN算法分类及预测,要求: (1)数据集划分为测试集占20%: (2)n_nei ...

  2. 使用ResNet18网络实现对Cifar-100数据集分类

    使用ResNet18网络实现对Cifar-100数据集分类 简介 本次作业旨在利用ResNet18实现对于Cifar-100数据集进行图像识别按照精细类进行分类. Cifar-100数据集由20个粗类 ...

  3. 利用胶囊网络实现对CIFAR10分类

    利用胶囊网络实现对CIFAR10分类 数据集:CIFAR-10数据集由10个类中的60000个32x32彩色图像组成,每个类有6000个图像.有50000个训练图像和10000个测试图像. 实验:搭建 ...

  4. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  5. 写给初学者的深度学习教程之 MNIST 数字识别

    一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程,几乎是所有的教程都会把它放在最开始的地方.这是因为,这个简单的工程包含了大致的机器学习流程,通过练习这个工程 ...

  6. 外包 | LBP/HOG/CNN 实现对 CK/jaffe/fer2013 人脸表情数据集分类

    外包 | LBP/HOG/CNN 实现对 CK/jaffe/fer2013 人脸表情数据集分类 文章目录 外包 | LBP/HOG/CNN 实现对 CK/jaffe/fer2013 人脸表情数据集分类 ...

  7. CNN网络实现手写数字(MNIST)识别 代码分析

    CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import num ...

  8. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  9. CNN网络介绍与实践-王者荣耀英雄图片识别

    作者介绍:高成才,腾讯Android开发工程师,2016.4月校招加入腾讯,主要负责企鹅电竞推流SDK.企鹅电竞APP的功能开发和技术优化工作. 本文主要是对CS231n课程学习笔记的提炼,添加了一些 ...

  10. CNN网络介绍与实践:王者荣耀英雄图片识别

    欢迎大家前往腾讯云社区,获取更多腾讯海量技术实践干货哦~ 作者介绍:高成才,腾讯Android开发工程师,2016.4月校招加入腾讯,主要负责企鹅电竞推流SDK.企鹅电竞APP的功能开发和技术优化工作 ...

最新文章

  1. Stack and queue.
  2. CENTOS6.3下zabbix安装部署
  3. poi写入Excel
  4. 为什么需要使用到多线程
  5. 生成器设计模式的应用
  6. 端口可以随便设置吗_驱动可以随便更新吗?
  7. 实时事理学习与搜索平台DemoV1.0正式对外发布
  8. virtualenv environment怎么选_2020年阿里云双11内容安全怎么选? - 云计算分享家
  9. python爬虫爬取公众号_Python爬虫案例:爬取微信公众号文章
  10. java短语音聊天室_实现一个简单的语音聊天室(多人语音聊天系统)
  11. Elasticsearch - Fuzzy query
  12. 毒论--不要再面向对象(续)
  13. Xshell6|Xftp6 要继续使用此程序,您必须应用最新的更新或使用新版本
  14. 数据仓库Build The Data Warehouse(William H.Inmon)学习笔记 --- 第六章、分布式数据仓库
  15. 百度云PCS调试过程
  16. 【SQL】格式为yyyymmddhh:mm:ss的时间格式转换
  17. Office 彻底卸载
  18. 鹅得了腺病毒用什么药治疗小鹅摇头晃脑不吃食怎么办
  19. GAN项目实战 使用CycleGAN将苹果变成橙子Pytorch版
  20. Openstack“T版“全组件手动部署

热门文章

  1. C ++ 扑克牌洗牌
  2. mindmanager2020版下载激活码序列号密钥版及使用教程
  3. org.springframework.beans.factory.BeanDefinitionStoreException: Failed to process import candidates
  4. c# 十六进制数据转十六进制字符串
  5. LG V10距离感应器失效后的解决办法
  6. Mozilla5.0的意思
  7. html网页中图片展示为碎片,基于HTML代码实现图片碎片化加载功能
  8. python 删除字典none_python – 从字典中删除NoneTypes
  9. 性感美女陪你读名言——《经典双语名言警句十篇》 (图)
  10. halcon分割区域的方法