DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测

导读
          计算图在神经网络算法中的作用。计算图的节点是由局部计算构成的。局部计算构成全局计算。计算图的正向传播进行一般的计算。通过计算图的反向传播,可以计算各个节点的导数。

目录

输出结果

设计思路

核心代码


输出结果


设计思路

核心代码

class TwoLayerNet:def __init__(self, input_size, hidden_size, output_size, weight_init_std = 0.01):self.params = {}self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)self.params['b1'] = np.zeros(hidden_size)self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size) self.params['b2'] = np.zeros(output_size)self.layers = OrderedDict()self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])self.layers['Relu1'] = Relu()self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2'])self.lastLayer = SoftmaxWithLoss()def predict(self, x):for layer in self.layers.values():x = layer.forward(x)return x# x:输入数据, t:监督数据def loss(self, x, t):y = self.predict(x)return self.lastLayer.forward(y, t)def accuracy(self, x, t):y = self.predict(x)y = np.argmax(y, axis=1)if t.ndim != 1 : t = np.argmax(t, axis=1)accuracy = np.sum(y == t) / float(x.shape[0])return accuracydef gradient(self, x, t):self.loss(x, t)dout = 1dout = self.lastLayer.backward(dout)layers = list(self.layers.values())layers.reverse()for layer in layers:dout = layer.backward(dout)grads = {}grads['W1'], grads['b1'] = self.layers['Affine1'].dW, self.layers['Affine1'].dbgrads['W2'], grads['b2'] = self.layers['Affine2'].dW, self.layers['Affine2'].dbreturn grads

相关文章
DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测

DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测相关推荐

  1. DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、GC对比

    DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练.GC对比 导读           神经网络算法封装为层级结构的作用.在神经网络算法中,通过将 ...

  2. DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练.预测 导读 利用python的numpy计算库,进行自定义搭建2层神经网络TwoLayerN ...

  3. DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程

    DL之DNN优化技术:自定义MultiLayerNetExtend算法(BN层使用/不使用+权重初始值不同)对Mnist数据集训练评估学习过程 目录 输出结果 设计思路 核心代码 更多输出 相关文章: ...

  4. DL之DNN优化技术:神经网络算法简介之GD/SGD算法的简介、代码实现、代码调参之详细攻略

    DL之DNN优化技术:神经网络算法简介之GD/SGD算法的简介.代码实现.代码调参之详细攻略 目录 GD算法的简介 GD/SGD算法的代码实现 1.Matlab编程实现 GD算法的改进算法 GD算法中 ...

  5. DL之CNN:自定义SimpleConvNet【3层,im2col优化】利用mnist数据集实现手写数字识别多分类训练来评估模型

    DL之CNN:自定义SimpleConvNet[3层,im2col优化]利用mnist数据集实现手写数字识别多分类训练来评估模型 目录 输出结果 设计思路 核心代码 更多输出 输出结果 设计思路 核心 ...

  6. DL之DNN优化技术:神经网络算法简介之GD/SGD算法(BP的梯度下降算法)的简介、理解、代码实现、SGD缺点及改进(Momentum/NAG/Ada系列/RMSProp)之详细攻略

    DL之DNN优化技术:神经网络算法简介之GD/SGD算法(BP的梯度下降算法)的简介.理解.代码实现.SGD缺点及改进(Momentum/NAG/Ada系列/RMSProp)之详细攻略 目录 GD算法 ...

  7. DL之DNN之BP:神经网络算法简介之BP算法/GD算法之不需要额外任何文字,只需要八张图讲清楚BP类神经网络的工作原理

    DL之DNN之BP:神经网络算法简介之BP算法/GD算法之不需要额外任何文字,只需要八张图讲清楚BP类神经网络的工作原理 目录 BP类神经网络理解 1.信号正向传播FP 2.误差反向传播BP+GD B ...

  8. DL之DNN优化技术:神经网络算法简介之数据训练优化【mini-batch技术+etc】

    DL之DNN优化技术:神经网络算法简介之数据训练优化[mini-batch技术+etc] 目录 1.mini-batch技术 输出结果 实现代码 1.mini-batch技术 输出结果 实现代码 # ...

  9. 排序层-深度模型-2015:AutoRec【单隐层神经网络推荐模型】

    AutoRec模型是2015年由澳大利亚国立大学提出的. 它将 自编码器(AutoEncoder ) 的思想和协同过滤结合,提出了一种单隐层神经网络 推荐模型.因其简洁的网络结构和清晰易懂的模型原理, ...

最新文章

  1. 使用 IntraWeb (8) - 系统模板
  2. Github 的使用
  3. linux 下 select 函数的用法
  4. 影响计算机的速度有哪些,影响电脑速度的硬件有哪些
  5. ArcGIS Engine开发模板及C#代码
  6. PYTOHN1.day14
  7. ios 数字键盘左下角添加按钮_ios数字键盘添加完成按钮
  8. ssh 根据指定端口登录远程服务器
  9. 家庭医疗系统-基于蓝牙无线通信技术
  10. 微商爆粉2.0全自动批量加人模拟手动操作
  11. NR PUSCH(一)configured grant Type1 or Type 2
  12. 软件资产管理重在license
  13. Adobe PhotoShop(PS) for Windows 快捷键/PS快捷键
  14. 扒一扒 ScheduledThreadPoolExecutor
  15. 人脸识别系统——Face recognition 人脸识别
  16. 物联网学习之旅:微信小程序控制STM32(三)--STM32代码编写
  17. IntelliJ IDEA V2022.1版本亮点——改进框架与技术
  18. matlab中有解耦指令吗,powertrain-mounting_Opti 发动机悬置系统解耦率、固有频率以及参数优化程序 matlab 266万源代码下载- www.pudn.com...
  19. 【Microsoft】Project Oxford
  20. jquery 插件 rater 星星评论

热门文章

  1. mysql锁的一些理解简书_MySQL锁系列之锁的种类和概念
  2. 尼康d7200拍照_尼康D7500适合入门吗
  3. 正确的 send recv 行为
  4. 比特币现金被3.1万多家餐厅接受
  5. 用CSS实现首字下沉效果,仿word的首字下沉
  6. 无法添加类型为“mimeMap”的重复集合项
  7. 机房收费--修改密码
  8. Java多线程:线程属性
  9. 美团配送数据治理实践
  10. 详解 Java 的八大基本类型,写得非常好!