引言

本着“凡我不能创造的,我就不能理解”的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导。

要深入理解深度学习,从零开始创建的经验非常重要,从自己可以理解的角度出发,尽量不适用外部完备的框架前提下,实现我们想要的模型。本系列文章的宗旨就是通过这样的过程,让大家切实掌握深度学习底层实现,而不是仅做一个调包侠。
本系列文章首发于微信公众号:JavaNLP

今天来好好探讨下交叉熵损失(Cross Entropy Loss),为什么逻辑回归需要用到交叉熵损失。

信息熵

信息的价值在于消除事件的不确定性,那事件的不确定性要怎么度量呢?答案就是信息熵(information entropy)。

比如你告诉别人你中了500万彩票,别人会大吃一惊,因为他被消除了大量的不确定性。但如果你告诉别人你没中彩票,别人基本熵没有反应,因为他估计你这小子十有八九不会中彩票。相当于你几乎没有消除他对你没有中彩票这件事的不确定性。或者说你传达的信息量太少。我们知道概率只能在0到1之间,也就是说,最好在概率为1的时候,信息量为0,且概率越小,信息量越大。后来人们发现,对数函数很符合这样的规律,某个事件的信息量与概率的关系是i=log⁡(1p)i = \log(\frac{1}{p})i=log(p1​),这里的对数是以2为底的,ppp是事件发生的概率。

上面最后这个式子是怎么来的呢?以抛硬币游戏为例,如果有一枚理想的硬币,其出现正面和反面的概率相等,假设我们相隔很远,只能通过电位信号(0或1)进行交流,如何把这个硬币的结果告诉我呢。显然,此时只需要发送一个信号就可以,用1表示正面,用0表示反面。信息的价值在于消除事件的不确定性,传递一枚硬币结果的信息,帮我们消除了它是哪一个面的不确定性。

我们再来看一个转盘游戏,这个转盘被均等地分为8个区域,如果我们要把转盘的结果发送出去,那么需要多少个信号呢?答案是3个信号。

在这两个例子中,我们发现一件事情,把一个游戏系统中所有可能出现的等概率事件数量取以2为底的对数,就是我们要传递事件结果所需要的信号数量。比如在抛硬币游戏中是log⁡2(2)=1\log_2(2)=1log2​(2)=1,在转盘游戏中是log⁡2(8)=3\log_2(8)=3log2​(8)=3​​。这个数量就是信息量

即信息量=log⁡(N)\text{信息量}=\log(N)信息量=log(N),这里的NNN是等可能事件数量。

可以把这种度量不确定性的信息量称为信息熵,但严格来说这并不是香农所说的信息熵,这只是信息熵的一个特例,即所有的事件是等可能的。我们遇到的更多情况是事件发生的可能性不一样的系统。 比如现实生活中就无法制作出来正反面概率都是50%的硬币。

实际上,我们总是可以把一个事件的概率值转换为一个等可能事件系统中发生某个事件的概率。举例来说,我们总是可以把一个概率值转换为“在N个球中随机摸一个球”这个等可能事件系统中摸出某个球的概率。

假设中彩票大奖的概率很低,只有两千万分之一,我们可以把这个概率值转换为在两千万个球中摸出中奖球的概率。在这个摸球系统中,就有两千万个等可能事件。所以只需要用1除以概率值就可以想象出等可能事件系统中事件等数量,即N=1pN=\frac{1}{p}N=p1​​。

假设我们有一个动了手脚的不均匀硬币,它正面朝上的概率是0.8,反面朝上的概率是0.2。

如上图,反面朝上0.2的概率可以想象有5个球的摸球系统中摸出某个球的概率。

而正面朝上0.8的概率可以想象成在有1.25个球的系统中摸出某个球的概率。我们通过想象把一个非等概率事件的系统拆成了两个等概率事件的系统。

而面对等概率事件系统,我们就可以很容易地计算它们的信息量。再把这两个想象出来的摸球系统的信息量加起来,就是这个不均匀硬币的信息量:log⁡5+log⁡1.25\log 5 + \log 1.25log5+log1.25,由于这两个我们想象出来的等概率系统本身出现的概率也不一样,因此我们需要分别乘上它们出现的概率,得0.2⋅log⁡5+0.8⋅log⁡1.250.2 \cdot \log 5 + 0.8 \cdot \log 1.250.2⋅log5+0.8⋅log1.25。

如果我们用符号去抽象这些具体的值,就是
p1⋅log⁡1p1+p2⋅log⁡1p2(1)p_1 \cdot \log \frac{1}{p_1} + p_2 \cdot \log \frac{1}{p_2} \tag{1} p1​⋅logp1​1​+p2​⋅logp2​1​(1)
对于有更多事件的一般情况,我们可以这么表示:
∑ipilog⁡1pi(2)\sum_i p_i \log \frac{1}{p_i} \tag{2} i∑​pi​logpi​1​(2)
我们整理下这个式子
∑ipilog⁡1pi=∑i(pi⋅(log⁡1−log⁡pi))=∑i−pilog⁡pi=−∑ipilog⁡pi(3)\begin{aligned} \sum_i p_i \log \frac{1}{p_i} &= \sum_i (p_i \cdot (\log1 - \log p_i)) \\ &= \sum_i - p_i \log p_i \\ &= - \sum_i p_i \log p_i \end{aligned} \tag{3} i∑​pi​logpi​1​​=i∑​(pi​⋅(log1−logpi​))=i∑​−pi​logpi​=−i∑​pi​logpi​​(3)

就得到了香农所提出的信息熵公式 $ -\sum_i p_i \log p_i$

我们可以看出,信息熵实际熵就是我们给每个概率值想象出来的某球系统的信息量的平均值,或者说是信息量的期望。

