神经网络是由以下基本函数组成:传播函数(包括前向传播,反向传播),激活函数,损失函数

这几个函数的作用: 前向传播:向前预测结果(由已知的参数预测) 激活函数:加入非线性因素,实现映射关系
反向传播:通过前向传播的预测值与实际的值之间的误差利用链式求导法则反向更新参数实现优化 损失函数:损失计算

一个最简单的神经网络就是由上述四个函数构成的,类似于大脑的工作原理。

简单说一下前向传播:前向传播很好理解,就是将上一层的输出作为下一层的输入并计算下一层的输出一直运算到输出层为止,得到输出层的预测结果。

反向传播也很好理解,主要是用到了链式法则,这个知识点在高数下册第九章多元函数微分法及其应用中的第四节中有详细的讲解,至于为什么用链式求导法则来实现反向传播以致于实现参数的更新,我个人理解是导数的一个几何性质可以解释:


这是个导数,f对a求偏导,他的值反应了f在a处变化的快慢,假如这个导数等于0了,从图像上看斜率就是0,他没有变化了,这个时候的参数也就更新得跟实际得一样了。

反向传播理解点这里

整个过程大概是这样:


接下来实现四个函数:

激活函数这里采用了sigmoid,也可以替换成其他比如tanh,sigmod函数的取值范围在(0, 1)之间,可以将网络的输出映射在这一范围,方便分析。层数小可以用,层数大容易出现梯度消失的情况,根据求导可见每一次向前传播都会损失3/4。而且计算量很大。

def sigmoid(x):s = 1 / (1 + np.exp(-x))return s

初始化参数w,b

def initialize_with_zeros(dim):w = np.zeros((dim, 1))  b = 0return w, b

损失函数
这里用了交叉熵损失
原理看这里

def costCAL(img, Y):m = img.shape[1]cost = -np.sum(train_label * np.log(Y) + (1 - train_label) * np.log(1 - Y)) / mreturn cost

y是使用激活函数处理后的数值,在0~1上的概率分布。

前向传播:
前向传播的公式如下:

权重w的意义:体现出x的那些像素比较重要,体现在图片上则是x1更多可能是数字的组成部分,x2更多可能是背景。
b的意义:调整神经元向后传递信号的难易程度

在实际写代码时,训练时需要输入6000个图片,一开始我是想用个for循环一个个处理,实际上python提供了一个函数np.dot(矩阵乘积)
A = np.dot(W.T, X)+b

W的数据格式是(784,1),转置后是(1,784)

X的数据格式是(784,6000),6000代表一次计算6000个图片的结果

A的数据格式是(1,6000),A是6000个图片的计算结果合集

def propagate(w, b, img):# 向前传播A = np.dot(w.T, img) + b# 使用激活函数将A的值映射到0~1的区间Y = sigmoid(A)return Y

反向传播:
传播函数通过w、b计算出A,通过激活函数又计算出Y
反向传播函数通过Y计算出dw和db
使用dw和db,通过梯度下降的方法得到新的w和b
使用更新后的w和b重复前面的运算过程

def back_propagate(Y, img, label):m = img.shape[1]dZ = Y - labeldw = np.dot(img, dZ.T) / mdb = np.sum(dZ) / mreturn dw, db

更新参数

def optimize(img, label, w, b, num_iterations, learning_rate, print_cost):# 梯度下降法,循环num_iterations次找到最优w和bfor i in range(num_iterations):# 向前传播一次Y = propagate(w, b, img)# 计算成本cost = costCAL(img, label, Y)# 反向传播得到dw、dbdw, db = back_propagate(Y, img, label)# 更新w和bw = w - learning_rate * dwb = b - learning_rate * db# 每100次输出一下成本if i % 100 == 0:if print_cost:print('优化%i次后成本是:%f' % (i, cost))return w, b

预测函数

def predict(w, b, img):m = img.shape[1]Y_prediction = np.zeros((1, m))# 向前传播得到YY = propagate(w, b, img)for i in range(Y.shape[1]):# 若Y的值大于等于0.5就认为该图片为数字9,否则不是数字9if Y[0, i] >= 0.5:Y_prediction[0, i] = 1return Y_prediction

