对于大多数人来说,犯错是一件让人很不开心的事情。但反过来想,犯错可以让我们意识到自己的不足,然后我们很快就学会下次不能再犯错了。犯的错越多,我们学习进步就越快。

同样的,在神经网络训练当中,当神经网络的输出与标签不一样时,也就是神经网络预测错了,这时我们希望神经网络可以很快地从错误当中学习,然后避免再预测错了。那么现实中,神经网络真的会很快地纠正错误吗?

我们来看一个简单的例子:

上图是一个只有一个神经元的模型。我们希望输入1的时候,模型会输出0(也就是说,我们只有一个样本(x=1, y=0))。假设我们随机初始化权重参数w=2.0,偏置参数b=2.0。激活函数为sigmoid函数。所以模型的第一次输出为:

output=σ(w⋅x+b)=σ(2.0×1+2.0)=0.98output=σ(w⋅x+b)=σ(2.0×1+2.0)=0.98

output = \sigma(w \cdot x + b) = \sigma(2.0 \times 1 + 2.0) = 0.98
可见,模型的第一次输出跟标签相差很大,很错误的一个输出。然后我们不断地使用梯度下降算法更新参数,重复训练。于是我们得到了下面这个图:

从图中可以看出,随着训练的次数增加,模型的输出越来越接近0。但是有没有发现一个问题?在训练的前部分,cost并没有显著的减少,也就是权重参数w和偏置参数b的变化不明显。我们前面说了,当我们知道错了,而且错误很大时,我们通常会很快地将错误降下来。但是图中的曲线一开始却是很缓慢地变化。这跟我们想要的不一样呀。虽然最终的结果是会收敛,但是我们希望的是在一开始训练的时候,模型可以收敛得更快。究竟是什么原因使得模型的cost在一开始的时候下降很慢呢?

我们知道在用梯度下降更新参数的时候,我们是计算了下面这两个偏导数:

∂C∂w    ∂C∂b∂C∂w∂C∂b

\frac{\partial C}{\partial w} \ \ \ \ \frac{\partial C}{\partial b}
其中: C=(y−a)22C=(y−a)22C = \frac{(y-a)^2}{2},a为模型的输出。所以,上面说cost的变化不明显,也就是这两个偏导数的值很小。我们将a用上面计算output的公式代替,即 a=σ(z),z=w⋅x+ba=σ(z),z=w⋅x+ba = \sigma(z) , z = w \cdot x + b。我们就可以得到:

∂C∂w=(a−y)σ′(z)x=aσ′(z)x∂C∂w=(a−y)σ′(z)x=aσ′(z)x∂C∂w=(a−y)σ′(z)x=aσ′(z)x∂C∂w=(a−y)σ′(z)x=aσ′(z)x

\frac{\partial C}{\partial w} = (a-y){\sigma}'(z)x = a{\sigma}'(z)x \\ \frac{\partial C}{\partial w} = (a-y){\sigma}'(z)x = a{\sigma}'(z)x
为了直观一点,我们可以看一下sigmoid函数的曲线图:

我们可以看到,当模型的输出a(sigmoid的输出)接近于1的时候,曲线变得很平滑(曲线的右上角),所以σ′(z)σ′(z){\sigma}'(z)也就很小了(斜率很小)。因此,上面两个偏导数的结果就很小了。这就是为什么一开始cost曲线下降很慢的原因了。

交叉熵代价函数

那么,如何解决学习速度不够快这个问题呢?

要想解决这个问题,就是说我们的cost函数不能使用二次平方这种形式了。我们应该使用一种叫做交叉熵的函数。交叉熵代价函数的公式如下:

C=−1n∑x[yIn a+(1−y)In(1−a)]C=−1n∑x[yIna+(1−y)In(1−a)]

C = -\frac{1}{n}\sum_{x}[y In\ a + (1-y)In(1-a)]
其中,n是训练样本的总数。

从这个公式我们并不能很清晰地看出解决了学习速度慢的问题。我们对一个权重参数求导:

将上述公式继续化简可以得到:

我们知道σ′(z)=σ(z)(1−σ(z))σ′(z)=σ(z)(1−σ(z)){\sigma}'(z) = \sigma(z)(1-\sigma(z)),所以我们可以将上面的公式分子分母部分约掉,得到:

从这里我们就可以看到,权重参数的偏导数由σ(z)−yσ(z)−y\sigma(z)-y控制,模型的输出与标签y之间的偏差越大,也就是σ(z)−yσ(z)−y\sigma(z)-y的值越大,那么偏导数就会越大,学习就会越快。这正是我们想要的结果。我们用一开始那个例子,不错这一次cost函数使用交叉熵代价函数了,可以得到下面的曲线:

可以看到,这一次,在一开始的时候,曲线下降的速度变快了。这就是为什么我们在大部分的机器学习模型中,经常使用交叉熵函数作为代价函数的原因了。

参考:
http://neuralnetworksanddeeplearning.com/chap3.html