如果我们要比较两个概率模型的距离,最简单的办法就是把它们的信息熵都算出来,直接比较两个结果就好了。但是问题是,在机器学习中,我们往往不知道训练样本的概率模型。此时呢,我们就需要用到相对熵,也称为KL散度(KL Divergence)。

但是在这之前,为了知识的完整性,我们需要了解极大似然估计的概念。

极大似然估计

极大似然估计里面有三个概念,极大、似然和估计。通俗来说,就是用已知的样本结果信息,去反推最有可能导致这些样本结果出现的模型参数值。

反推说的是一种推理、估计,我们无法保证完全能从已知样本去推出产生这些样本的概率分布,只能说是一种估计。似然值说的是,真实样本已经看到,假设有很多(概率)模型,每个模型产生这些真实样本的可能性就叫似然值。极大似然估计就是选择似然值最高的模型来估计真实(概率)模型。

还是以抛硬币为例,我们记硬币的正面为H(Head),反面为T(Tail)。

假设我们不知道这个硬币产生正反面的概率,但是我们可以做10次实验,假设产生这样一组结果:HHHHHHHTTT。即前7次是正面,后3次是反面。

假设有三个产生这组结果的模型(概率分布),

  • 模型A产生正面的概率p=0.1p=0.1p=0.1,产生反面的概率就是1−p=0.91-p=0.91−p=0.9​
  • 模型B产生正面的概率p=0.7p=0.7p=0.7,产生反面的概率是1−p=0.31-p=0.31−p=0.3
  • 模型CCC产生正面的概率p=0.8p=0.8p=0.8,产生反面的概率是1−p=0.21-p=0.21−p=0.2

计算某个概率模型产生这组结果的可能性是可以计算出来的,公式为:
P(C1,C2,⋯,C10∣θ)=∏i=110P(Ci∣θ)(4)P(C_1,C_2,\cdots,C_{10}|\theta) = \prod_{i=1}^{10} P(C_i|\theta) \tag{4} P(C1​,C2​,⋯,C10​∣θ)=i=1∏10​P(Ci​∣θ)(4)
其中Ci∈{0,1}C_i \in \{0,1\}Ci​∈{0,1}是第iii次抛硬币的结果,整个式子说的是由参数θ\thetaθ确定的模型同时发生C1,C2,⋯,C10C_1,C_2,\cdots,C_{10}C1​,C2​,⋯,C10​的概率。

同时发生就是连乘。

这样的可能性就叫似然值

因此我们只需要计算每个模型的似然值,然后选择似然值最大的模型来估计真实模型。

模型AAA的似然值:0.170.93≈7.29e−080.1^70.9^3 \approx 7.29e-080.170.93≈7.29e−08

模型BBB的似然值:0.770.33≈0.002220.7^70.3^3\approx 0.002220.770.33≈0.00222

模型CCC的似然值:0.170.93≈0.001680.1^70.9^3\approx 0.001680.170.93≈0.00168

挑出似然值最大的模型就叫最大似然估计法。

极大似然法

我们从极大似然估计的角度来看一下损失函数的选择。

以上图为例,把一些图片,输入到神经网络,神经网络会输出这张图片是猫的可能性。假设这些图片是训练数据,我们已经这些图片是否是猫。

在抛硬币中,我们通过θ\thetaθ来表示参数,在神经网络这里可以具体地用W,bW,bW,b来表示。

即可以写成:
P(x1,x2,x3,⋯,xn∣W,b)(5)P(x_1,x_2,x_3,\cdots,x_n|W,b) \tag{5} P(x1​,x2​,x3​,⋯,xn​∣W,b)(5)
nnn这些图片的个数;xi∈{0,1}x_i \in \{0,1\}xi​∈{0,1}代表输入的这张图片是否为猫,111代表是猫。

这样我们也可以把上式改成连乘的形式:
P(x1,x2,x3,⋯,xn∣W,b)=∏i=1nP(xi∣W,b)(6)P(x_1,x_2,x_3,\cdots,x_n|W,b) =\prod_{i=1}^n P(x_i|W,b) \tag{6} P(x1​,x2​,x3​,⋯,xn​∣W,b)=i=1∏n​P(xi​∣W,b)(6)
我们就可以得到基于这些图片的模型的似然值,我们要找到使得这个似然值最大的W,bW,bW,b。

但是W,bW,bW,b​是一个确定的值,而我们知道神经网络可以看成是由W,bW,bW,b这组参数确定的一个函数,该函数的输出结果yiy_iyi​表示输入图片xix_ixi​是猫的可能性有多大,即yi=NNW,b(xi)y_i = NN_{W,b}(x_i)yi​=NNW,b​(xi​)​​。这里我们就可以用可能性yiy_iyi​来替代上面的参数W,bW,bW,b:
P(x1,x2,x3,⋯,xn∣W,b)=∏i=1nP(xi∣yi)(7)P(x_1,x_2,x_3,\cdots,x_n|W,b) =\prod_{i=1}^n P(x_i|y_i) \tag{7} P(x1​,x2​,x3​,⋯,xn​∣W,b)=i=1∏n​P(xi​∣yi​)(7)
这样输入不同猫的图片xix_ixi​,我们可以得到不同的概率值yiy_iyi​。

这个式子我们要如何展开呢,我们这个连乘时的写法与xix_ixi​的取值有关,当xi=1x_i=1xi​=1时,输出的应该是判断为猫的概率,取yiy_iyi​;当xi=0x_i=0xi​=0时,输出的应该是判断不是猫的概率,取1−yi1-y_i1−yi​。

好在还是有解法的,这个式子可以通过伯努利分布展开的,因为xi∈{0,1}x_i \in \{0,1\}xi​∈{0,1}两种情况,同时yiy_iyi​又是一个概率。

