用神经网络实现手写数字识别

这是我的第一篇关于神经网络的博客,我们的目的是建立一个全连接的神经网络模型来识别手写数字,希望通过写博客记录自己学习的过程,不断提高。本文主要参考这篇博文
一文弄懂神经网络中的反向传播法

1.准备工作

环境:python 3.7 , pycharm
在开始之前我们需要导入一些模块:

  • numpy python中用来进行科学计算的基本软件包
  • scipy.special 这是一个常见的激活函数,sigmoid函数
  • matplotlib:用于在Python中绘制图表。

如果没有以上库请自行安装。

import numpy as np
import scipy.special
import matplotlib

我们这次实验用到的数据集是MNIST数据集,MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。研究人员已经把图片处理为了字节的形式,所以我们在使用时非常方便。

2.建立模型


这是我们这次实验的模型,从图中可以看出,我们建立的神经网络由3层组成,分别是输入层,隐藏层,输出层。神经元的个数分别为:

  • 输入层神经元:784
  • 隐藏层神经元:200
  • 输出层神经元:10

神经网络的计算主要分为两个部分,一是正向传播,计算输出值与期望值的误差。我们使用的损失(误差)函数是最简单的函数,激活函数用的是sigmoid函数:
E = T -Y
正向传播过程如上图所示,输入数据分别乘以与之相连的线上的权重再相加,最后再放到sigmoid激活函数中即得到隐藏层的输出,依次类推即计算出了每个输出神经元的输出。通过计算损失函数我们可以知道神经网络的准确性,损失函数越小说明模型拟合的越好。
第二个部分是反向传播过程,这部分主要是通过调整神经网络中的参数,从而提高神经网络的准确性。更新权重利用的是梯度下降法,调整公式如下:

其中n为学习率,在公式中是非常关键的,关系到参数下降的程度,选的太大可能找不到最优值,选的太小训练速度会非常慢。在本文中我们选取n=0.1。如果你需要了解推导过程可以看文章开头提供的博文。

3.代码实现

本实验通过python进行实现,

import numpy as np
import scipy.special
import matplotlibclass neuralNetwork :# 用于神经网络初始化def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):# 输入层节点数self.inodes = inputnodes# 隐层节点数self.hnodes = hiddennodes# 输出层节点数self.onodes = outputnodes# 学习率self.lr = learningrate# 初始化输入层与隐层之间的权重( -1 到设置权重 1)self.wih = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))# 初始化隐层与输出层之间的权重self.who = np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))# 激活函数(Sigmoid(x)=1/1+exp(-x)函数)self.activation_function = lambda x: scipy.special.expit(x)# 神经网络学习训练def train(self, inputs_list, targets_list):# 将输入数据转化成二维矩阵(ndmin指定返回的最小维度)inputs = np.array(inputs_list, ndmin=2).T# 将输入标签转化成二维矩阵targets = np.array(targets_list, ndmin=2).T# 计算隐层的输入hidden_inputs = np.dot(self.wih, inputs)# 计算隐层的输出hidden_outputs = self.activation_function(hidden_inputs)# 计算输出层的输入final_inputs = np.dot(self.who, hidden_outputs)# 计算输出层的输出final_outputs = self.activation_function(final_inputs)# 计算输出层误差output_errors = targets - final_outputs# 计算隐层误差hidden_errors = np.dot(self.who.T, output_errors)# 更新隐层与输出层之间的权重#w = w -德尔塔w = w - E对w的偏导self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),np.transpose(hidden_outputs))# 更新隐层与输出层之间的权重self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))# 神经网络测试def test(self, inputs_list):# 将输入数据转化成二维矩阵inputs = np.array(inputs_list, ndmin=2).T# 计算隐层的输入hidden_inputs = np.dot(self.wih, inputs)# 计算隐层的输出hidden_outputs = self.activation_function(hidden_inputs)# 计算输出层的输入final_inputs = np.dot(self.who, hidden_outputs)# 计算输出层的输出final_outputs = self.activation_function(final_inputs)return final_outputsif __name__ == "__main__":# 初始化 784(28 * 28)个输入节点,100个隐层节点,10个输出节点(0~9)input_nodes = 784hidden_nodes = 200output_nodes = 10# 学习率0.1learning_rate = 0.1# 训练次数epochs = 5# 初始化神经网络实例n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)# 读取训练集training_data_file = open('mnist_dataset/mnist_train_100.csv', 'r')training_data_list = training_data_file.readlines()training_data_file.close()# 训练数据for e in range(epochs):for record in training_data_list:all_values = record.split(',')# 输入数据范围(0.01~1)inputs = np.asfarray(all_values[1:]) / 255.0 * 0.99 + 0.01# 标记数据(相应标记为0.99,其余0.01)targets = np.zeros(output_nodes) + 0.01targets[int(all_values[0])] = 0.99n.train(inputs, targets)# 读取测试数据test_data_file = open('mnist_dataset/mnist_test.csv', 'r')test_data_list = test_data_file.readlines()test_data_file.close()# 打印测试数据标签count = 0.0for record in test_data_list:test_data = record.split(',')print('原标签:', test_data[0])# 生成标签图片# image_array = np.asfarray(test_data[1:]).reshape(28, 28)#plt.imshow(image_array, cmap='Greys', interpolation='None')# plt.show()# 利用神经网络预测results = n.test(np.asfarray(test_data[1:]) / 255.0 * 0.99 + 0.01)pre_label = np.argmax(results)print('预测结果:', pre_label)if int(pre_label) == int(test_data[0]):count = count + 1#print(results)print(count)rating = count/10000print("correct rating: %f" % rating)

