用神经网络实现手写数字识别
用神经网络实现手写数字识别
这是我的第一篇关于神经网络的博客,我们的目的是建立一个全连接的神经网络模型来识别手写数字,希望通过写博客记录自己学习的过程,不断提高。本文主要参考这篇博文
一文弄懂神经网络中的反向传播法
1.准备工作
环境:python 3.7 , pycharm
在开始之前我们需要导入一些模块:
- numpy python中用来进行科学计算的基本软件包
- scipy.special 这是一个常见的激活函数,sigmoid函数
- matplotlib:用于在Python中绘制图表。
如果没有以上库请自行安装。
import numpy as np
import scipy.special
import matplotlib
我们这次实验用到的数据集是MNIST数据集,MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。研究人员已经把图片处理为了字节的形式,所以我们在使用时非常方便。
2.建立模型
这是我们这次实验的模型,从图中可以看出,我们建立的神经网络由3层组成,分别是输入层,隐藏层,输出层。神经元的个数分别为:
- 输入层神经元:784
- 隐藏层神经元:200
- 输出层神经元:10
神经网络的计算主要分为两个部分,一是正向传播,计算输出值与期望值的误差。我们使用的损失(误差)函数是最简单的函数,激活函数用的是sigmoid函数:
E = T -Y
正向传播过程如上图所示,输入数据分别乘以与之相连的线上的权重再相加,最后再放到sigmoid激活函数中即得到隐藏层的输出,依次类推即计算出了每个输出神经元的输出。通过计算损失函数我们可以知道神经网络的准确性,损失函数越小说明模型拟合的越好。
第二个部分是反向传播过程,这部分主要是通过调整神经网络中的参数,从而提高神经网络的准确性。更新权重利用的是梯度下降法,调整公式如下:
其中n为学习率,在公式中是非常关键的,关系到参数下降的程度,选的太大可能找不到最优值,选的太小训练速度会非常慢。在本文中我们选取n=0.1。如果你需要了解推导过程可以看文章开头提供的博文。
3.代码实现
本实验通过python进行实现,
import numpy as np
import scipy.special
import matplotlibclass neuralNetwork :# 用于神经网络初始化def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):# 输入层节点数self.inodes = inputnodes# 隐层节点数self.hnodes = hiddennodes# 输出层节点数self.onodes = outputnodes# 学习率self.lr = learningrate# 初始化输入层与隐层之间的权重( -1 到设置权重 1)self.wih = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))# 初始化隐层与输出层之间的权重self.who = np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))# 激活函数(Sigmoid(x)=1/1+exp(-x)函数)self.activation_function = lambda x: scipy.special.expit(x)# 神经网络学习训练def train(self, inputs_list, targets_list):# 将输入数据转化成二维矩阵(ndmin指定返回的最小维度)inputs = np.array(inputs_list, ndmin=2).T# 将输入标签转化成二维矩阵targets = np.array(targets_list, ndmin=2).T# 计算隐层的输入hidden_inputs = np.dot(self.wih, inputs)# 计算隐层的输出hidden_outputs = self.activation_function(hidden_inputs)# 计算输出层的输入final_inputs = np.dot(self.who, hidden_outputs)# 计算输出层的输出final_outputs = self.activation_function(final_inputs)# 计算输出层误差output_errors = targets - final_outputs# 计算隐层误差hidden_errors = np.dot(self.who.T, output_errors)# 更新隐层与输出层之间的权重#w = w -德尔塔w = w - E对w的偏导self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),np.transpose(hidden_outputs))# 更新隐层与输出层之间的权重self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))# 神经网络测试def test(self, inputs_list):# 将输入数据转化成二维矩阵inputs = np.array(inputs_list, ndmin=2).T# 计算隐层的输入hidden_inputs = np.dot(self.wih, inputs)# 计算隐层的输出hidden_outputs = self.activation_function(hidden_inputs)# 计算输出层的输入final_inputs = np.dot(self.who, hidden_outputs)# 计算输出层的输出final_outputs = self.activation_function(final_inputs)return final_outputsif __name__ == "__main__":# 初始化 784(28 * 28)个输入节点,100个隐层节点,10个输出节点(0~9)input_nodes = 784hidden_nodes = 200output_nodes = 10# 学习率0.1learning_rate = 0.1# 训练次数epochs = 5# 初始化神经网络实例n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)# 读取训练集training_data_file = open('mnist_dataset/mnist_train_100.csv', 'r')training_data_list = training_data_file.readlines()training_data_file.close()# 训练数据for e in range(epochs):for record in training_data_list:all_values = record.split(',')# 输入数据范围(0.01~1)inputs = np.asfarray(all_values[1:]) / 255.0 * 0.99 + 0.01# 标记数据(相应标记为0.99,其余0.01)targets = np.zeros(output_nodes) + 0.01targets[int(all_values[0])] = 0.99n.train(inputs, targets)# 读取测试数据test_data_file = open('mnist_dataset/mnist_test.csv', 'r')test_data_list = test_data_file.readlines()test_data_file.close()# 打印测试数据标签count = 0.0for record in test_data_list:test_data = record.split(',')print('原标签:', test_data[0])# 生成标签图片# image_array = np.asfarray(test_data[1:]).reshape(28, 28)#plt.imshow(image_array, cmap='Greys', interpolation='None')# plt.show()# 利用神经网络预测results = n.test(np.asfarray(test_data[1:]) / 255.0 * 0.99 + 0.01)pre_label = np.argmax(results)print('预测结果:', pre_label)if int(pre_label) == int(test_data[0]):count = count + 1#print(results)print(count)rating = count/10000print("correct rating: %f" % rating)
小结
由于本文是一篇深度学习入门的小实验,目的在于以最快的时间了解深度学习的方法和框架,所以选用的模型,损失函数等比较简单,因此手写数字的识别率不高,只有66%左右。接下来会通过调整超参数,模型的选择等方法来提高识别率。
用神经网络实现手写数字识别相关推荐
- 我的Go+语言初体验——Go+语言构建神经网络实战手写数字识别
"我的Go+语言初体验" | 征文活动进行中- 我的Go+语言初体验--Go+语言构建神经网络实战手写数字识别 0. 前言 1. 神经网络相关概念 2. 构建神经网络实战手写数字识 ...
- 读书笔记-深度学习入门之pytorch-第四章(含卷积神经网络实现手写数字识别)(详解)
1.卷积神经网络在图片识别上的应用 (1)局部性:对一张照片而言,需要检测图片中的局部特征来决定图片的类别 (2)相同性:可以用同样的模式去检测不同照片的相同特征,只不过这些特征处于图片中不同的位置, ...
- 深度学习 卷积神经网络-Pytorch手写数字识别
深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...
- MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试
文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...
- 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)
基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...
- 神经网络实现手写数字识别(MNIST)
一.缘起 原本想沿着 传统递归算法实现迷宫游戏 --> 遗传算法实现迷宫游戏 --> 神经网络实现迷宫游戏的思路,在本篇当中也写如何使用神经网络实现迷宫的,但是研究了一下, 感觉有些麻烦不 ...
- 深度学习笔记:07神经网络之手写数字识别的经典实现
神经网络之手写数字识别的经典实现 上一节完成了简单神经网络代码的实现,下面我们将进行最终的实现:输入一张手写图片后,网络输出该图片对应的数字.由于网络需要用0-9一共十个数字中挑选出一个,所以我们的网 ...
- 基于matlab BP神经网络的手写数字识别
摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...
- 基于BP神经网络的手写数字识别
基于BP神经网络的手写数字识别 摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入 ...
- 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作
卷积神经网络与循环神经网络实战 - 手写数字识别及诗词创作 文章目录 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作 一.神经网络相关知识 1. 深度学习 2. 人工神经网络回顾 3. ...
最新文章
- 在实践中深入理解IP协议
- 接口java_JAVA 初识接口
- linux下安装glibc-2.14,解决“`GLIBC_2.14' not found”问题
- 一个好用的markdown表格生成工具
- 数据结构 | B树、B+树、B*树
- 如何导出树结构清晰的代码机构目录
- 自动化运维工具Ansible实战(四)常用模块
- java file的用法_Java 关于File使用
- 不同网段的局域网怎么互通_智能化工程中,局域网IP地址不够用怎么解决?
- 查看SQL Server被锁的表以及如何解锁【转】
- 点击按钮返回上一个页面_零基础跟老陈一起学WordPress 《第四课》用WP半小时建一个商业网站...
- centos7.x/RedHat7.x重命名网卡名称
- Unity3D导出Android工程(Android中应用Unity3D)
- 用友u8怎么导出凭证_用友U8怎么导入凭证?
- html源代码中 图像的属性标记,HTML图像标签img和源属性src及Alt属性、宽高、对齐...
- 磁盘分区,格式化,挂载
- 常见的服务器报错数字的意思
- 推荐几个免费好用的毕业论文(设计)文献查找网站包括外文文献(亲测有用)
- 计算机开机图片怎么换,如何把电脑开机画面换成自己的图片?
- java-final关键字修饰变量