和抛硬币的例子类似,当x=1x=1x=1时,我们用硬币是正面的概率ppp去乘;当x=0x=0x=0是,我们用硬币是反面的概率1−p1-p1−p去乘。

我们就可以通过伯努利分布的这个式子来展开式(7)(7)(7),得:
P(x1,x2,x3,⋯,xn∣W,b)=∏i=1nyixi(1−yi)1−xi(8)P(x_1,x_2,x_3,\cdots,x_n|W,b) = \prod_{i=1}^n y_i^{x_i}(1-y_i)^{1-x_i} \tag{8} P(x1​,x2​,x3​,⋯,xn​∣W,b)=i=1∏n​yixi​​(1−yi​)1−xi​(8)
我们通过在等式右边取对数,把连乘变成连加,因为取对数不改变单调性的。
log⁡(∏i=1nyixi(1−yi)1−xi)=∑i=1nlog⁡(yixi(1−yi)1−xi)=∑i=1n(xi⋅log⁡yi+(1−xi)⋅log⁡(1−yi))(9)\begin{aligned} \log \left( \prod_{i=1}^n y_i^{x_i}(1-y_i)^{1-x_i} \right )&= \sum_{i=1}^n \log \left(y_i^{x_i}(1-y_i)^{1-x_i} \right) \\ &= \sum_{i=1}^n \left(x_i \cdot \log y_i + (1-x_i)\cdot \log (1-y_i) \right) \\ \end{aligned} \tag{9} log(i=1∏n​yixi​​(1−yi​)1−xi​)​=i=1∑n​log(yixi​​(1−yi​)1−xi​)=i=1∑n​(xi​⋅logyi​+(1−xi​)⋅log(1−yi​))​(9)

我们要求极大似然值,是取的最大值,而损失函数是取最小值,我们把等式两边乘以一个负号,变成了求最小值。
min⁡−∑i=1n(xi⋅log⁡yi+(1−xi)⋅log⁡(1−yi))(10)\min - \sum_{i=1}^n \left(x_i \cdot \log y_i + (1-x_i)\cdot \log (1-y_i) \right) \tag{10} min−i=1∑n​(xi​⋅logyi​+(1−xi​)⋅log(1−yi​))(10)
虽然这个式子看起来很像交叉熵,但实际上还是有很大不同的,主要的区别是它们的量纲不同。

这里面的对数是我们故意加上去的,并且负号也是为了凑成求最小值。

下面我们就来了解相对熵。

KL散度

KL散度,也被称为相对熵,是用来衡量两个分布的距离,设PPP和QQQ是两个概率分布,则PPP对QQQ的相对熵为:
DKL(P∣∣Q)=∑iP(i)log⁡P(i)Q(i)(11)D_{KL}(P||Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)} \tag{11} DKL​(P∣∣Q)=i∑​P(i)logQ(i)P(i)​(11)
这里iii代表分布中的所有类别。

性质

  1. 不具备对称性,即D(P∣∣Q)≠D(Q∣∣P)D(P||Q) \neq D(Q||P)D(P∣∣Q)​=D(Q∣∣P)
  2. 非负性,即D(P∣∣Q)≥0D(P||Q) \geq 0D(P∣∣Q)≥0

举个例子,还是以抛硬币为例,假设我们有一个公平的硬币,即正反概率都是50%;我们还有一个有偏差的硬币,其正面概率为ppp,反面概率为qqq。

我们要如何判断这两个分布的相似性呢?

可能不好回答,但是我们知道,如果p=0.55p=0.55p=0.55,它肯定比p=0.95p=0.95p=0.95要更相似。

我们可以从抛硬币的结果来看,

​ 假设公平硬币的抛掷结果为:HHTHHTTHTHTH

假设p=0.55p=0.55p=0.55硬币的抛掷结果为:HHTHHTTHHHTH

假设p=0.95p=0.95p=0.95硬币的抛掷结果为:HHHHHHTHHHHH

我们可以简单的计算不相等的结果个数,但是更严谨的做法是计算产生某个结果的似然值。

如果似然值很接近,那么说明这两个概率分布很接近。

基于(公平硬币抛出的)观察结果,我们就可以计算公平硬币的似然值和其他硬币的似然值的比值:
P(观察结果∣公平硬币)P(观察结果∣偏差硬币)\frac{P(\text{观察结果}|\text{公平硬币})}{P(\text{观察结果}|\text{偏差硬币})} P(观察结果∣偏差硬币)P(观察结果∣公平硬币)​
我们再举一个例子,假设有一枚硬币,其正面概率为p1p_1p1​,反面概率为p2p_2p2​;

假设我们抛掷这枚硬币12次,产生的结果为:HHTHHTHHHTHT

我们可以很容易计算出这枚硬币产生这个结果的概率: p1⋅p1⋅p2⋅p1⋅p1⋅p2⋅p1⋅p1⋅p1⋅p2⋅p1⋅p2p_1\cdot p_1 \cdot \color{red}p_2 \cdot \color{black}p_1 \cdot p_1 \cdot \color{red}p_2 \cdot \color{black}p_1 \cdot p_1 \cdot p_1 \cdot \color{red}p_2 \cdot \color{black}p_1 \cdot \color{red}p_2p1​⋅p1​⋅p2​⋅p1​⋅p1​⋅p2​⋅p1​⋅p1​⋅p1​⋅p2​⋅p1​⋅p2​

我们再拿一枚硬币,它产生正面的概率为q1q_1q1​,反面概率为q2\color{red}q_2q2​

