1.这里首先需要声明一下,我的这篇文章参考来自以下两位博主:
界面:
https://blog.csdn.net/weixin_39964552/article/details/82937144
猫狗图像识别:
https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/106166653
2.数据集下载:以下的数据集也是在网上下载之后我自己整理的:
(1)测试集数据:
链接:https://pan.baidu.com/s/1fQg27nBMsL0xlm6JaQIerg
提取码:uags
(2)训练集数据:
链接:https://pan.baidu.com/s/1fkb4x4zmTKVgJKJ32CxjAA
提取码:gtwv
3.
(1)以下是我放置数据集的位置:

(2)代码运行环境:Windows10; 运行平台:jupyter notebook
4.训练集程序代码:
(1)导入相关的库:

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import  Conv2D,MaxPooling2D,GlobalAveragePooling2Dimport cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import  train_test_split
from tensorflow.keras import  Input
from tensorflow.keras.callbacks import ModelCheckpoint,ReduceLROnPlateau
from tensorflow.keras.layers import PReLU,Activation
from tensorflow.keras.models import Model

(2)设置全局参数:

#设置图像的像素大小
image_Size=64
#设置训练图像的路径
dataPath='train/train_images'
#定义训练的代数
EPOCHES=10
#定义学习率
Learing_Rate=0.001

(3)定义类别:

#定义类别的个数
Label_Class=10

(4)设置batch_size大小:

Batch_Size=32

(5)

LabelName=[]

(6)

class_Number={'0':0,'1':1,'2':2,'3':3,'4':4,'5':5,'6':6,'7':7,'8':8,'9':9
}

(7)加载数据集:

def LoadImageData():imageList=[]ListImage=os.listdir(dataPath)print(ListImage)for img in ListImage:#这个地方需要注意的是需要处理前后的空格strip()#注意这里我为什么要用‘(’作为分割看下面一张图我的图片处理的形式,就是为了将前面的数字0给切分出来,其他的数字也是类似的LabelNames=class_Number[img.split('(')[0].strip(' ')]print(LabelNames)LabelName.append(LabelNames)#图像的当前路径dataImagePath=os.path.join(dataPath,img)print(dataImagePath)#读取图像image=cv2.imread(dataImagePath)#改变图像的大小image=cv2.resize(image,(image_Size,image_Size),interpolation=cv2.INTER_LANCZOS4)#转换为array类型image=img_to_array(image)imageList.append(image) imageList=np.array(imageList,dtype="int")/255.0return imageList


(8)

print('开始加载数据集: ')
imageArray=LoadImageData()
LabelName=np.array(LabelName)
print('加载数据集结束: ')


(9)

import tensorflow as tf

(10.1)定义模型,这个模型我只是将原文的模型改成自己喜欢的形式了,功能是一样的:

model=Sequential([tf.keras.layers.InputLayer(input_shape=(image_Size,image_Size,3)),tf.keras.layers.Conv2D(32,kernel_size=[3,3],strides=[2,2],padding='same'),#参数计算=(3*3*3+1)*32=896tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),tf.keras.layers.Conv2D(32,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),#tf.keras.layers.MaxPooling2D(pool_size=[2,2]),这个地方一定要注意,不能写成tf.keras.layers.MaxPooling2D(pool_size=[2,2],strides=[1,1])#虽然原文什么都没有,并不代表这默认为1tf.keras.layers.MaxPooling2D(pool_size=[2,2]),tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),tf.keras.layers.MaxPooling2D(pool_size=[2,2]),tf.keras.layers.Conv2D(128,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),tf.keras.layers.Conv2D(128,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),tf.keras.layers.MaxPooling2D(pool_size=[2,2]),tf.keras.layers.Conv2D(256,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),tf.keras.layers.Conv2D(256,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.BatchNormalization(epsilon=1e-5),tf.keras.layers.PReLU(),tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(Label_Class),tf.keras.layers.Activation('softmax')
])

(11.1)输出模型具体的每一层情况和参数:

model.summary()


(11.2)当采用VGG13模型时:

