在学习手写体识别的时候,看到一些B站的教学视频发现,很多用TensorFlow完成的手写体识别,在下载数据集的时候会报错,无法使用,这是因为TensorFlow在维护的时候,处理的不是很好,无法使用input_data,有解决办法,但是太麻烦了,Keras是在TensorFlow之上 运行的,采用Keras能省去很多麻烦。

手写体的具体实现直接上代码;具体操作看注释。

from numpy import mean
from numpy import std
from matplotlib import pyplot
from sklearn.model_selection import KFoldimport tensorflow as tf
from tensorflow.keras import utils
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.optimizers import SGD#load train & test data
def load_dataset():#加载minister手写体数据  训练集和测试集都有两个数据一个是图像一个是标签表示这个是什么数字(trainX,trainY),(testX,testY)=tf.keras.datasets.mnist.load_data()print("trainX shape",trainX.shape)  #(60000, 28, 28)print("trainY shape",trainY.shape)#reshape dataset to have a signal channeltrainX=trainX.reshape((trainX.shape[0],28,28,1))testX=testX.reshape((testX.shape[0],28,28,1))#one hot编码trainY=to_categorical(trainY)testY=to_categorical(testY)#打印前5行one hot represent是什么样的for i in range(5):print("trainY",trainY[i])return trainX,trainY,testX,testY#线性处理 scale pixels
def prep_pixels(train,test):#把样本中的值转化为浮点数train_norm=train.astype('float32')test_norm=test.astype('float32')#可以像我们第一段程序那样把0转为0.01 1转为0.99,也可以不转#normallize to range 0-1train_norm=train_norm/255.0test_norm=test_norm/255.0return train_norm,test_norm#配置学习模型
def define_model():model=Sequential()#8个(3,3)的卷积核  激活函数relu,kernel_initializer初始化卷积核的方法model.add(Conv2D(8,(3,3),activation='relu',kernel_initializer='he_uniform',input_shape=(28,28,1)))#convolution output number of parameter = 26 x 26 x 8model.add(MaxPooling2D((2,2)))  #这里的stride相当于2,就像卷积一样在这里会以2的步长进行移动#(26-2)/2+1=13   13x13x8   没有padding,所以不用加padding#formula for calucate the number of output for each layer#output = (input - kernel+2*padding)/Stride + 1   输出矩阵大小的求解公式model.add(Flatten())#flatten作用:Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。#隐含层model.add(Dense(120,activation='relu',kernel_initializer='he_uniform'))model.add(Dense(10,activation='softmax'))   #输出节点10个#compile the model#定义优化器为SGD#出SGD外常用的优化器还有RMSprop,Adam,adadelte,adagrad,adamax,Nadam,Ftrl#monentum项能够在相关方向加速SGD#对于monentum参数可以看看这篇博客;https://blog.csdn.net/fengbingchun/article/details/124648766?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522166576412016800182168388%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=166576412016800182168388&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-124648766-null-null.142^v56^opensearch_v2,201^v3^add_ask&utm_term=sgd%E4%B8%ADmomentum%E5%8F%82%E6%95%B0%E8%AE%BE%E7%BD%AE&spm=1018.2226.3001.4187#metircs是对模型有效性,performance的测量classification 问题与regression 问题的测量方法不同#Keras对classification问题支持的测量包括:Binary Accoracy,Categorical Accuracy,Saprese Categorical Accuracy,Top K,Sparese Top Kopt=SGD(lr=0.01,momentum=0.9)   #lr指学习率model.compile(optimizer=opt,loss='categorical_crossentropy',metrics=['accuracy'])print(model.summary())return modeldef evaluate_model(dataX,dataY,n_folds=5):scores,histories=list(),list()#prepare cross validation   准备交叉验证kfold=KFold(n_folds,shuffle=True,random_state=1)    #random_state相当于随机数种子,保证每次随机运行的结果一样#enumerate splitefor train_ix,test_ix in kfold.split(dataX):model=define_model()#select rows for train or testtrainX,trainY,testX,testY=dataX[train_ix],dataY[train_ix],dataX[test_ix],dataY[test_ix]#fit modelhistory=model.fit(trainX,trainY,epochs=10,batch_size=60,validation_data=(testX,testY),verbose=0)#增加几个打印语句方便调试和程序理解print(history.history.keys())#evalute modelloss,acc=model.evaluate(testX,testY,verbose=0)print('>%.3f'%(acc*100))#stores scoresscores.append(acc)histories.append(history)print("scores",scores)print("histories.len",len(histories))return scores,histories#plot diagnostic learning curves
def summarize_diagnostics(histories):for i in range(len(histories)):#plot losspyplot.subplot(2,1,1)pyplot.title('Cross Entropy loss')pyplot.plot(histories[i].histories['loss'],color='blue',label='train')pyplot.plot(histories[i].histories['val_loss'],color='orange',lable='test')pyplot.ylabel('loss')pyplot.xlabel('epoch')pyplot.legend(['train','test'],loc='upper right')#plot accuracypyplot.subplot(2,1,2)pyplot.title("classification accuracy")pyplot.plot(histories[i].histories['accuracy'],color='blue',lable='train')pyplot.plot(histories[i].histories['val_accuracy'],color='orange',lable='test')pyplot.ylabel('accuracy')pyplot.xlabel('epoch')pyplot.legend(['train', 'test'], loc='upper right')pyplot.show()def summarize_performance(scores):#print summaryprint('Accuracy: mean= %.3f, n=%d '% (mean(scores)*100,std(scores)*100),len(scores))# box and whisker plots of resultspyplot.boxplot(scores)pyplot.show()#run the test harness for evaluating a model
def run_mymodel_test():#加载数据集trainX,trainY,testX,testY = load_dataset()#数据集,像素预处理,转化为浮点数,并压缩到0-1之前trainX,testX=prep_pixels(trainX,testX)#模型评估,其中先构造模型再调用学习scores,histories=evaluate_model(trainX,trainY)#打印学习曲线,看学习的过程趋势summarize_diagnostics(histories)#总结模型的performancesummarize_performance(scores)#主程序入口
run_mymodel_test()