那么这枚新的硬币产生这个结果的概率为: q1⋅q1⋅q2⋅q1⋅q1⋅q2⋅q1⋅q1⋅q1⋅q2⋅q1⋅q2q_1\cdot q_1 \cdot \color{red}q_2 \cdot \color{black}q_1 \cdot q_1 \cdot \color{red}q_2 \cdot \color{black}q_1 \cdot q_1 \cdot q_1 \cdot \color{red}q_2 \cdot \color{black}q_1 \cdot \color{red}q_2q1​⋅q1​⋅q2​⋅q1​⋅q1​⋅q2​⋅q1​⋅q1​⋅q1​⋅q2​⋅q1​⋅q2​

即基于观察结果,有

P(观察结果∣硬币1)=p1NHp2NTP(\text{观察结果}|\text{硬币1}) = p_1^{N_H}\color{red}p_2^{N_T}P(观察结果∣硬币1)=p1NH​​p2NT​​

P(观察结果∣硬币2)=q1NHq2NTP(\text{观察结果}|\text{硬币2}) = q_1^{N_H}\color{red}q_2^{N_T}P(观察结果∣硬币2)=q1NH​​q2NT​​

其中NHN_HNH​​表示观察结果中为正面的次数,NTN_TNT​​为反面的次数。我们计算它们的比值:
P(观察结果∣真实硬币)P(观察结果∣硬币2)=p1NHp2NTq1NHq2NT(12)\frac{P(\text{观察结果}|\text{真实硬币})}{P(\text{观察结果}|\text{硬币2})} = \frac{p_1^{N_H}\color{red}p_2^{N_T}}{q_1^{N_H}\color{red}q_2^{N_T}} \tag{12} P(观察结果∣硬币2)P(观察结果∣真实硬币)​=q1NH​​q2NT​​p1NH​​p2NT​​​(12)
这样就能计算出来这两个硬币的相似性。

其实KL散度衡量的是类似的东西。怎么说?

我们把上式右边取对数,并除以实验总数N=NH+NTN=N_H+\color{red}N_TN=NH​+NT​:
1Nlog⁡(p1NHp2NTq1NHq2NT)=1Nlog⁡p1NH+1Nlog⁡p2NT−1Nlog⁡q1NH−1Nlog⁡q2NT=p1log⁡p1+p2log⁡p2−p1log⁡q1−p2log⁡q2=p1log⁡p1q1+p2log⁡p2q2\begin{aligned} \frac{1}{N}\log \left( \frac{p_1^{N_H}\color{red}p_2^{N_T}}{q_1^{N_H}\color{red}q_2^{N_T}} \right) &= \frac{1}{N}\log p_1^{N_H} + \frac{1}{N}\log \color{red}p_2^{N_T} \color{black} - \frac{1}{N}\log q_1^{N_H} -\frac{1}{N}\log \color{red} q_2^{N_T} \\ &= p_1\log p_1 + p_2 \log \color{red}p_2 \color{black} - p_1 \log q_1 - \color{red}p_2 \color{black}\log \color{red}q_2 \\ &= p_1 \log \frac{p_1}{q_1} + \color{red}p_2 \color{black}\log \frac{\color{red}p_2}{\color{red}q_2} \\ \end{aligned} N1​log(q1NH​​q2NT​​p1NH​​p2NT​​​)​=N1​logp1NH​​+N1​logp2NT​​−N1​logq1NH​​−N1​logq2NT​​=p1​logp1​+p2​logp2​−p1​logq1​−p2​logq2​=p1​logq1​p1​​+p2​logq2​p2​​​

其中NHN=p1NTN=p2\frac{N_H}{N}=p_1 \,\,\,\, \frac{\color{red}N_T}{N}=\color{red}p_2NNH​​=p1​NNT​​=p2​。这里ppp是一个概率分布,qqq是另一个概率分布,该式子和KL散度的式子一模一样。

即我们通过计算真实分布的似然值除以第二个分布的似然值,再取归一化的对数,就得到了KL散度的表达式。

我们可以看到,KL散度是一种衡量两个概率分布距离的方式,通过观察第二个概率分布产生第一个概率分布样本的可能性。

KL散度非常适用于深度学习的场景,因为深度学习模型基本上是关于为已知样本的真实分布建模。

实际上,交叉熵损失(cross entroy loss)就等于KL损失,最小化交叉熵就是最小化两个分布的距离。

我们先来看下交叉熵的定义。

交叉熵

交叉熵(Cross Entropy)主要衡量两个概率分布之间的差异性。交叉熵可在神经网络中作为损失函数,有:
H(P∗∣P)=−∑iP∗(i)log⁡P(i)(13)H(P^*|P)=- \sum_i P^*(i) \log P(i) \tag{13} H(P∗∣P)=−i∑​P∗(i)logP(i)(13)
其中P∗P^*P∗表示真实分布;PPP表示预测分布;iii表示分布中的所有类别。

KL散度和交叉熵

我们已经了解了KL散度和交叉熵,我们本小节来看它们之间的关系。

