原文地址:http://suanfazu.com/t/topic/13742

上一次的分享我们提到了神经网络的几个基本概念,其中提到了随机梯度下降(SGD)算法是神经网络学习(或者更通用的,一般性参数优化问题)的主流方法。概念上,神经网络的学习非常简单,可以被归纳为下面的步骤:

(a) 构造神经网络结构(选择层数、激活函数等)
(b) 初始化构造出的神经网络参数
(c) 对于给定的训练样本与当前的,计算梯度
(d) 通过(随机)梯度下降算法更新
例如,不考虑任何正则化因子情况的最简单参数更新为

神经网络的初学者往往会发现,上述四个步骤当中,对于给定样本,计算其梯度是最 不直观 的一个步骤。本文我们玻森(bosonnlp.com)的讨论就围绕解决梯度的核心算法:后向传播 算法来展开。

首先理清一个概念,步骤(d)的梯度下降算法是一种优化算法,而我们要讨论的后向传播算法,是计算步骤(c)中所需要梯度的一种算法。下面的讨论,我们首先完成单参数(即只有一个参数需要学习)的特例情况下的推导,进而通过 动态规划(Dynamic programming) 思想,将其推导泛化到多变量的情况。需要注意的是,虽然后向传播概念上并不复杂,所用到的数学工具也很基本,但由于所涉及的变量较多、分层等特点,在推导的时候需要比较仔细,类似绣花。

单参数情况

特例
在讨论后向传播算法之前,我们简单回顾一下单变量微积分中的求导规则。来看个例子,假设我们有一个极端简化的网络,其中只有一个需要学习的参数,形式如下

并且假设损失函数Cost为平方误差(MSE)。

假设我们只有一个训练样本。因为这个形式非常简单,我们试试将样本直接带入损失函数:

显然当时,我们可以让损失函数为0,达到最优。下面让我们 假装 不知道最优解,考虑如何用梯度下降方法来求解。假设我们猜为最优,带入计算得到

嗯,不算太坏的一个初始值。让我们计算其梯度,或者损失函数关于的导数。

设置学习率参数,我们可以通过梯度下降方法来不断改进,以达到降低损失函数的目的。三十个迭代的损失函数变化如下:

生成上图采用的是如下Python代码

import matplotlib.pyplot as plt
w0, eta, n_iter = 2, 0.02, 30
gradient_w = lambda w: 2*(w**3)
cost = lambda w: 0.5*(w**4)
costs = []
w = w0
for i in range(n_iter):
costs.append(cost(w))
w = w – eta*gradient_w(w) # SGD
plt.plot(range(n_iter), costs)

可以发现,经过30次迭代后,我们的参数从初始的2改进到了0.597,正在接近我们的最优目标

对于一般的情况
回忆一下,上面的结果是基于我们给定 下得到的,注意这里我们假设输入信号为常量。我们将上面的求解步骤做一点点泛化。

重复上面的求解

关于w求导,

注意,上面求导用到了 链式法则(Chain Rule),即

或者写成偏导数形式:

对于一般性损失函数的情况
上式推导基于损失函数为平方最小下得出,那么我们再泛化一点,对于任意给定的可导损失函数,其关于的梯度:

其中是损失函数关于的导数。实际上这个形式很通用,对于 任意 特定的损失函数和神经网络的激活函数,都可以通过这个式子进行梯度计算。譬如,对于一个有三层的神经网络

同样通过链式法则,

上式看上去比较复杂,我们可以在符号上做一点简化。令每一层网络得到的激活函数结果为,即, 那么:

即:不论复合函数本身有多么复杂,我们都可以将其导数拆解成每一层函数的导数的乘积。

上面的推导我们给出了当神经网络仅仅被一个可学习参数所刻画的情况。一句话总结,在单参数的网络推导中,我们真正用到的唯一数学工具就是 链式法则。实际问题中,我们面对的参数往往是数以百万计的,这也就是为什么我们无法采用直觉去“猜”到最优值,而需要用梯度下降方法的原因。下面我考虑在多参数情况下,如何求解梯度。

多参数情况

首先,不是一般性的,我们假设所构建的为一个层的神经网络,其中每一层神经网络都经过线性变换和非线性变换两个步骤(为简化推导,这里我们略去对bias项的考虑):

定义网络的输入,而作为输出层。一般的,我们令网络第层具有个节点,那么。注意此时我们网络共有个参数需要优化。

为了求得梯度,我们关心参数关于损失函数的的导数:,但似乎难以把简单地与损失函数联系起来。问题在哪里呢?事实上,在单参数的情况下,我们通过链式法则,成功建立第一层网络的参数与最终损失函数的联系。譬如,的改变影响函数的值,而连锁反应影响到的函数结果。那么,对于值的改变,会影响,从而影响。通过的线性变换(因为),的改变将会影响到每一个

将上面的过程写下来:

可以通过上式不断展开进行其梯度计算。这个方式相当于我们枚举了 每一条 改变对最终损失函数影响的 路径。通过简单使用链式法则,我们得到了一个 指数级 复杂度的梯度计算方法。稍仔细观察可以发现,这个是一个典型的递归结构(为什么呢?因为定义的是一个递归结构),可以采用动态规划(Dynamic programming)方法,通过记录子问题答案的进行快速求解。设用于动态规划的状态记录。我们先解决最后一层的边界情况:

上式为通用形式。对于Sigmoid, Tanh等形式的element-wise激活函数,因为可以写成的形式,所示上式可以简化为:

即该情况下,最后一层的关于的导数与损失函数在导数和最后一层激活函数在的导数相关。注意当选择了具体的损失函数和每层的激活函数后,也被唯一确定了。下面我们看看动态规划的 状态转移 情况:

成功建立的递推关系,所以整个网络的可以被计算出。在确定了后,我们的对于任意参数的导数可以被简单表示出:

至此,我们通过链式法则和动态规划的思想,不失一般性的得到了后向传播算法的推导。

转载于:https://www.cnblogs.com/davidwang456/articles/5607057.html

当我们在谈深度学习时,到底在谈论什么(二)--转相关推荐

  1. 当我们在谈深度学习时,到底在谈论什么(三)--转

    原文:http://suanfazu.com/t/topic/13744 正则化 相信对机器学习有一定了解的朋友对正则化(Regularization)这个概念都不会陌生.可以这么说,机器学习中被讨论 ...

  2. 当我们在谈深度学习时,到底在谈论什么(一)--转

    原文地址:http://suanfazu.com/t/topic/13741 深度学习最近两年在音频分析,视频分析,游戏博弈等问题上取得了巨大的成果.由于微软,谷歌等科技巨头的推动及应用上的可见突破, ...

  3. 当我们谈深度学习时,我们用它落地了什么?

    摘要: 近日,阿里云在深度学习方面动作频频,先后发布了OCR证件识别,声纹检测,人脸搜索,视频鉴黄服务以及相似图片搜索功能,下面小编就一一为大家介绍五大功能应用. 现今伴随人工智能在技术上的不断突破, ...

  4. 当我们谈深度学习时,我们用它落地了什么?阿里云内容安全功能全新升级

    现今伴随人工智能在技术上的不断突破,一些领域如计算机视觉,已开始与各个行业进行了深度融合.例如保险行业已通过人脸识别这种新时代的认证方式,来对用户身份信息进行识别与审核.深度学习对人工智能的发展起着至 ...

  5. 浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现)

    浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现) 总包含文章: 一个完整的机器学习模型的流程 浅谈深度学习:了解RNN和构建并预测 浅谈深度学习:基于对LS ...

  6. 浅谈深度学习:基于对LSTM项目`LSTM Neural Network for Time Series Prediction`的理解与回顾

    浅谈深度学习:基于对LSTM项目LSTM Neural Network for Time Series Prediction的理解与回顾#### 总包含文章: 一个完整的机器学习模型的流程 浅谈深度学 ...

  7. 嵌入式AI —— 6. 为糖葫芦加糖,浅谈深度学习中的数据增广

    没有读过本系列前几期文章的朋友,需要先回顾下已发表的文章: 开篇大吉 集成AI模块到系统中 模型的部署 CMSIS-NN介绍 从穿糖葫芦到织深度神经网络 又和大家见面了,上次本程序猿介绍了CMSIS- ...

  8. 周志华:浅谈深度学习

    我们都知道直接掀起人工智能热潮的最重要的技术之一,就是深度学习技术.今天,其实深度学习已经有各种各样的应用,到处都是它,不管图像也好,视频也好,声音自然语言处理等等.那么我们问一个问题,什么是深度学习 ...

  9. 浅谈深度学习的基础——神经网络算法(科普)

    浅谈深度学习的基础--神经网络算法(科普) 神经网络算法是一门重要的机器学习技术.它是目前最为火热的研究方向--深度学习的基础.学习神经网络不仅可以让你掌握一门强大的机器学习方法,同时也可以更好地帮助 ...

最新文章

  1. 无法访问您试图使用的功能所在的网络位置
  2. 【深度学习】CNN图像分类:从LeNet5到EfficientNet
  3. UIModalPresentationStyle 各种类型的区别
  4. 机器学习中用到的概率知识_机器学习中有关概率论知识的小结
  5. 英国Carmarthen Learning Centre校长Mr Stuart来华访问,与荣新IT培训中心达成教学合作关系...
  6. 工作流添加跟踪后,实例一启动就会自动关闭
  7. vb与数据库(二)之迟到的学生信息管理系统总结
  8. python plt 色卡
  9. 0/1背包总结(持续更新...)
  10. projectwbs表_从Project 2007导出WBS图表到Visio 2007
  11. Python 3 网络爬虫 个人笔记 (未完待续)
  12. ElasticSearch知识概括
  13. 【2022年法定工作日,周末,节假日类型使用Java存入sql】
  14. PDF转Word怎么转?教你三招快速实现PDF转Word
  15. Halcon——颜色识别提取
  16. Java 程序获取本机 ip 地址
  17. 《微信公众平台与小程序开发——从零搭建整套系统》——第1章,第1.2节微信公众平台...
  18. MySQL学习_数据库和表的基本操作
  19. informatica添加MySQL表,Informatica 简单使用
  20. kubesphere_越南 ZaloPay 使用 KubeSphere 构建核心商户平台支持亿级用户

热门文章

  1. linux怎么命令设置网络连接,Linux网络操作命令
  2. python中的字典推导式_python 字典推导式(经典代码)(22)
  3. jq 点击按钮跳转到微信_【看这里】教你用微信小程序登陆全国青少年普法网,方便快捷!...
  4. cas无法使用_一文彻底搞懂CAS实现原理
  5. matplotlib 制作不等间距直方图
  6. php不能加载oci8,无法加载动态库'oci8.so'(PHP 7.2)
  7. android 之Dialog对话框(简易版)
  8. c语言文件发送程序,C语言程序例程的文件结构
  9. C++引用作为函数参数
  10. crontab 运行pyhon脚本