#VGG13模型
model=Sequential([tf.keras.layers.InputLayer(input_shape=(image_Size,image_Size,3)),tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.Conv2D(64,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),tf.keras.layers.Conv2D(128,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.Conv2D(128,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),tf.keras.layers.Conv2D(256,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.Conv2D(256,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),tf.keras.layers.Conv2D(512,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.Conv2D(512,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),tf.keras.layers.Conv2D(512,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.Conv2D(512,kernel_size=[3,3],strides=[1,1],padding='same'),tf.keras.layers.Activation('relu'),tf.keras.layers.MaxPool2D(pool_size=[2,2],strides=2,padding='same'),tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(256),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(128),tf.keras.layers.Activation('relu'),tf.keras.layers.Dense(10),tf.keras.layers.Activation('softmax')
])

(11.3)VGG13输出层的情况:

VGG13训练16代之后的训练集和验证集准确率都很不错:

其他的地方都不变,只是这个模型的地方可以采用上面的模型,也可以采用这个VGG13模型。

(12)切分训练集和验证集数据:

trainX,ValX,trainY,ValY=train_test_split(imageArray,LabelName,test_size=0.3,random_state=42)

(13)

from tensorflow.keras.preprocessing.image import ImageDataGenerator

(14)这里我不做数据增强:

train=ImageDataGenerator()
val=ImageDataGenerator()
train_Gen=train.flow(trainX,trainY,batch_size=Batch_Size,shuffle=True)
val_Gen=val.flow(ValX,ValY,batch_size=Batch_Size,shuffle=True)

(15)打印数据集的划分情况:

print(np.shape(trainX))
print(np.shape(trainY))
print(tf.shape(trainX))
print(np.shape(ValX))
print(np.shape(ValY))

(16)

(17)定义回调函数中调用的函数:

checkpointer=ModelCheckpoint(filepath='weights_best_simple_model.hdf5',#该回调函数将在每个epoch后保存模型到filepathmonitor='val_accuracy',#检测数据为准确率verbose=1,#信息展示模式,0或1save_best_only=True,#当设置为True时,将只保存在验证集上性能最好的模型mode='max'#表示得到的数据是验证集上最大准确率,与monitor检测的数据相对应
)

(18)

#当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果。
#该回调函数检测指标的情况,如果在patience个epoch中看不到模型性能提升,则减少学习率
#当评价指标不在提升时,减少学习率
reduce=ReduceLROnPlateau(monitor='val_accuracy',patience=10,#当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发verbose=1,#是否展示信息模式factor=0.5,#每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少min_lr=1e-6#学习率的下限
)

(19)

from tensorflow.keras import losses

(20)

optimizer=Adam(learning_rate=Learing_Rate)

(21)

model.compile(optimizer=optimizer,loss=losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

(22)训练:

history=model.fit_generator(train_Gen,steps_per_epoch=trainX.shape[0]/Batch_Size,validation_data=val_Gen,epochs=100,validation_steps=ValX.shape[0]/Batch_Size,callbacks=[checkpointer,reduce],verbose=1,shuffle=True
)
#这个地方保存训练完之后的模型,就是为了后面进行测试直接加载这个模型
model.save('models_100.h5')

(23)画图这个地方我完全使用的是原文的内容:

#history保存的参数{'accuracy':[],'loss':{},'val_accuracy':{},'val_loss':[]}import os
import matplotlib.pyplot as pltloss_trend_graph_path = r"WW_loss.jpg"
acc_trend_graph_path = r"WW_acc.jpg"
print("Now,we start drawing the loss and acc trends graph...")
# summarize history for accuracyfig = plt.figure(1)
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("Model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(1)
# summarize history for lossfig = plt.figure(2)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(2)
print("We are done, everything seems OK...")

5.测试集程序代码:
(1)导入相关的库:

import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time

(2)设置全局参数:

image_Size=64
imagelist=[]
class_Test={'0':0,'1':1,'2':2,'3':3,'4':4,'5':5,'6':6,'7':7,'8':8,'9':9
}

(3)加载模型:

test_model_classifier=load_model('models_10.h5')
t1=time.time()

(4)

class_Numbers={0:'0',1:'1',2:'2',3:'3',4:'4',5:'5',6:'6',7:'7',8:'8',9:'9'
}

(5)查看当前目录情况:

import os
os.listdir()


(6)对图像进行处理:

def Deal_Pic(dataImgPath):imageList=[]image=cv2.imread(dataImgPath)image=cv2.resize(image,(image_Size,image_Size),interpolation=cv2.INTER_LANCZOS4)image=img_to_array(image)imageList.append(image)imageList=np.array(imageList,dtype="int")/255.0return imageList

(7)预测函数:

def Predict_Class(imageList):print(test_model_classifier.predict(imageList))pre=np.argmax(test_model_classifier.predict(imageList))print(pre)print(type(pre))test=class_Numbers[pre]t2=time.time()print(test,'预测的概率: ',test_model_classifier.predict(imageList)[0][pre])t3=t2-t1print('time: ',t3)return test,test_model_classfiier.predict(imageList)[0][pre]

(8)运行代码->弹出界面->选择图像(就是文件中给的测试集图像)->根据原文博主给出的代码,可以对图像进行放大和缩小:

import sys
import os
from PyQt5 import QtGui,QtCore,QtWidgets
from PyQt5.QtWidgets import QApplication,QTextEdit,QPushButtonclass Ui_MainWindow(object):def __init__(self):super(Ui_MainWindow, self).__init__()def setupUi(self, MainWindow):#设置窗口的标题MainWindow.setObjectName("MainWindow")#设置窗口的大小MainWindow.resize(800, 600)self.centralWidget = QtWidgets.QWidget(MainWindow)self.centralWidget.setObjectName("centralWidget")#在水平的方向上排列控件 左右排列self.horizontalLayout = QtWidgets.QHBoxLayout(self.centralWidget)self.horizontalLayout.setObjectName("horizontalLayout")#网格布局self.gridLayout = QtWidgets.QGridLayout()self.gridLayout.setObjectName("gridLayout")self.picshow = QtWidgets.QGraphicsView(self.centralWidget)self.picshow.setObjectName("picshow")self.gridLayout.addWidget(self.picshow, 0, 0, 5, 1)self.zoomout = QtWidgets.QPushButton(self.centralWidget)self.zoomout.setObjectName("zoomout")self.gridLayout.addWidget(self.zoomout, 0, 1, 1, 1)self.zoomin = QtWidgets.QPushButton(self.centralWidget)self.zoomin.setObjectName("zoomin")#self.gridLayout.addWidget(组件,r,c,对齐方式)self.gridLayout.addWidget(self.zoomin, 1, 1, 1, 1)self.horizontalLayout.addLayout(self.gridLayout)MainWindow.setCentralWidget(self.centralWidget)# 创建多行文本框,这个文本框用来显示类别的self.textEdit_s = QTextEdit()self.textEdit_s.setGeometry(2,1,100,30)self.gridLayout.addWidget(self.textEdit_s,3,1,1,1)# 创建多行文本框,这个文本框用来显示准确率的self.textEdit_a = QTextEdit()self.textEdit_a.setGeometry(2, 1, 100, 30)self.gridLayout.addWidget(self.textEdit_a, 4, 1, 1, 1)#放置按钮的位置self.button=QPushButton('上传图像')self.gridLayout.addWidget(self.button,2,1,1,1)self.retranslateUi(MainWindow)QtCore.QMetaObject.connectSlotsByName(MainWindow)def retranslateUi(self, MainWindow):_translate = QtCore.QCoreApplication.translateMainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))self.zoomout.setText(_translate("MainWindow", "放大"))self.zoomin.setText(_translate("MainWindow", "缩小"))

以下的代码是继承了上面的那个类:Ui_MainWindow(这个主要是为了布局,下面的才是正真的操作)

import sys
import os
import cv2
import numpy as np
from Ui_MainWindow import *
from PyQt5 import QtWidgets,QtCore,QtGui
from PyQt5.QtCore import pyqtSlot
from PyQt5.QtWidgets import QApplication,QPushButton,QFileDialog,\QGraphicsScene,QGraphicsPixmapItem,QTextEdit,QVBoxLayout
from PyQt5.QtGui import  QImage,QPixmap,QFontnorm_size=100
class File_Select(Ui_MainWindow,QtWidgets.QMainWindow,QtWidgets.QWidget):def __init__(self):super(File_Select, self).__init__()QtWidgets.QWidget.__init__(self)def initUi(self):self.setupUi(self)#编辑文本框self.zoomscale=1 self.setToolTip(u'<b>程序</b>提示')  # 调用setToolTip()方法,该方法接受富文本格式的参数,css之类。QtWidgets.QToolTip.setFont(QFont('华文楷体', 10))  # 设置字体以及字体大小icon = QtGui.QIcon()icon.addPixmap(QtGui.QPixmap("D:\\1.jpg"), QtGui.QIcon.Normal, QtGui.QIcon.Off)self.setWindowIcon(icon)  # 设置窗口的图标self.setGeometry(30,30,700,500)self.setWindowTitle('选择图像')
#         self.button=QPushButton('上传图像',self)self.button.setGeometry(595,200,100,30)self.button.clicked.connect(self.file)#对图像的路径进行处理def file_Path(self,filename):file_name = ''for i in filename[0]:file_name+=ireturn file_namedef file(self):filename=QFileDialog.getOpenFileNames(self,'选择图像',os.getcwd(), "All Files(*);;Text Files(*.txt)")file_name=self.file_Path(filename[0])#处理图像imageLists=Deal_Pic(file_name)#输出预测类别classify,acc=Predict_Class(imageLists)acc=str(acc)self.textEdit_s.setPlainText('数字的类别: '+classify)self.textEdit_a.setPlainText('预测的准确率: '+acc)print(file_name)img=cv2.imread(file_name)
#         img=cv2.cvtColor(img,cv2.COLOR_RGB2BGR)x=img.shape[1]y=img.shape[0]frame=QImage(img,x,y,QImage.Format_BGR888)pix=QPixmap.fromImage(frame)self.item=QGraphicsPixmapItem(pix)self.scene=QGraphicsScene()self.scene.addItem(self.item)self.picshow.setScene(self.scene)#重写关闭Mmainwindow窗口def closeEvent(self, event):replp=QtWidgets.QMessageBox.question(self,u'警告',u'确认退出?',QtWidgets.QMessageBox.Yes|QtWidgets.QMessageBox.No)if replp==QtWidgets.QMessageBox.Yes:event.accept()else:event.ignore()@pyqtSlot()def on_zoomin_clicked(self):"""点击缩小图像"""self.zoomscale=self.zoomscale-0.05if self.zoomscale<=0:self.zoomscale=0.2self.item.setScale(self.zoomscale)                                #缩小图像@pyqtSlot()def on_zoomout_clicked(self):"""点击方法图像"""# TODO: not implemented yetself.zoomscale=self.zoomscale+0.05if self.zoomscale>=1.2:self.zoomscale=1.2self.item.setScale(self.zoomscale)                             #放大图像if __name__=='__main__':app=QApplication(sys.argv)file_select=File_Select()file_select.initUi()file_select.show()sys.exit(app.exec_())

(9)最后的效果:

训练了100代之后的结果是较好的,主要是训练的图像太少了,我也只能给出那些图像了,这篇文章还会更新,我之后将使用其他的模型进行测试,看一下效果怎么样。关于上面我所参考的内容我已经在文章开头给出了参考的文章的链接。

手写体数字识别+界面相关推荐

  1. 基于MATLAB的手写体数字识别算法的实现

    基于MATLAB的手写体数字识别 一.课题介绍 手写数字识别是模式识别领域的一个重要分支,它研究的核心问题是:如何利用计算机自动识别人手写在纸张上的阿拉伯数字.手写体数字识别问题,简而言之就是识别出1 ...

  2. Matlab深度学习-手写体数字识别

    Matlab深度学习 文章目录 Matlab深度学习 前言 一.MNIST手写体数字数据 二.用到的深度学习框架-LeNet5 2-0 LeNet5的网络架构 2-1 框架实现-通过Matlab GU ...

  3. bp神经网络_BP 神经网络驱动的手写体数字识别软件 EasyOCR

    EasyOCR 项目介绍 本软件是一个手写体数字识别软件,采用BP神经网络,基于colt数学库,有完整源码,可以保存训练结果,基于开源例程neuralnetwork-sample,原作可以在GitHu ...

  4. Tensorflow 改进的MNIST手写体数字识别

    上篇简单的Tensorflow解决MNIST手写体数字识别可扩展性并不好.例如计算前向传播的函数需要将所有的变量都传入,当神经网络的结构变得复杂.参数更多时,程序的可读性变得非常差.而且这种方式会导致 ...

  5. Tensorflow解决MNIST手写体数字识别

    这里给出的代码是来自<Tensorflow实战Google深度学习框架>,以供参考和学习. 首先这个示例应用了几个基本的方法: 使用随机梯度下降(batch) 使用Relu激活函数去线性化 ...

  6. bp神经网络测试_BP 神经网络驱动的手写体数字识别软件 EasyOCR

    EasyOCR 项目介绍 本软件是一个手写体数字识别软件,采用BP神经网络,基于colt数学库,有完整源码,可以保存训练结果,基于开源例程neuralnetwork-sample,原作可以在GitHu ...

  7. 基于matlab的手写体数字识别系统,基于matlab的手写体数字识别系统研究

    基于matlab的手写体数字识别系统研究 丁禹鑫1,丁会2,张红娟2,杨彤彤1 [摘要]随着科学技术的发展,机器学习成为一大学科热门领域,是一门专门研究计算机怎样模拟或实现人类的学习行为的交叉学科.文 ...

  8. 【机器学习实验二】k-NN算法—改进约会网站以及手写体数字识别

    目录 一.改进约会网站 1.项目背景 2.数据收集 3.在约会网站中使用k-近邻算法的流程 4.代码实现 二.手写体数字识别 1.了解手写体数字识别 2.手写体数字识别思路 3.1.导入模块 3.2. ...

  9. 基于MATLAB手写体数字识别程序设计

    基于MATLAB手写体数字识别程序设计 手写体识别由于其实用性,一直处于研究进步的阶段,本文主要针对的是对0-9十个手写数字体脱机识别,在Matlab中对样本部分为进行16特征的提取,分别采用最小距离 ...

最新文章

  1. intellij 打开node项目 一直停留在scanning files to index....,或跳出内存不够的提示框...
  2. 月光博客 - 再谈软件保护中软加密和硬加密的安全强度
  3. 迷途の荣耀 Chapter Ⅱ
  4. 官方文档翻译-ESP32-SPI Flash
  5. 全面了解Nginx到底能做什么
  6. 【技术+某度面经】Jenkins 内容+百度面经分享
  7. angularjs 模块化
  8. BZOJ 2724: [Violet 6]蒲公英
  9. 【Log4j】Jboss下配置log4j简记
  10. 【Erlang开源项目】HTTP客户端ibrowse
  11. Gradle下载及安装
  12. KVM虚拟化技术(理论知识+搭建虚拟化平台实验步骤)
  13. Software.Cradle.Suite.V11 X64 热流体模拟软件
  14. I/O流(万流齐发、万流归宗) 本章目标: 掌握 讲  解:★★★★★ http://kuaibao.qq.com/s/20200527A0LR3000?refer=spider 1.I/O流概
  15. 使用zlib对字符串进行压缩
  16. linux系统篇 -- 一、系统概要
  17. OSChina 周六乱弹 ——对,假期的最后一天咯~!
  18. 数据流图 visio
  19. Replicate Brogaard Stock Volatility Decomposition
  20. 机器人建模中移动关节如何建立坐标系_机器人工程师进阶之路(二)6轴机械臂D-H法建模...

热门文章

  1. 爬虫之selenium对cookie的处理
  2. 深度神经网络是否过拟合?
  3. laravel中Crypt加密方法
  4. 用 .NET Memory Profiler 跟踪.net 应用内存使用情况--基本应用篇
  5. SAX解析XML文档——(二)
  6. C语言中的typedef
  7. windows phone (26) ApplicationBar应用程序栏
  8. Java Swing 树状组件JTree的使用方法【图】
  9. 写一个ArrayList类的动态代理类
  10. 搞死了 报错【libc-client.a: could not read symbols: ...