我们知道,交叉熵可以用来衡量预测分布和真实分布的差异(距离)。我们观察到的样本都是由真实分布产生的,所以我们可以这样描述KL散度:
DKL(P∗∣∣P)=DKL(P∗(y∣xi)∣∣P(y∣xi;θ)(14)D_{KL}(P^*||P) =D_{KL}\left ( P^* (y|x_i) || P(y|x_i;\theta\right) \tag{14} DKL​(P∗∣∣P)=DKL​(P∗(y∣xi​)∣∣P(y∣xi​;θ)(14)
其中P∗P^*P∗是真实分布,PPP是我们的预测分布;xix_ixi​是第iii个样本,yyy是其对应的标签;θ\thetaθ是模型的参数。
DKL(P∗∣∣P)=∑yP∗(y∣xi)log⁡P∗(y∣xi)P(y∣xi;θ)=∑yP∗(y∣xi)[log⁡P∗(y∣xi)−log⁡P(y∣xi;θ)]=∑yP∗(y∣xi)log⁡P∗(y∣xi)−∑yP∗(y∣xi)log⁡P(y∣xi;θ)(15)\begin{aligned} D_{KL}(P^*||P) &= \sum_y P^* (y|x_i) \log \frac{P^*(y|x_i)}{P(y|x_i;\theta)} \\ &=\sum_y P^* (y|x_i) \left [\log P^*(y|x_i) - \log P(y|x_i;\theta) \right] \\ &= \sum_y P^* (y|x_i)\log P^*(y|x_i) - \sum_y P^* (y|x_i) \log P(y|x_i;\theta) \\ \end{aligned} \tag{15} DKL​(P∗∣∣P)​=y∑​P∗(y∣xi​)logP(y∣xi​;θ)P∗(y∣xi​)​=y∑​P∗(y∣xi​)[logP∗(y∣xi​)−logP(y∣xi​;θ)]=y∑​P∗(y∣xi​)logP∗(y∣xi​)−y∑​P∗(y∣xi​)logP(y∣xi​;θ)​(15)
观察上面最终的式子,其中P∗(y∣xi)log⁡P∗(y∣xi)P^* (y|x_i)\log P^*(y|x_i)P∗(y∣xi​)logP∗(y∣xi​)与参数θ\thetaθ无关,实际上是真实分布的信息熵,是一个常数;而 −P∗(y∣xi)log⁡P(y∣xi;θ)-P^* (y|x_i)\log P(y|x_i;\theta)−P∗(y∣xi​)logP(y∣xi​;θ)就是我们熟悉的交叉熵的式子。

如果看不明白的话,或者我们换一种写法:DKL(P∗∣∣P)=−S(P∗)+H(P∗,P)D_{KL}(P^*||P)= -S(P^*) + H(P^*,P)DKL​(P∗∣∣P)=−S(P∗)+H(P∗,P),S(P∗)S(P^*)S(P∗)是P∗P^*P∗的信息熵;H(P∗,P)H(P^*,P)H(P∗,P)是交叉熵,KL散度 = 交叉熵 - 熵

因此,我们最小化关于参数θ\thetaθ的KL散度,就相当于最小化式(15)(15)(15)中的第二项,即:
arg⁡min⁡θDKL(P∗∣∣P)≡arg⁡min⁡θ−∑iP∗(y∣xi)log⁡P(y∣xi;θ)(16)\arg\,\min_{\theta}D_{KL}(P^*||P) \equiv \arg\,\min_{\theta} - \sum_i P^* (y|x_i) \log P(y|x_i;\theta) \tag{16} argθmin​DKL​(P∗∣∣P)≡argθmin​−i∑​P∗(y∣xi​)logP(y∣xi​;θ)(16)

arg⁡min⁡θDKL(P∗∣∣P)≡arg⁡min⁡θH(P∗,P)(17)\arg\,\min_{\theta}D_{KL}(P^*||P) \equiv \arg\,\min_{\theta} H(P^*,P) \tag{17} argθmin​DKL​(P∗∣∣P)≡argθmin​H(P∗,P)(17)
因此,在机器学习中,我们要评估预测模型和真实模型之间的差距,可以使用KL散度,而KL散度中的信息熵那一部分不变,所以只需要关注交叉熵就可以了。

基于KL散度恒不小于零的特性,博主找到了一个很好的图示:

红色曲线代表真实概率分布;橙色曲线代表预测概率分布;紫线代表蓝色曲线下的面积,代表这两个分布的交叉熵。

交叉熵的大小与预测分布和真实分布的偏离程度相关。

当两个分布重叠时,此时交叉熵最小,为真实分布的信息熵。

交叉熵损失

在机器学习中,我们需要评估标签值yyy和预测值y^\hat yy^​之间的差距,我们知道只需要关注交叉熵。一般在机器学习中直接用交叉熵做损失函数来评估模型。

loss=−∑j=1nyjlog⁡y^j(18)loss=-\sum_{j=1}^n y_j \log \hat y_j \tag{18} loss=−j=1∑n​yj​logy^​j​(18)
这里yjy_jyj​是真实样本的标签;y^j\hat y_jy^​j​是预测值,通常是一个概率;nnn是分类的个数;因此这是针对单个样本的情况,如果对于批量样本,那么交叉熵计算公式为:
L=−∑i=1m∑j=1nyijlog⁡y^ij(19)\mathcal L = -\sum_{i=1}^m \sum_{j=1}^n y_{ij} \log \hat y_{ij} \tag{19} L=−i=1∑m​j=1∑n​yij​logy^​ij​(19)
其中mmm是样本数;nnn是分类数;yijy_{ij}yij​表示第iii个样本在类别jjj上的真实标签;y^ij\hat y_{ij}y^​ij​表示第iii个样本在类别jjj上的预测概率。

二分类

有一种特殊问题,即分类数为222,就是二分类问题。对于这种问题,由于n=2n=2n=2,y1=1−y2y_1=1-y_2y1​=1−y2​,y^1=1−y^2\hat y_1 = 1- \hat y_2y^​1​=1−y^​2​,所以交叉熵可以简化为:
loss=−[y1log⁡y^1+(1−y1)log⁡(1−y^1)](20)loss = - \left[ y_1 \log \hat y_1 + (1-y_1)\log (1-\hat y_1) \right] \tag{20} loss=−[y1​logy^​1​+(1−y1​)log(1−y^​1​)](20)
对于批量样本的交叉熵为:
L=−∑i=1m[yilog⁡y^i+(1−yi)log⁡(1−y^i)](21)\mathcal L = - \sum_{i=1}^m \left [ y_i \log \hat y_i + (1-y_i)\log(1-\hat y_i) \right] \tag{21} L=−i=1∑m​[yi​logy^​i​+(1−yi​)log(1−y^​i​)](21)
通常对于二分类问题,记正例为111,负例为000。因此上式的两个相加项只会有一个存在。

多分类

常见的是多分类问题,即分类数n≥3n \geq 3n≥3。多分类问题对于批量样本的交叉熵损失即为式(19)(19)(19):
L=−∑i=1m∑j=1nyijlog⁡y^ij(22)\mathcal L = -\sum_{i=1}^m \sum_{j=1}^n y_{ij} \log \hat y_{ij} \tag{22} L=−i=1∑m​j=1∑n​yij​logy^​ij​(22)
这里有必要指出的是,对于多分类问题,标签值(真实类别)一般采用独热编码(one-hot encoding),预测值在输出之前会经过Softmax转换为概率分布。这样交叉熵损失只会关注预测正确的类别的概率。

这种特性使得代码编写也比计较直观。

负对数似然与交叉熵

本小节来看一下负对数似然与交叉熵的关系。

这里考虑的是多分类的情况。二分类的在上文公式(10)(10)(10)中已经证明了。

我们知道在多分类中,会经过Softmax得到概率,有
y^=exp⁡(zi)∑j=1nexp⁡(zj)1≤j≤n(23)\hat y =\frac{\exp(z_i)}{\sum_{j=1}^n \exp(z_j)} \quad 1 \leq j \leq n \tag{23} y^​=∑j=1n​exp(zj​)exp(zi​)​1≤j≤n(23)
这里假设有nnn个类别,即y^\hat yy^​向量是一个长度为nnn个向量,向量中每个元素代表属于一个类别的概率。

假设我们已经有一些观测数据,我们如何计算这些数据的似然?

我们先看一个数据的情况,假设该数据对应的真实类别为ccc,那么该数据的似然就是y^c\hat y_cy^​c​。

而真实类别一般通过ont-hot向量表示(假设在真实类别yyy中只有第ccc个元素为111,其他元素都为000)。那么该数据的似然可以表示为:
∏j=1ny^jyj\prod_{j=1}^n \hat y_j ^{y_j} j=1∏n​y^​jyj​​

虽然一个连乘的形式,假设真实类别为ccc,只有yc=1y_c=1yc​=1,其他都是000,最终的似然依然是y^c\hat y_cy^​c​。

那么负对数似然就是先取对数,再加上负号,即:
−∑j=1nyjlog⁡y^j(24)- \sum_{j=1}^n y_j \log \hat y_j \tag{24} −j=1∑n​yj​logy^​j​(24)
这里yyy是一个one-hot向量,可以看成是真实标签的概率分布,而y^\hat yy^​​就是模型预测的概率。

我们来看多个数据的情况,如果我们有mmm个独立同分布的样本,那么产生这些样本的似然就是产生每个样本的似然之积:
∏i=1m∏j=1ny^ijyij\prod_{i=1}^m \prod_{j=1}^n \hat y_{ij} ^{y_{ij}} i=1∏m​j=1∏n​y^​ijyij​​

yijy_{ij}yij​表示第iii个样本在类别jjj上的真实标签(0或1);y^ij\hat y_{ij}y^​ij​表示第iii个样本在类别jjj上的预测概率。

负对数似然则为:

−∑i=1m∑j=1nyijlog⁡y^ij(25)- \sum_{i=1}^m \sum_{j=1}^n y_{ij} \log \hat y_{ij} \tag{25} −i=1∑m​j=1∑n​yij​logy^​ij​(25)

所以负对数似然就是真实类别乘上预测类别的对数,和上面多分类的交叉熵公式(22)(22)(22)是一模一样的。

均方误差和交叉熵

我们知道,线性回归的损失函数是均方误差,而逻辑回归的损失函数为交叉熵损失。为什么呢?

先看逻辑回归为什么用交叉熵损失而不是均方误差。

逻辑回归其实是分类问题,输出的是一个概率,交叉熵就是用于衡量概率距离的函数,所以选用交叉熵损失。如果把概率值看成是一个数值的话,也可以用均方误差啊。那到底为什么呢?

我们可以从均方误差和交叉熵的函数图形入手。

以二分类问题为例,先看交叉熵的函数图形。

import numpy as np
import matplotlib.pyplot as pltdef cross_entropy(y_hat, y):return -np.log(y_hat) if y == 1 else -np.log(1 - y_hat)y_hat = np.arange(0.01,1,0.01)plt.plot(y_hat, cross_entropy(y_hat, 1), label='y=1')
plt.plot(y_hat, cross_entropy(y_hat, 0), label='y=0')
plt.legend()
plt.show()

其中蓝线代表真实标签y=1y=1y=1时的交叉熵损失函数图形,橙线代表真实标签y=0y=0y=0时的图形。横坐标代表预测值,纵坐标代表损失值。

可以看到,当y=1y=1y=1时(蓝线),如果预测的越正确(预测值与1越近),则损失(惩罚)越小,在越接近0的位置,损失越大。

反过来,当y=0y=0y=0时(橙线),如果预测的越正确(预测值与0越近),则损失越小,在越接近1的位置,损失越大。

我们来看下,当y=1y=1y=1时,预测结果为y^=0.1\hat y=0.1y^​=0.1时的损失:

> cross_entropy(0.1, 1)
2.3025850929940455

大约是2.32.32.3。

我们再来看均方误差的图形:

def mse(y_hat, y):return (y - y_hat)**2plt.plot(y_hat,mse(y_hat, 1) , label='y=1')
plt.plot(y_hat, mse(y_hat, 0), label='y=0')
plt.legend()

其中蓝线代表真实标签y=1y=1y=1时的均方误差损失函数图形,橙线代表真实标签y=0y=0y=0时的图形。横坐标代表预测值,纵坐标代表损失值。

上面纵轴最大值也只是1.01.01.0,整个函数图像看起来也没有特别大的梯度。

我们也来看下,当y=1y=1y=1时,预测结果为y^=0.1\hat y=0.1y^​=0.1​时的损失:

> mse(0.1, 1)
0.81

其损失值也不大,如果选用均方误差作为逻辑回归的损失函数。

上图是Sigmoid和它的导数的图像。

如果用均方误差,那么
z=w⋅x+ba=σ(z)C=(y−a)22dCdw=(y−a)⋅σ(z)′⋅xz = w\cdot x + b\\ a = \sigma(z)\\ C = \frac{(y-a)^2}{2}\\ \frac{dC}{dw} = (y-a) \cdot \sigma(z)^\prime \cdot x z=w⋅x+ba=σ(z)C=2(y−a)2​dwdC​=(y−a)⋅σ(z)′⋅x
可以看到,其中包含σ(z)′\sigma(z)^\primeσ(z)′,而当zzz取值的绝对值很大时,其对应的梯度几乎就是000,因此导致dCdw=0\frac{dC}{dw}=0dwdC​=0,很可能训练不起来。

这样我们就明白了为什么逻辑回归要选择交叉熵。

我们再来看线性回归为什么不选择交叉熵。直接说结论,假设概率分布为高斯分布的情况下,采用交叉熵损失等同于采用均方误差损失。相关证明可以网上查找。

交叉熵损失的梯度

当然,对于梯度下降,我们不需要损失,我们需要它的梯度。单个样本的梯度与我们在前文中看到的逻辑回归的梯度(y^−y)x(\hat y - y)x(y^​−y)x非常相似。让我们考虑一下梯度的一部分,即单个权重的导数。对于每个类kkk,输入xxx的第iii个元素的权重是wk,iw_{k,i}wk,i​,假设xxx共有nnn个特征。与wk,iw_{k,i}wk,i​有关的损失的偏导数是多少?因为kkk被占用了,因此我们用新的符号lll。

由于分母∑j=1Kexp⁡(wj⋅x+bj)\sum_{j=1}^K \exp(w_j \cdot x+ b_j)∑j=1K​exp(wj​⋅x+bj​)中包含wkw_kwk​,因此我们推导如下:

其中去掉偏导符号的等式(28)(28)(28)拿出来展开,第一项:

−∂∂wk,i[∑j=1Kyl(wl⋅x+bl)]=−∂∂wk,i[y1(w1⋅x+b1)+⋯+yk(wk⋅x+bk)+⋯+yK(wK⋅x+bK)]=−∂∂wk,i[y1(w1⋅x+b1)+⋯+yk(wk⋅x+bk)+⋯+yK(wK⋅x+bK)]=−∂∂wk,iyk(wk,1⋅x1+⋯+wk,i⋅xi+⋯+wk,n⋅xn+bk)=−∂∂wk,iyk(wk,1⋅x1+⋯+wk,i⋅xi+⋯+wk,n⋅xn+bk)=−ykxi\begin{aligned} - \frac{\partial}{\partial w_{k,i}} \left[ \sum_{j=1}^K y_l (w_l \cdot x + b_l) \right] &= - \frac{\partial}{\partial w_{k,i}} \left[ y_1 (w_1 \cdot x + b_1) + \cdots +y_k (w_{k} \cdot x + b_k) + \cdots + y_K (w_K \cdot x + b_K) \right] \\ &= - \frac{\partial}{\partial w_{k,i}} \left[ \cancel{y_1 (w_1 \cdot x + b_1)} + \cdots +y_k (w_{k} \cdot x + b_k) + \cdots + \cancel{y_K (w_K \cdot x + b_K)} \right] \\ &= - \frac{\partial}{\partial w_{k,i}} y_k (w_{k,1}\cdot x_1 + \cdots + w_{k,i} \cdot x_i + \cdots + w_{k,n} \cdot x_n + b_k) \\ &= - \frac{\partial}{\partial w_{k,i}} y_k (\cancel{w_{k,1}\cdot x_1} + \cdots + w_{k,i} \cdot x_i + \cdots + \cancel{w_{k,n} \cdot x_n + b_k}) \\ &= - y_k x_i \end{aligned} −∂wk,i​∂​[j=1∑K​yl​(wl​⋅x+bl​)]​=−∂wk,i​∂​[y1​(w1​⋅x+b1​)+⋯+yk​(wk​⋅x+bk​)+⋯+yK​(wK​⋅x+bK​)]=−∂wk,i​∂​[y1​(w1​⋅x+b1​)​+⋯+yk​(wk​⋅x+bk​)+⋯+yK​(wK​⋅x+bK​)​]=−∂wk,i​∂​yk​(wk,1​⋅x1​+⋯+wk,i​⋅xi​+⋯+wk,n​⋅xn​+bk​)=−∂wk,i​∂​yk​(wk,1​⋅x1​​+⋯+wk,i​⋅xi​+⋯+wk,n​⋅xn​+bk​​)=−yk​xi​​

因为只有wk,i⋅xiw_{k,i}\cdot x_iwk,i​⋅xi​项与wk,iw_{k,i}wk,i​有关,其他的偏导数都是000​,所以上面进行了简化。

同理,第二项:
∂∂wk,i[∑l=1Kyllog⁡∑j=1Kexp⁡(wj⋅x+bj)]=∑l=1Kyl∂∂wk,i[∑j=1Kexp⁡(wj⋅x+bj)]∑j=1Kexp⁡(wj⋅x+bj)=∑l=1Kylexp⁡(wk⋅x+bk)⋅xi∑j=1Kexp⁡(wj⋅x+bj)\begin{aligned} \frac{\partial}{\partial w_{k,i}} \left[ \sum_{l=1}^K y_l \log \sum_{j=1}^K \exp(w_j \cdot x + b_j) \right] &= \sum_{l=1}^K y_l \frac{\frac{\partial }{\partial w_{k,i}} \left[ \sum_{j=1}^K \exp(w_j \cdot x + b_j) \right] }{\sum_{j=1}^K \exp(w_j \cdot x + b_j)} \\ &= \sum_{l=1}^K y_l \frac{\exp(w_k \cdot x + b_k) \cdot x_i}{\sum_{j=1}^K \exp(w_j \cdot x + b_j)} \end{aligned} ∂wk,i​∂​[l=1∑K​yl​logj=1∑K​exp(wj​⋅x+bj​)]​=l=1∑K​yl​∑j=1K​exp(wj​⋅x+bj​)∂wk,i​∂​[∑j=1K​exp(wj​⋅x+bj​)]​=l=1∑K​yl​∑j=1K​exp(wj​⋅x+bj​)exp(wk​⋅x+bk​)⋅xi​​​

(29)(29)(29)到(30)(30)(30)是因为exp⁡(wk⋅x+bk)⋅xi∑j=1Kexp⁡(wj⋅x+bj)\frac{\exp(w_k \cdot x + b_k) \cdot x_i}{\sum_{j=1}^K \exp(w_j \cdot x + b_j)}∑j=1K​exp(wj​⋅x+bj​)exp(wk​⋅x+bk​)⋅xi​​ 与lll无关,因此可以提到求和符号左边。而∑l=1Kyl=1\sum_{l=1}^K y_l=1∑l=1K​yl​=1,因此变成了(31)(31)(31)。

事实证明,这个导数只是kkk类的真实值(即1或0)和kkk类分类器输出的概率之间的差额。

Reference

  1. 如何理解信息熵
  2. 损失函数是如何设计出来的
  3. Softmax与Cross-entropy的求导
  4. Softmax回归简介

从零实现深度学习框架——深入浅出交叉熵相关推荐

  1. 从零实现深度学习框架——深入浅出Word2vec(下)

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导. 要深入理解深度学 ...

  2. python学习框架图-从零搭建深度学习框架(二)用Python实现计算图和自动微分

    我们在上一篇文章<从零搭建深度学习框架(一)用NumPy实现GAN>中用Python+NumPy实现了一个简单的GAN模型,并大致设想了一下深度学习框架需要实现的主要功能.其中,不确定性最 ...

  3. 深度学习相关概念:交叉熵损失

    深度学习相关概念:交叉熵损失 交叉熵损失详解 1.激活函数与损失函数 1.1激活函数: 1.2损失函数: 2.对数损失函数(常用于二分类问题): 3.交叉熵.熵.相对熵三者之间的关系 4.交叉熵损失函 ...

  4. 从零实现深度学习框架——GloVe从理论到实战

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  5. 从零实现深度学习框架——Seq2Seq从理论到实战【实战】

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  6. 从零实现深度学习框架——RNN从理论到实战【理论】

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  7. 从零实现深度学习框架——从共现矩阵到点互信息

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  8. 从零实现深度学习框架——LSTM从理论到实战【理论】

    引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.

  9. 【深度学习原理】交叉熵损失函数的实现

    交叉熵损失函数 一般我们学习交叉熵损失函数是在二元分类情况下: L=−[ylogy^+(1−y)log(1−y^)]L=−[ylog ŷ +(1−y)log (1−ŷ )]L=−[ylogy^​+ ...

  10. 深度学习中softmax交叉熵损失函数的理解

    1. softmax层的作用 通过神经网络解决多分类问题时,最常用的一种方式就是在最后一层设置n个输出节点,无论在浅层神经网络还是在CNN中都是如此,比如,在AlexNet中最后的输出层有1000个节 ...

最新文章

  1. video 微信 标签层级过高_基于大数据的用户标签体系建设思路和应用
  2. MATLAB画图:改变坐标轴刻度的显示数值
  3. Spring Autowire自动装配
  4. shell 编程学习笔记(一)
  5. heroku java_部署Java Web项目到Heroku
  6. hadoop和spark的区别
  7. Windows操作系统安全配置缺陷自动检测技术
  8. android实体键盘输入法,推荐一个实体键盘专用输入法,是对 autotext的改进
  9. 公务员计算机基本操作知识培训,计算机基础知识:计算机中窗口的基本操作
  10. 计算机潮流算法一般采用,计算机潮流计算
  11. mysql 建库建表模板 权限管理
  12. 有哪些好的科研工具软件?
  13. ubuntu 20.04 不能鼠标双击打开 .desktop (桌面快捷方式图标)文件(双击变为使用文本编辑器打开)的解决办法
  14. 小猪短租网一个网页上的单个价格
  15. python查看list的shape_列表list、数组np.array等的len,size,shape操作
  16. quartus 13.0 之四位全加器(不需要用modelism的歪门邪道)
  17. Mac电脑如何快速回到桌面?
  18. Kahan求和公式原理
  19. 项目中成功的运用proxool连接池
  20. 5款不妨一试的硬盘碎片整理工具

热门文章

  1. 开源非英文关键词编程语言
  2. Kali Linux 更新源 操作完整版教程
  3. Java反射机制demo(三)—获取类中的构造函数
  4. Swing 显示良好JPanel保存为图片
  5. C#调用c++Dll结构体数组指针的问题
  6. C# Win32API
  7. 推荐一个程序员阅读文章资料时的辅助神器
  8. Intel Edison学习笔记(一)—— 刷系统
  9. 贪心整理一本通1431:钓鱼题解
  10. 4.2 优化数据访问