为什么要用交叉熵作为代价函数相关推荐

  1. 信息量、熵、相对熵与交叉熵的理解

    一.信息量 信息奠基人香农(Shannon)认为"信息是用来消除随机不确定性的东西".也就是说衡量信息量大小就看这个信息消除不确定性的程度. "太阳从东方升起了" ...

  2. 为什么需要交叉熵代价函数

    为什么需要交叉熵代价函数 人类却能够根据明显的犯错快速地学习到正确的东西.相反,在我们的错误不是很好地定义的时候,学习的过程会变得更加缓慢.但神经网络却不一定如此,这种行为看起来和人类学习行为差异很大 ...

  3. 交叉熵代价函数——当我们用sigmoid函数作为神经元的激活函数时,最好使用交叉熵代价函数来替代方差代价函数,以避免训练过程太慢...

    交叉熵代价函数 machine learning算法中用得很多的交叉熵代价函数. 1.从方差代价函数说起 代价函数经常用方差代价函数(即采用均方误差MSE),比如对于一个神经元(单输入单输出,sigm ...

  4. 【深度学习】sigmoid - 二次代价函数 - 交叉熵 - logistic回归 - softmax

    1. sigmoid函数:σ(z) = 1/(1+e^(-z)) sigmoid函数有个性质:σ'(z) =σ(z) * ( 1 - σ(z) ) sigmoid函数一般是作为每层的激活函数,而下边的 ...

  5. 交叉熵代价函数cross-entropy

    交叉熵代价函数(Cross-entropy cost function)是用来衡量人工神经网络(ANN)的预测值与实际值的一种方式.与二次代价函数相比,它能更有效地促进ANN的训练.在介绍交叉熵代价函 ...

  6. 为什么使用交叉熵代替二次代价函数_Softmax回归与交叉熵损失的理解

    0.sigmoid.softmax和交叉熵损失函数的概念理解 sigmoid.softmax和交叉熵损失函数 1.使用场景 在二分类问题中,神经网络输出层只有一个神经元,表示预测输出 是正类 的概率 ...

  7. 交叉熵代价函数(损失函数)及其求导推导

    转自:http://blog.csdn.net/jasonzzj/article/details/52017438 前言 交叉熵损失函数 交叉熵损失函数的求导 前言 说明:本文只讨论Logistic回 ...

  8. BP神经网络——从二次代价函数(Quadratic cost)到交叉熵(cross-entropy cost)代价函数

    通过下文的阐述我们可以获得以下信息: 反向传播(back propagation)算法是一个计算框架(或者计算流程) 既然是一个计算框架,便与代价函数的具体形式(无论是二次代价还是交叉熵代价函数,只要 ...

  9. 机器学习基础(六)—— 交叉熵代价函数(cross-entropy error)

    交叉熵代价函数 1. 交叉熵理论 交叉熵与熵相对,如同协方差与方差. 熵考察的是单个的信息(分布)的期望: H(p)=−∑i=1np(xi)logp(xi) H(p)=-\sum_{i=1}^n p( ...

  10. python3 23.keras使用交叉熵代价函数进行MNIST数据集简单分类 学习笔记

    文章目录 前言 一.交叉熵代价函数简介 二.交叉熵代价函数使用 前言 计算机视觉系列之学习笔记主要是本人进行学习人工智能(计算机视觉方向)的代码整理.本系列所有代码是用python3编写,在平台Ana ...

最新文章

  1. 视图函数中进行sql查询,防止sql注入
  2. VTK:PolyData之DeletePoint
  3. 763. Partition Labels 划分字母区间
  4. Silverlight专题(10)- WatermarkedTextBox使用
  5. 虚拟化技术--桌面虚拟化(VDI)
  6. activity绑定service
  7. 电子信息工程这个专业学的是什么内容,就业怎么样?
  8. 快进来,详解MySQL游标
  9. 不使用中国手机号码注册网易云音乐
  10. Jquery 漂浮广告的插件
  11. Windows使用CMD命令查看WIFI密码
  12. Mendix装备制造业应用 | 质量统计分析人工智能应用APP
  13. 牛客寒假算法基础集训营1 C. 小a与星际探索(dp或者各种姿势)
  14. django改变用户头像
  15. 语法基础——Objective-C语法基础
  16. python 爬取某音乐各排行榜【简易版本】
  17. 初试Office 365企业版E3
  18. bim建筑绘图计算机要求,BIM技术人才需要达到哪些要求呢?
  19. Linux数据库挂载空间
  20. ET钱包1月21日早报|EOS钱包插件Scatter正尝试增加新功能

热门文章

  1. java实现网站的访问量_java统计网站访问量
  2. 双绞线的制作T568A线序,T568B线序
  3. 全国耳鼻喉科 医院排名
  4. python 加速运算
  5. 晶振的负载电容、寄生电容和动态电容及参考值
  6. 使用UCSC基因组浏览器可视化测序深度分布数据
  7. CentOS7安装MySQL8报错mariadb-libs is obsoleted by mysql-community-libs-8.0.xx-1.el7.x86_64
  8. 电脑底部任务栏没反应怎么办?
  9. 20年研发管理经验谈(二)
  10. FTRL之FM和LR实战(使用稀疏数据进行实战)