概述

计算机神经网络则是人工智能中最为基础的也是较为重要的部分,它使用深度学习的方式模拟了人的神经元的工作,是一种全新的计算方法。本文的目标就是通过学习神经网络的相关知识,了解并掌握BP神经网络的实现原理和构造方法,建立一个简单的BP神经网络,并用MNIST数据集训练该网络,使训练后的网络能够成功的分类出MNIST测试数据集上的数字,并能识别从文件中读入的图片上的数字。

开发环境

CPU:英特尔 Core i7-7700HQ
GPU :Nvidia GeForce GTX 1060
内存:16GB

操作系统:Windows 10 x64
开发环境:PyCharm Community Edition 2020.2 x64 + Python3.8

需要我们完成的功能

总体目标:
构建一个简单的BP神经网络,让这个网络可以被训练,可以测试数据,并实现识别用户打开的图片中的数字
功能需求:

  1. 可以自己定义神经网络的输入层、隐藏层、学习率、训练世代等参数
  2. 可以训练、测试该网络
  3. 可以从文件夹中读取一张图片并进行判断分类

非功能需求:

  1. 需要一个可以使操作更为简便的图形交互
  2. GUI界面设计应该简单明了
  3. 对测试数据的识别精确度要高

BP网络设计

反向传播算法,即Back Propagation是建立在梯度下降算法基础上,适用多层神经网络的参数训练方法。由于隐藏层节点的预测误差无法直接计算,因此,反向传播算法直接利用输出层节点的预测误差反向估计上一层隐藏节点的预测误差,即从后往前逐层从输出层把误差反向传播到输入层,从而实现对链接权重调整,这也是反向传播算法名称的由来。
一个典型的3层BP神经网络模型如下图所示:


Mnist数据集的测试图片像素是28X28的,所以输入节点的个数就是28X28=784;识别出的数字有0-9十个数字,所以输出的节点的个数设置为10个;因为输入层的节点较多,所以隐藏层的节点个数设置为100;
考虑到梯度下降算法能够较好的消除产生的误差,所以激活函数设置为sigmoid函数;学习率设置为0.2,太高或太低都会导致不同的问题(梯度爆炸、梯度消失);训练世代设置为5个世代
因此,网络设计的参数为:

输入层节点数:784;
隐藏层节点数:200;
输出层节点数:10
学习率:0.1;
训练世代:5;
激活函数:sigmoid函数

代码实现

import PIL
import numpy as np
import pandas as pd
import imageio
from PIL import Image, ImageTk # 导入图像处理函数库
import tkinter as tk
from tkinter import constants, ttk
from tkinter import filedialog   #导入文件对话框函数库#————————————————————————神经网络构建,三层结构——————————————————#
#激活函数
def sigmoid(x):return 1.0 / (1.0 + np.exp(-x))
#定义神经网络函数
class neuralNetwork:#初始化神经网络def __init__(self,inputnodes,hiddennodes,outputnodes,learnrate):#设立每个神经网络的输入、隐藏、输出层的节点数self.inodes = inputnodesself.hnodes = hiddennodesself.onodes = outputnodes#设置学习率self.lrate = learnrateself.wi_h = (np.random.rand(self.hnodes,self.inodes)-0.5)self.wh_o = (np.random.rand(self.onodes,self.hnodes)-0.5)pass#训练神经网络def train(self,inputs_list,targets_list):#输入与标准结果inputs = np.array(inputs_list,ndmin=2).Ttargets= np.array(targets_list, ndmin=2).T#计算隐藏层的信号值hidden_inputs = np.dot(self.wi_h,inputs)hidden_outputs = sigmoid(hidden_inputs)#计算输出层的信号值outputs_inputs  = np.dot(self.wh_o, hidden_outputs)outputs_outputs = sigmoid(outputs_inputs)#计算误差:精确值-实际值output_errors = targets - outputs_outputshidden_errors = np.dot(self.wh_o.T,output_errors)#根据公式得出的表达式,直接用self.wh_o += self.lrate * np.dot((output_errors*outputs_outputs*(1.0-outputs_outputs)),np.transpose(hidden_outputs))self.wi_h += self.lrate * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),np.transpose (inputs))pass#接受输入,返回输出#将输出进行激活,归一化def query(self,input_list):inputs = np.array(input_list,ndmin=2).Thidden_inputs = np.dot(self.wi_h, inputs)hidden_outputs = sigmoid(hidden_inputs)outputs_inputs = np.dot(self.wh_o,hidden_outputs)outputs_outputs = sigmoid(outputs_inputs)return outputs_outputspasspass

