原文地址:点击访问

许久未更,是因为开学之后学习任务太充实了。每天都有做不完的事情,每件事情都又想把它做好。

我航中秋国庆假期长达8天,真应了那句话:该放的假一天不少,该补的课一次没有。期间,有多门作业要完成。今天,为大家推送简单神经网络的实现,是我的《人工智能加速器》的作业。

实验内容

搭建基本的多层神经网络,并在给定测试集上进行精度测试。

  • 注1:不使用深度学习框架完成网络搭建。

  • 注2:不限制编程语言,推荐使用python进行神经网络搭建,允许使用numpy等工具包。

  • 注3:使用给定的训练集和测试集,可使用提供的代码模板(bp_template.py)并在其基础上进行修改,也可以重新进行编写。

实验要求

  • 网络输入:784 个输入节点(每个节点对应图片的一个像素)
  • 网络输出:10 个输出节点(分别代表0~9 这10 个数字)
  • 网络深度建议为3 至5 层即可,如果太深则需要太长运行时间。
  • 使用给定训练集(mnist_train.csv)进行权重训练,使用测试集(mnist_test.csv)测试并给出测试精度。(不对精度做特别的要求,只需在合理范围内即可)

网络架构

本文参考tutorial学习了后向传播的原理, 参考tutorial学习了后向传播的设计, 从而设计出了两个隐含层的简单神经网络. 具体网络架构见图1所示.

主要代码实现

完整代码点击阅读原文跳转到我的Github,后向与前向传播代码如下:

# neural network class definition
class neuralNetwork:param = {}# initialise the neural networkdef __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):'''The network consists of three types of layers: input layer(784 nodes), hidden layers and output layer(10 nodes).You can define the number of hidden nodes and layers you want.'''self.hiddennodes_1 = hiddennodes[0]self.hiddennodes_2 = hiddennodes[1]self.param['W1'] = np.random.randn(hiddennodes[0], inputnodes) * np.sqrt(1 / hiddennodes[0])self.param['b1'] = np.random.randn(hiddennodes[0], 1) * np.sqrt(1 / hiddennodes[0])self.param['W2'] = np.random.randn(hiddennodes[1], hiddennodes[0]) * np.sqrt(1 / outputnodes)self.param['b2'] = np.random.randn(hiddennodes[1], 1) * np.sqrt(1 / hiddennodes[1])self.param['W3'] = np.random.randn(outputnodes, hiddennodes[1]) * np.sqrt(1 / outputnodes)self.param['b3'] = np.random.randn(outputnodes, 1) * np.sqrt(1 / outputnodes)self.learningrate = learningrateself.inputnodes = inputnodesself.hiddennodes = hiddennodesself.outputnodes = outputnodesdef forward(self, inputs_list):'''forward the neural network'''inputs_list = inputs_list.reshape(-1, 1)self.inputs_list = inputs_listz1 = np.dot(self.param['W1'], inputs_list) + self.param['b1']h1 = sigmoid(z1)z2 = np.dot(self.param['W2'], h1) + self.param['b2']h2 = sigmoid(z2)z3 = np.dot(self.param['W3'], h2) + self.param['b3']h3 = sigmoid(z3)self.final_outputs = h3self.z1 = z1self.h1 = h1self.z2 = z2self.h2 = h2self.z3 = z3self.h3 = h3def Backpropagation(self, targets_list):'''propagate backword'''change = {}targets_list = targets_list.reshape(-1, 1)loss_val = mse_loss(targets_list, self.final_outputs)# calculate W3 updateerror = -2 * (targets_list - self.final_outputs)error = np.multiply(error, sigmoid(self.z3, derivative=True))change['W3'] = np.dot(error, self.h2.T)change['b3'] = error# calculate W2 updateerror = np.multiply(np.dot(self.param['W3'].T, error), sigmoid(self.z2, derivative=True))change['W2'] = np.dot(error, self.h1.T)change['b2'] = error# calculate W1 updateerror = np.multiply(np.dot(self.param['W2'].T, error), sigmoid(self.z1, derivative=True))change['W1'] = np.dot(error, self.inputs_list.T)change['b1'] = errorself.param['W1'] -= self.learningrate * change['W1']self.param['b1'] -= self.learningrate * change['b1']self.param['W2'] -= self.learningrate * change['W2']self.param['b2'] -= self.learningrate * change['b2']self.param['W3'] -= self.learningrate * change['W3']self.param['b3'] -= self.learningrate * change['b3']return loss_val

精度与损失

本实验中, 训练了20代, 共耗时3590.8710s. 在训练集上的损失和验证精度如下图, 可以看到, 随着训练代数增多, losslossloss值逐渐减低, 精度逐渐升高, 最高可以达到100%.

最终, 在测试集上的精度达到了98%. 图3展示了测试集中前10个测试样本的预测结果, 期预测结果与其真实标签基本吻合. 但是, 对于图3(i)这样人为都难以辨别出来的测试样本, 神经网络就更加难以预测准确. 现实中, 这种脏数据往往是无意义的.

训练时间

整个实验运行时间如下表所示, 其中训练时间最长, 平均每一代的训练时间为179.5436s, 相当耗时.

图4显示了各个时间的占比, 平均每代训练时间占比高达93.6%.

实验环境

本文使用Python 3.6, 在配置为Intel® Xeon® Gold 5120T CPU @2.20GHz 2.19 GHz (2 processors)的PC机上进行实验,