最终结果:


踩坑日记:
我想换激活函数relu但是忽略了后面损失函数log是不能取0的,那么就会出现下面这种情况。


完整代码

import math
import sys, os
import numpy as np
from PIL import Image
from mnist import init_mnist
from mnist import load_mnist# 设置np输出数组时不做缩写
np.set_printoptions(edgeitems=1000000)# 下载数据集,会下载6W个训练数据核1W个测试数据
init_mnist()# 加载数据集 数据集包括一张图片和一个正确标注
(train_img, train_label), (test_img, test_label) = load_mnist(normalize=True, flatten=True, one_hot_label=False)# 将label不是9的数据全部转为0,将9转为1
train_label = np.where(train_label == 9, 1, 0)
test_label = np.where(test_label == 9, 1, 0)# 修改数组的大小
train_img = np.resize(train_img, (6000, train_img.shape[1]))
train_label = np.resize(train_label, (6000, 1))
test_img = np.resize(test_img, (1000, test_img.shape[1]))
test_label = np.resize(test_label, (1000, 1))# 需要将数据进行转置
train_img = train_img.T
train_label = train_label.T
test_img = test_img.T
test_label = test_label.T# sigmoid函数
def sigmoid(x):s = 1 / (1 + np.exp(-x))return s# 初始化权重数组w和偏置b(默认为0)
def initialize_with_zeros(dim):w = np.zeros((dim, 1))  # 全零的数组b = 0return w, b# 通过Y和标注计算成本
def costCAL(img, label, Y):m = img.shape[1]cost = -np.sum(train_label * np.log(Y) + (1 - train_label) * np.log(1 - Y)) / mreturn cost# 向前传播得到Y
def propagate(w, b, img):m = img.shape[1]# 向前传播A = np.dot(w.T, img) + b# 使用激活函数将A的值映射到0~1的区间Y = sigmoid(A)return Y# 反向传播得到dw、db
def back_propagate(Y, img, label):m = img.shape[1]dZ = Y - labeldw = np.dot(img, dZ.T) / mdb = np.sum(dZ) / mreturn dw, db# 通过梯度下降法更新w和b
def optimize(img, label, w, b, num_iterations, learning_rate, print_cost):# 梯度下降法,循环num_iterations次找到最优w和bfor i in range(num_iterations):# 向前传播一次Y = propagate(w, b, img)# 计算成本cost = costCAL(img, label, Y)# 反向传播得到dw、dbdw, db = back_propagate(Y, img, label)# 更新w和bw = w - learning_rate * dwb = b - learning_rate * db# 每100次输出一下成本if i % 100 == 0:if print_cost:print('优化%i次后成本是:%f' % (i, cost))return w, b# 预测函数
def predict(w, b, img):m = img.shape[1]Y_prediction = np.zeros((1, m))# 向前传播得到YY = propagate(w, b, img)for i in range(Y.shape[1]):# 若Y的值大于等于0.5就认为该图片为数字9,否则不是数字9if Y[0, i] >= 0.5:Y_prediction[0, i] = 1return Y_prediction# 按训练图片的数量生成w和b,
w, b = initialize_with_zeros(train_img.shape[0])# 通过梯度下降法更新w和b
w, b = optimize(train_img, train_label, w, b, 2000, 0.005, True)Y_prediction_train = predict(w, b, train_img)
Y_prediction_test = predict(w, b, test_img)print('对训练图片的预测准确率为:{}%'.format(100 - np.mean(np.abs(Y_prediction_train - train_label)) * 100))
print('对测试图片的预测准确率为:{}%'.format(100 - np.mean(np.abs(Y_prediction_test - test_label)) * 100))