小结

由于本文是一篇深度学习入门的小实验,目的在于以最快的时间了解深度学习的方法和框架,所以选用的模型,损失函数等比较简单,因此手写数字的识别率不高,只有66%左右。接下来会通过调整超参数,模型的选择等方法来提高识别率。

用神经网络实现手写数字识别相关推荐

  1. 我的Go+语言初体验——Go+语言构建神经网络实战手写数字识别

    "我的Go+语言初体验" | 征文活动进行中- 我的Go+语言初体验--Go+语言构建神经网络实战手写数字识别 0. 前言 1. 神经网络相关概念 2. 构建神经网络实战手写数字识 ...

  2. 读书笔记-深度学习入门之pytorch-第四章(含卷积神经网络实现手写数字识别)(详解)

    1.卷积神经网络在图片识别上的应用 (1)局部性:对一张照片而言,需要检测图片中的局部特征来决定图片的类别 (2)相同性:可以用同样的模式去检测不同照片的相同特征,只不过这些特征处于图片中不同的位置, ...

  3. 深度学习 卷积神经网络-Pytorch手写数字识别

    深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...

  4. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  5. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  6. 神经网络实现手写数字识别(MNIST)

    一.缘起 原本想沿着 传统递归算法实现迷宫游戏 --> 遗传算法实现迷宫游戏 --> 神经网络实现迷宫游戏的思路,在本篇当中也写如何使用神经网络实现迷宫的,但是研究了一下, 感觉有些麻烦不 ...

  7. 深度学习笔记:07神经网络之手写数字识别的经典实现

    神经网络之手写数字识别的经典实现 上一节完成了简单神经网络代码的实现,下面我们将进行最终的实现:输入一张手写图片后,网络输出该图片对应的数字.由于网络需要用0-9一共十个数字中挑选出一个,所以我们的网 ...

  8. 基于matlab BP神经网络的手写数字识别

    摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...

  9. 基于BP神经网络的手写数字识别

    基于BP神经网络的手写数字识别 摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入 ...

  10. 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作

    卷积神经网络与循环神经网络实战 - 手写数字识别及诗词创作 文章目录 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作 一.神经网络相关知识 1. 深度学习 2. 人工神经网络回顾 3. ...

最新文章

  1. 在实践中深入理解IP协议
  2. 接口java_JAVA 初识接口
  3. linux下安装glibc-2.14,解决“`GLIBC_2.14' not found”问题
  4. 一个好用的markdown表格生成工具
  5. 数据结构 | B树、B+树、B*树
  6. 如何导出树结构清晰的代码机构目录
  7. 自动化运维工具Ansible实战(四)常用模块
  8. java file的用法_Java 关于File使用
  9. 不同网段的局域网怎么互通_智能化工程中,局域网IP地址不够用怎么解决?
  10. 查看SQL Server被锁的表以及如何解锁【转】
  11. 点击按钮返回上一个页面_零基础跟老陈一起学WordPress 《第四课》用WP半小时建一个商业网站...
  12. centos7.x/RedHat7.x重命名网卡名称
  13. Unity3D导出Android工程(Android中应用Unity3D)
  14. 用友u8怎么导出凭证_用友U8怎么导入凭证?
  15. html源代码中 图像的属性标记,HTML图像标签img和源属性src及Alt属性、宽高、对齐...
  16. 磁盘分区,格式化,挂载
  17. 常见的服务器报错数字的意思
  18. 推荐几个免费好用的毕业论文(设计)文献查找网站包括外文文献(亲测有用)
  19. 计算机开机图片怎么换,如何把电脑开机画面换成自己的图片?
  20. java-final关键字修饰变量

热门文章

  1. 4.1 模拟低通滤波器设计
  2. VSS的基本使用操作介绍
  3. 配置IDEA运行环境
  4. 计算机文化基础十一版百度云,计算机文化基础(高职高专版 第十一版)第一章答案...
  5. matlab神经网络原理应用实例pdf,MATLAB神经网络原理与实例精解
  6. 聚类分析入门(理论)
  7. 2001年李彦宏DoNews三篇搜索引擎Blog
  8. (Django开发)免费HTML模板资源集合
  9. 网易游戏岗位大揭秘(我是文案策划)
  10. 计算机操作系统哪几部分组成,计算机操作系统的组成部分