这部分对BP神经网络类进行了参数的定义和对训练、激活函数进行了定义

#——————————————初始化GUI界面——————————————--#
window = tk.Tk()
window.title('神经网络识别MNIST数据集')
window.geometry('600x500')
global img_png  # 定义全局变量 图像的
var = tk.StringVar()  # 这时文字变量储存器
text = tk.Text(window,width=20,height=17)
text.pack(fill=tk.X,side=tk.BOTTOM)
text.insert(tk.END, '请输入相关数据,构建一个网络\n')def craet_BPNN():global nglobal input_nodesglobal hidden_nodesglobal output_nodesglobal learning_rateglobal epochsglobal training_data_list
#—————————————创建神经网络对象并用数据集训练网络——————————#
# 输入、隐藏、输出节点数input_nodes =int(var_inputs.get())hidden_nodes= int(var_hidden.get())output_nodes = int( var_outputs.get())# 学习率learning_rate = float(var_lrate.get())epochs = int(var_epochs.get())# 创建神经网络对象n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)text.insert(tk.END, 'BP网络构建成功!\n')text.insert(tk.END, '输入层节点数:'+var_inputs.get()+',隐藏层节点数:'+var_hidden.get()+',输出层节点数:'+var_outputs.get()+'\n')text.insert(tk.END, '学习率:' + var_lrate.get() + ',训练世代:' + var_epochs.get()+ '\n')text.insert(tk.END, '可以开始训练了!\n')#加载mnist数据集training_data_file = open("C:\\Users\\EASKWON\\Desktop\\mnist_train_100.csv", 'r')training_data_list = training_data_file.readlines()training_data_file.close()

以上代码的作用是使用函数创建神经网络对象,并加载数据集

#开始训练函数,训练MNist数据集
def beg_train():for e in range(epochs):#   训练的世代,一次训练完成表示训练一个世代print("训练中,第", e, "个世代")text.insert(tk.END, '训练中,第' + str(e) + '个世代\n')text.update()t=0for record in training_data_list:t+=1print("已训练",t,"个数据")# 用”,“来区分数据all_values = record.split(',')# 将输入缩放和转换inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01# 将目标的输出值的0改为0.01,1改为0.99targets = np.zeros(output_nodes) + 0.01targets[int(all_values[0])] = 0.99n.train(inputs, targets)passtext.insert(tk.END, '训练完毕!\n')text.insert(tk.END, '可以开始测试你的网络了!\n')
pass#       打开测试数据集MNIST-test# 开始测试函数,遍历所有测试集中的测试数据,得出准确率#————————————————————————测试MNIst数据集————————————————————#
def beg_test():global test_data_list#数据集的文件路径由自己定义,这是我自己的路径test_data_file = open("C:\\Users\\EASKWON\\PycharmProjects\\BpNetwork\\mnist_test_10.csv", 'r')test_data_list = test_data_file.readlines()test_data_file.close()all_values = test_data_list[0].split(',')print(all_values[1])n.query((np.asfarray(all_values[1:])/255.0*0.99)+0.01)#用来存放分数,即正确率scorecard = []text.insert(tk.END, '开始测试MNIST数据集.....\n')for record in test_data_list:#用”,“号分开数据all_values = record.split(',')# 用准确值标签记录数字准确值correct_label = int(all_values[0])print("---------")print("正确结果",correct_label)# 缩放inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01# 计算输出outputs = n.query(inputs)# 输出的最大值即为判断值label = np.argmax(outputs)print( "神经网络判断",label)# 将正确和错误的判断形成一个列表if (label == correct_label):# 正确为1scorecardscorecard.append(1)else:# 错误为0scorecardscorecard.append(0)print(scorecard)scorecard_array = np.asarray(scorecard)#正确率right_rate = (scorecard_array.sum() / scorecard_array.size) * 100text.insert(tk.END, '数据测试完毕\n')text.insert(tk.END, '正确率='+str(right_rate)+'%\n')text.update()print("正确率= ", right_rate, "%")pass