【深度学习】Numpy实现简单神经网络相关推荐

  1. 计算机视觉与深度学习 | 基于MATLAB 深度学习工具实现简单的数字分类问题(卷积神经网络)

    博主github:https://github.com/MichaelBeechan 博主CSDN:https://blog.csdn.net/u011344545 %% Time:2019.3.7 ...

  2. 【深度学习笔记1】神经网络的搭建与简单应用

    目录 推荐阅读 前言 神经网络与深度学习 使用Tensorflow搭建神经网络 环境搭建和导包遇到的问题: 问题1:Duplicate registrations for type 'optimize ...

  3. 深度学习(DL)与卷积神经网络(CNN)学习笔记随笔-04-基于Python的LeNet之MLP

    原文地址可以查看更多信息 本文主要参考于:Multilayer Perceptron  python源代码(github下载 CSDN免费下载) 本文主要介绍含有单隐层的MLP的建模及实现.建议在阅读 ...

  4. 深度学习(DL)与卷积神经网络(CNN)学习笔记随笔-03-基于Python的LeNet之LR

    原地址可以查看更多信息 本文主要参考于:Classifying MNIST digits using Logistic Regression  python源代码(GitHub下载 CSDN免费下载) ...

  5. 深度学习21天——卷积神经网络(CNN):实现mnist手写数字识别(第1天)

    目录 一.前期准备 1.1 环境配置 1.2 CPU和GPU 1.2.1 CPU 1.2.2 GPU 1.2.3 CPU和GPU的区别 第一步:设置GPU 1.3 MNIST 手写数字数据集 第二步: ...

  6. 深度学习初级阶段——全连接神经网络(MLP、FNN)

    在前面的数学原理篇中,已经提到了各种深度学习的数学知识基本框架,那么从这篇文章开始,我将和大家一起走进深度学习的大门(部分图片和描述取自其他文章). 目录 一.首先我们需要知道什么是深度学习? 二.了 ...

  7. 深度学习笔记:卷积神经网络的可视化--卷积核本征模式

    目录 1. 前言 2. 代码实验 2.1 加载模型 2.2 构造返回中间层激活输出的模型 2.3 目标函数 2.4 通过随机梯度上升最大化损失 2.5 生成滤波器模式可视化图像 2.6 将多维数组变换 ...

  8. Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类)

    Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类) 1.卷积神经网络 1.1卷积神经网络简介 1.2卷积运算 1.3 深度学习与小数据问题的相关性 2.下载数据 2.1下载原始数据 ...

  9. 深度学习与围棋:神经网络入门

    本文主要内容 介绍人工神经网络的基础知识. 指导神经网络学习如何识别手写数字. 组合多个层来创建神经网络. 理解神经网络从数据中学习的原理. 从零开始实现一个简单的神经网络. 本章介绍人工神经网络(A ...

  10. 普通视频转高清:10个基于深度学习的超分辨率神经网络

    原文:http://www.tinymind.cn/articles/1176 在 AlphaGo 对弈李世石.柯洁之后,更多行业开始尝试通过机器学习优化现有技术方案.其实对于实时音视频来讲,对机器学 ...

最新文章

  1. .net卸载程序制作
  2. 【怎样写代码】参数化类型 -- 泛型(八):泛型委托
  3. python处理csv数据-Python处理csv文件
  4. 谁是谷歌想要的人才:智商高不见得总是好员工
  5. SVN Could not open the requested SVN filesystem解决办法
  6. LeetCode 1238. 循环码排列(格雷编码+旋转数组)
  7. 017 在SecureCRT中安装rz小工具
  8. WORD如何创建三线表样式?
  9. 深度学习-吴恩达-笔记-2-神经网络的编程基础
  10. 大白话5分钟带你走进人工智能-第二十四节决策树系列之分裂流程和Gini系数评估(3)...
  11. 微信浏览器禁止页面下拉查看网址(不影响页面内部scroll)
  12. 渗透测试教程(基础篇)-2
  13. 横向合计代码 锐浪报表_报表开发常见问题解答 - 锐浪报表工具
  14. Nginx下的反向代理 双层代理 负载均衡
  15. 电脑文档误删除怎么恢复,恢复误删除电脑文档的方法
  16. 单链表 尾插法 C语言
  17. potatso lite怎么添加代理_「科技犬」除了苹果AirPods,真无线蓝牙耳机到底怎么选?_蓝牙耳机...
  18. 1 --> 以太网 PHY 层简介
  19. 计算一个月有几天并且有几个周六日的小函数
  20. 浅谈企业微信公域到私域流量玩法

热门文章

  1. flutter 九宫格菜单_flutter九宫格图片查看器
  2. 世界上第一台数字计算机图片大全,第二章 计算机中的图世界
  3. 对敏捷宣言的原则进行风险评估
  4. 罗杨美慧 20190919-3 效能分析
  5. html文件右键没有打开方式,一个文件打不开,点右键,怎么在打开方式中加入Word,Excel的打开方式,打开方式中有Word的打开方式?...
  6. FPGA——SPI总线控制flash(3)含代码
  7. Gvim计数器模板经典练习
  8. 物联网萤石云获取登录的accessToken工具类
  9. android百度地图定位跳转中心点,百度地图,拖动地图,定位marker固定在屏幕中心位置...
  10. 信息系统项目管理师(2022年)—— 重点内容:项目质量管理(8)