神经网络实现手写数字识别相关推荐

  1. 我的Go+语言初体验——Go+语言构建神经网络实战手写数字识别

    "我的Go+语言初体验" | 征文活动进行中- 我的Go+语言初体验--Go+语言构建神经网络实战手写数字识别 0. 前言 1. 神经网络相关概念 2. 构建神经网络实战手写数字识 ...

  2. 读书笔记-深度学习入门之pytorch-第四章(含卷积神经网络实现手写数字识别)(详解)

    1.卷积神经网络在图片识别上的应用 (1)局部性:对一张照片而言,需要检测图片中的局部特征来决定图片的类别 (2)相同性:可以用同样的模式去检测不同照片的相同特征,只不过这些特征处于图片中不同的位置, ...

  3. 深度学习 卷积神经网络-Pytorch手写数字识别

    深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...

  4. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  5. 基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)

    基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明) 配置环境 1.前言 2.问题描述 3.解决方案 4.实现步骤 4.1数据集选择 4.2构建网络 4.3训练网络 4.4测试网络 4.5图 ...

  6. 神经网络实现手写数字识别(MNIST)

    一.缘起 原本想沿着 传统递归算法实现迷宫游戏 --> 遗传算法实现迷宫游戏 --> 神经网络实现迷宫游戏的思路,在本篇当中也写如何使用神经网络实现迷宫的,但是研究了一下, 感觉有些麻烦不 ...

  7. 深度学习笔记:07神经网络之手写数字识别的经典实现

    神经网络之手写数字识别的经典实现 上一节完成了简单神经网络代码的实现,下面我们将进行最终的实现:输入一张手写图片后,网络输出该图片对应的数字.由于网络需要用0-9一共十个数字中挑选出一个,所以我们的网 ...

  8. 基于matlab BP神经网络的手写数字识别

    摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入.灰度化以及二值化等处理,通过神 ...

  9. 基于BP神经网络的手写数字识别

    基于BP神经网络的手写数字识别 摘要 本文实现了基于MATLAB关于神经网络的手写数字识别算法的设计过程,采用神经网络中反向传播神经网络(即BP神经网络)对手写数字的识别,由MATLAB对图片进行读入 ...

  10. 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作

    卷积神经网络与循环神经网络实战 - 手写数字识别及诗词创作 文章目录 卷积神经网络与循环神经网络实战 --- 手写数字识别及诗词创作 一.神经网络相关知识 1. 深度学习 2. 人工神经网络回顾 3. ...

最新文章

  1. 第6章:可维护性软件构建方法 6.2可维护性设计模式
  2. CSP2020洛谷P7077:函数调用
  3. python 操作系统学习_操作系统学习
  4. 代码质量差,啥都干不好!丨技术大牛:你的代码正在毁掉你!
  5. C语言 链表 3个结点,一个关于C语言链表头结点的问题
  6. TCA9539 IO扩展芯片
  7. jsp数据库中文乱码处理
  8. VLAN 虚拟局域网 搭建
  9. 哪些机器学习模型需要归一化
  10. 一文搞懂什么是Hadoop?Hadoop的前世今生,Hadoop的优点有哪些?Hadoop面试考查重点,大数据技术生态体系
  11. 微星笔记本安装Ubuntu桌面版
  12. 高数_第3章重积分_在柱面坐标下计算三重积分
  13. 磁共振功能成像BOLD-fMRI原理
  14. vue+element实现滚动公告栏效果
  15. Javascript 与 或 非 符号
  16. 文件服务器审计---首选Netwrix文件服务器审计工具
  17. 是谁用Python弹奏一曲东风破
  18. MATLAB绘制SOI指数
  19. 帮北航小妹妹做的一道她的C++的作业题.
  20. 蓝桥杯2015年第六届真题-穿越雷区

热门文章

  1. Charles 的简单使用
  2. 怎样可以在线将pdf转换成jpg格式
  3. 电脑android模拟器下载地址,原神电脑版怎么下载 安卓模拟器电脑版下载地址
  4. Allegro添加Logo方法
  5. Python 实现哥德巴赫猜想
  6. Sklearn实现非线性回归
  7. 【无标题】2022年施工员-设备方向-通用基础(施工员)考试模拟100题及模拟考试
  8. 案件被终本后,失信被执行人会从黑名单中移除吗?
  9. java打印2到10000的所有素数(质数),每行显示8个素数
  10. linux socket write()函数阻塞卡住线程问题(线程无法结束)write()非阻塞代码