以上代码的作用则是对数据集测试功能的实现,它可以遍历测试集中的所有测试图片,并得出最终正确率。
到这一步,其实我们已经实现了识别手写数字的功能了,我们的目标已经完成了。接下来就是完善它,让它更加实用啦

#打开图片的函数,并尝试识别自己的图片
#编写GUI,让其更容易交互
def Open_Img():global img_pngglobal pathglobal label_ImgOpenFile = tk.Tk()  # 创建新窗口OpenFile.withdraw()file_path = filedialog.askopenfilename()print("训练已结束,开始测试图片")text.insert(tk.END, '开始测试图片\n')path=file_pathImg =Image.open(file_path)img_png = ImageTk.PhotoImage(Img)label_Img = tk.Label(window, image=img_png)Label_Show = tk.Label(window, image=img_png,# 使用 textvariable 替换 text, 因为这个可以变化bg='white', font=('Arial', 12), width=60, height=60)Label_Show.place(x=80, y=80)var.set('图像已打开')#自己图片的数据存放在这里our_own_dataset = []image_file_name = pathprint("加载中 ... ", image_file_name)text.insert(tk.END, '加载中....'+image_file_name+'\n')# 用文件名来设置准确值标签label = int(image_file_name[20])#将图片转换为数组img_array = imageio.imread(image_file_name, as_gray=True)# 将图片从28X28的数组转换成长为784的arrayimg_data = img_array.reshape(784)# 缩放灰度值为0-1范围内img_data = (img_data / 255.0 * 0.99) + 0.01print("图像最小值为",np.min(img_data))print("图像最大值为",np.max(img_data))# 将标签值放到数组第一个record = np.append(label, img_data)our_own_dataset.append(record)item = 0correct_label = our_own_dataset[item][0]# 将转换值作为输入inputs = our_own_dataset[item][1:]# 计算网络的输出outputs = n.query(inputs)print("输出节点的输出为:",outputs)text.insert(tk.END, '输出节点的输出为\n'+str(outputs)+'\n')text.update()# 最高输出值所在的数字作为识别标签label = np.argmax(outputs)print("神经网络说:“它是", label, "”")text.insert(tk.END, '神经网络认为图中的数字是' + str(label) + '\n')text.see(tk.END)if (label == correct_label):print("恭喜你,匹配成功!")text.insert(tk.END, '恭喜你,识别成功了!\n')else:print("很遗憾,识别失败了")text.insert(tk.END, '很遗憾,识别失败了!再试一次吧\n')pass#显示图片的函数
def SHOW():global img_pngLabel_Show = tk.Label(window, image=img_png,# 使用 textvariable 替换 text, 因为这个可以变化bg='white', font=('Arial', 12), width=60, height=60)Label_Show.place(x=80, y=80)passimg_frame = tk.LabelFrame(window, text='图像显示', padx=10, pady=10,width=120,height=120)
img_frame.place(x=55,y=50)
# 创建文本窗口,显示当前操作8状态
in_lable=tk.Label(window,text='输入层节点数:')
in_lable.pack()
in_lable.place(x=300,y=40)var_inputs=tk.StringVar()
var_inputs.set('784')
entry_inputs=tk.Entry(window,textvariable=var_inputs,width=10)
entry_inputs.place(x=380,y=40)hi_lable=tk.Label(window,text='隐藏层节点数:')
hi_lable.pack()
hi_lable.place(x=300,y=70)var_hidden=tk.StringVar()
var_hidden.set('50')
entry_hidden=tk.Entry(window,textvariable=var_hidden,width=10)
entry_hidden.place(x=380,y=70)out_lable=tk.Label(window,text='输出层节点数:')
out_lable.pack()
out_lable.place(x=300,y=100)var_outputs=tk.StringVar()
var_outputs.set('10')
entry_outputs=tk.Entry(window,textvariable=var_outputs,width=10)
entry_outputs.place(x=380,y=100)rate_lable=tk.Label(window,text='学习率:')
rate_lable.pack()
rate_lable.place(x=300,y=130)var_lrate=tk.StringVar()
var_lrate.set('0.1')
entry_lrate=tk.Entry(window,textvariable=var_lrate,width=10)
entry_lrate.place(x=380,y=130)epochs_lable=tk.Label(window,text='训练世代:')
epochs_lable.pack()
epochs_lable.place(x=300,y=160)var_epochs=tk.StringVar()
var_epochs.set('5')
entry_epochs=tk.Entry(window,textvariable=var_epochs,width=10)
entry_epochs.place(x=380,y=160)#训练数据集按钮
btn_train = tk.Button(window,text='构建网络',width=15, height=2,command=craet_BPNN)
btn_train.pack()
btn_train.place(x=30,y = 210)
#测试数据集按钮
btn_test = tk.Button(window,text='训练数据集',width=15, height=2,command=beg_train)
btn_test.pack()
btn_test.place(x=170,y=210)
# 创建打开图像按钮
btn_Open = tk.Button(window,text='测试数据集',  # 显示在按钮上的文字width=15, height=2,command=beg_test)  # 点击按钮式执行的命令
btn_Open.pack()
# 按钮位置
btn_Open.place(x=310,y=210)
# 创建显示图像按钮
btn_Show = tk.Button(window,text='打开测试图片',  # 显示在按钮上的文字width=15, height=2,command=Open_Img)  # 点击按钮式执行的命令btn_Show.pack()
# 按钮位置
btn_Show.place(x=450,y=210)
# 运行整体窗口
window.mainloop()
pass

以上代码实现了GUI的编写,让用户自行从文件中读取图片并识别,并由网络给出识别结果。因为GUI比较繁琐,所以代码看起来偏长。不过和最终效果比起来,这点付出是值得的。我们的功能也已经实现完毕了,接下来就看一下具体效果了。

实现效果


这是程序运行时的界面,可以对参数进行自定义的输入

在Pycharm的控制台界面也能看到每张图片的识别结果,1是识别正确,0是识别错误。可以看到识别的正确率还是挺高的,由97%。

这里是用户自行选择图片进行识别,可以自己写然后进行识别,但前提是图像尺寸必须是28X28,如果尺寸打了就必须对其进行池化到28X28的大小,否则就会导致输入参数量巨大(几万乃至几十万个输入参数),那就不是BP神经网络可以解决的问题了,就必须要用到卷积深度神经网络进行特征提取再来分类了。

可以看到,我们建立的神经网络已经成果的识别了我们手写的数字了,我们的目标成功了!
到这里,我们就成功的建立了一个神经网络了,可以说是实现了最基本的人工智能。但要明白,这只是人工智能中最基础的部分,要想实现真正的强人工智能,我们还有很长的路要走。但现阶段,不管复杂还是简单的神经网络,都是基于这个网络衍生而来的,所谓万变不离其宗。万丈高楼从地起,一步一步来,终有一天会达到我们心中的目标的。

文中所需要的数据集和png图片已经放在这里啦,如果觉得有帮助的话点赞或者留言

链接:https://pan.baidu.com/s/1_1kvvV4xkUvCNh9wUZ91Sw
提取码:1874

用Python实现BP神经网络识别MNIST手写数字数据集(带GUI)相关推荐

  1. Python实现bp神经网络识别MNIST数据集

    title: "Python实现bp神经网络识别MNIST数据集" date: 2018-06-18T14:01:49+08:00 tags: [""] cat ...

  2. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  3. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

  4. 将MNIST手写数字数据集导入NumPy数组(《深度学习入门:基于Python的理论与实现》实践笔记)

    将MNIST手写数字数据集导入NumPy数组(<深度学习入门:基于Python的理论与实现>实践笔记) 一.下载MNIST数据集(使用urllib.request.urlretrieve( ...

  5. PyTorch基础入门五:PyTorch搭建多层全连接神经网络实现MNIST手写数字识别分类

    )全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC. FC的准则很简单:神经网络中除输入层之外的每个节点都和上一层的所有节点有连 ...

  6. Tensorflow之 CNN卷积神经网络的MNIST手写数字识别

    点击"阅读原文"直接打开[北京站 | GPU CUDA 进阶课程]报名链接 作者,周乘,华中科技大学电子与信息工程系在读. 前言 tensorflow中文社区对官方文档进行了完整翻 ...

  7. PyTorch入门一:卷积神经网络实现MNIST手写数字识别

    先给出几个入门PyTorch的好的资料: PyTorch官方教程(中文版):http://pytorch123.com <动手学深度学习>PyTorch版:https://github.c ...

  8. 卷积神经网络mnist手写数字识别代码_搭建经典LeNet5 CNN卷积神经网络对Mnist手写数字数据识别实例与注释讲解,准确率达到97%...

    LeNet-5卷积神经网络是最经典的卷积网络之一,这篇文章就在LeNet-5的基础上加入了一些tensorflow的有趣函数,对LeNet-5做了改动,也是对一些tf函数的实例化笔记吧. 环境 Pyc ...

  9. 全连接神经网络实现MNIST手写数字识别

    在对全连接神经网络的基本知识(全连接神经网络详解)学习之后,通过MNIST手写数字识别这个小项目来学习如何实现全连接神经网络. MNIST数据集 对于深度学习的任何项目来说,数据集是其中最为关键的部分 ...

最新文章

  1. 2019年2月26日 Unique Email Addresses、To Lower Case、Encode and Decode TinyURL
  2. English in 999
  3. 线段树segment_tree go语言实现
  4. python文件写入字典格式输出_Python把对应格式的csv文件转换成字典类型存储脚本的方法...
  5. 【嵌入式】Libmodbus源码分析(一)-类型和结构体
  6. eclipse dorado plugin
  7. python对比图片
  8. python入门——P42魔法方法:算数运算1
  9. 德鲁伊 oltp oltp_内存中OLTP系列–表创建和类型
  10. 装饰模式-包装request和response
  11. C# MP3操作类,能播放指定的mp3文件,或播放嵌入的资源中的Mp3文件
  12. ospf 指定dr_OSPF中DR、BDR竞选机制
  13. 数据挖掘案例分析(1)-Apriori算法
  14. 关于扫码点餐多人实时共享订单的思考
  15. stm32定时器配置与时间计算公式
  16. 学学Gnuplot(常用命令及参数)
  17. 【COPOD】Suppressing Poisoning Attacks on Federated Learning for Medical Imaging
  18. 苹果App被置病毒 网友:安卓无压力
  19. 字库芯片学习之汉字内码
  20. FZU Monthly-201910 tutorial

热门文章

  1. Linux李哥私房菜——open、close和fd
  2. 【汇智学堂】-python小游戏(太空阻击之三-场景创建)
  3. H3CIE RS+——ospf(1)
  4. elementui中,下拉框设置,既可以从下拉框中选择,又可以自己添加选项
  5. java新生代 老年代比例_JVM老年代和新生代的比例
  6. cocosd-x 下 2D 骨骼动画编辑器选择的闲聊
  7. 首次亮相就用中文讲故事,三星新任总裁为啥这么拼?
  8. 计算机组装与维修统测试卷7,计算机组装与维修》课程学业水平测试卷样卷答案...
  9. PHP正则表达式过滤非主流特殊字符
  10. 冬至快乐,各位IT程序员们,你们吃对饺子了吗?