参考:ANN第三课 -卷积神经网络实现手写体识别的python程序 - 使用keras+tensorflow_哔哩哔哩_bilibili

cnn卷积神经网络手写体识别keras和tensorflow相关推荐

  1. 1700X + GTX950 跑 CNN卷积神经网络面部表情识别实例代码

    网站评论功能维护中,对文章的评论记录于此: 文章: http://blog.csdn.net/sqh4587/article/details/74507010 tensorflow机器学习之利用CNN ...

  2. CNN卷积神经网络十二生肖识别项目(一)数据下载篇

    文章目录 前言 一.前提准备 二.代码部分 1.引入库 2.发送请求,解析数据,并保存到本地 3.全部代码 总结 前言 接触深度学习有一段时间了,我们利用CNN卷积神经网络做一个十二生肖动物图片识别的 ...

  3. 卷积神经网络手写体识别

    CNN 背景 卷积 LeNet网络结构 C1 S2 C3 S4 C5 F6 输出 数据集 代码及运行结果 测试 CNN 1995年, Yann LeCun 与Yoshua Bengio 提出了conv ...

  4. python机器学习库keras——CNN卷积神经网络人脸识别

    全栈工程师开发手册 (作者:栾鹏) python教程全解 github地址:https://github.com/626626cdllp/kears/tree/master/Face_Recognit ...

  5. 深度学习--TensorFlow(项目)识别自己的手写数字(基于CNN卷积神经网络)

    目录 基础理论 一.训练CNN卷积神经网络 1.载入数据 2.改变数据维度 3.归一化 4.独热编码 5.搭建CNN卷积神经网络 5-1.第一层:第一个卷积层 5-2.第二层:第二个卷积层 5-3.扁 ...

  6. plt保存图片_人工智能Keras CNN卷积神经网络的图片识别模型训练

    CNN卷积神经网络是人工智能的开端,CNN卷积神经网络让计算机能够认识图片,文字,甚至音频与视频.CNN卷积神经网络的基础知识,可以参考:CNN卷积神经网络 LetNet体系结构是卷积神经网络的&qu ...

  7. TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)

    TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) 源代码/数据集已上传到 Github - tensorflow-tutorial-samples 大白话讲解卷积 ...

  8. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  9. CNN卷积神经网络—LeNet原理以及tensorflow实现mnist手写体训练

    CNN卷积神经网络-LeNet原理以及tensorflow实现minst手写体训练 1. LeNet原理 2.tensorflow实现Mnist手写体识别 1.安装tensorflow 2.代码实现手 ...

最新文章

  1. Alphabet量子公司横空出世!Sandbox将与谷歌、DeepMind成姊妹
  2. Entity Framework 学习初级篇7--基本操作:增加、更新、删除、事务
  3. js中Blob对象一般用法
  4. 将本地源代码程序推送远程Github仓库
  5. qc中的流程图怎么画_QC流程图参考
  6. 矩池云上安装caffe gpu教程
  7. 【转】B树的插入和删除
  8. 第三百九十一节,Django+Xadmin打造上线标准的在线教育平台—404,403,500页面配置...
  9. cygwin的离线安装包
  10. 在Windows环境下搭建Nginx文件服务器(简单实用版)
  11. LDAP 统一认证 单点登录学习
  12. 10-新闻发布系统数据库-新闻管理数据操作
  13. 每日一算法:杨辉三角形
  14. Android学习笔记--菜单
  15. 在c++程序中执行DOS命令
  16. Web3赋能创作者经济:NFT,DAO和永续收入
  17. 拼多多面试——机器学习岗位面经
  18. 第九届蓝桥杯 螺旋折线
  19. linux 安装git 教程
  20. 图像处理之角点检测与亚像素角点定位

热门文章

  1. google driver 上传文件等操作
  2. 如何获取适用于 Azure 的 EV 代码签名证书?
  3. Excel 2010 VBA 入门 130 利用窗体创建实时筛选浮动工具栏
  4. 【外网访问学校服务器】阿里云服务器+frp+内网服务器
  5. 【190416】BS结构的VC++消息发送程序源代码
  6. sigrity前仿真,DDR地址线仿真。
  7. 员工绩效考核管理PPT模板-优页文档
  8. pytorch节省显存_节省新房子的照明
  9. RK3568在智能融合终端的应用
  10. php模拟get提交 字符串截取 字符串替换 示例源码