你对YOLOV3损失函数真的理解正确了吗?
1. 前言
昨天行云大佬找到我提出了他关于GiantPandaCV公众号出版的《从零开始学YOLOV3》电子书中关于原版本的YOLOV3损失的一个质疑,并给出了他的理解。昨天晚上我仔细又看了下原始论文和DarkNet源码,发现在YOLOV3的原版损失函数的解释上我误导了不少人。所以就有了今天这篇文章,与其说是文章不如说是一个错误修正吧。
2. 在公众号里面的YOLOV3损失函数
在我们公众号出版的YOLOV3的PDF教程里对原始的DarkNet的损失函数是这样解释的,这个公式也是我参照源码(https://github.com/BBuf/Darknet/blob/master/src/yolo_layer.c
)进行总结的,。 我的总结截图如下:
其中 S S S表示 g r i d s i z e grid\ size grid size, S 2 S^2 S2表示 13 × 13 13\times 13 13×13, 26 × 26 26\times 26 26×26, 52 × 52 52\times 52 52×52。B代表box, 1 i j o b j 1_{ij}^{obj} 1ijobj表示如果在 i , j i,j i,j处的box有目标,其值为 1 1 1,否则为 0 0 0。 1 i j n o o b j 1_{ij}^{noobj} 1ijnoobj表示如果 i , j i,j i,j处的box没有目标,其值为 1 1 1,否则为 0 0 0。
BCE(binary cross entropy)的具体公式计算如下:
B C E ( ( ^ c i ) , c i ) = − c i ^ × l o g ( c i ) − ( 1 − c i ^ ) × l o g ( 1 − c i ) BCE(\hat(c_i),c_i)=-\hat{c_i}\times log(c_i)-(1-\hat{c_i})\times log(1-c_i) BCE((^ci),ci)=−ci^×log(ci)−(1−ci^)×log(1−ci)
另外,针对YOLOV3,回归损失会乘以一个 2 − w ∗ h 2-w*h 2−w∗h的比例系数, w w w和 h h h代表Ground Truth box
的宽高,如果没有这个系数AP会下降明显,大概是因为COCO数据集小目标很多的原因。
我根据DarkNet的源码对每一步进行了梯度推导发现损失函数的梯度是和上面的公式完全吻合的,所以当时以为这是对的,感谢行云大佬提醒让我发现了一个致命理解错误,接下来我们就说一下。
3. 行云大佬的损失函数公式
接下来我们看一下行云大佬的损失函数公式,形式如下:
可以看到我的损失函数理解和行云大佬的损失函数理解在回归损失以及分类损失上是完全一致的,只有obj loss表示形式完全不同。对于obj loss,我的公式里面是方差损失,而行云大佬是交叉熵损失。那么这两种形式哪一种是正确的呢?
其实只要对交叉熵损失和方差损失求个导问题就迎刃而解了。
4. 交叉熵损失求导数
推导过程如下:
(1)softmax函数
首先再来明确一下softmax函数,一般softmax函数是用来做分类任务的输出层。softmax的形式为:
S i = e z i ∑ k e z k S_i = \frac{e^{z_i}}{\sum_ke^{z_k}} Si=∑kezkezi
其中 S i S_i Si表示的是第i个神经元的输出,接下来我们定义一个有多个输入,一个输出的神经元。神经元的输出为
z i = ∑ i j x i j + b z_i = \sum_{ij}x_{ij}+b zi=∑ijxij+b
其中 w i j w_{ij} wij是第 i i i个神经元的第 j j j个权重,b是偏移值. z i z_i zi表示网络的第 i i i个输出。给这个输出加上一个softmax函数,可以写成:
a i = e z i ∑ k e z k a_i = \frac{e^{z_i}}{\sum_ke^{z_k}} ai=∑kezkezi,
其中 a i a_i ai表示softmax函数的第 i i i个输出值。这个过程可以用下图表示:
(2)损失函数
softmax的损失函数一般是选择交叉熵损失函数,交叉熵函数形式为:
C = − ∑ i y i l n a i C=-\sum_i{y_i lna_i} C=−∑iyilnai
其中y_i表示真实的标签值
(3)需要用到的高数的求导公式
c'=0(c为常数)
(x^a)'=ax^(a-1),a为常数且a≠0
(a^x)'=a^xlna
(e^x)'=e^x
(logax)'=1/(xlna),a>0且 a≠1
(lnx)'=1/x
(sinx)'=cosx
(cosx)'=-sinx
(tanx)'=(secx)^2
(secx)'=secxtanx
(cotx)'=-(cscx)^2
(cscx)'=-csxcotx
(arcsinx)'=1/√(1-x^2)
(arccosx)'=-1/√(1-x^2)
(arctanx)'=1/(1+x^2)
(arccotx)'=-1/(1+x^2)
(shx)'=chx
(chx)'=shx
(uv)'=uv'+u'v
(u+v)'=u'+v'
(u/)'=(u'v-uv')/^2
(4)进行推导
我们需要求的是loss对于神经元输出 z i z_i zi的梯度,求出梯度后才可以反向传播,即是求:
∂ C ∂ z i \frac{\partial C}{\partial z_i} ∂zi∂C, 根据链式法则(也就是复合函数求导法则) ∂ C ∂ a j ∂ a j ∂ z i \frac{\partial C}{\partial a_j}\frac{\partial a_j}{\partial z_i} ∂aj∂C∂zi∂aj,初学的时候这个公式理解了很久,为什么这里是 a j a_j aj而不是 a i a_i ai呢?这里我们回忆一下softmax的公示,分母部分包含了所有神经元的输出,所以对于所有输出非i的输出中也包含了 z i z_i zi,所以所有的a都要参与计算,之后我们会看到计算需要分为 i = j i=j i=j和 i ≠ j i \neq j i=j两种情况分别求导数。
首先来求前半部分:
∂ C ∂ a j = − ∑ j y i l n a j ∂ a j = − ∑ j y j 1 a j \frac{\partial C}{ \partial a_j} = \frac{-\sum_jy_ilna_j}{\partial a_j} = -\sum_jy_j\frac{1}{a_j} ∂aj∂C=∂aj−∑jyilnaj=−∑jyjaj1
接下来求第二部分的导数:
- 如果 i = j i=j i=j, ∂ a i ∂ z i = ∂ ( e z i ∑ k e z k ) ∂ z i = ∑ k e z k e z i − ( e z i ) 2 ( ∑ k e z k ) 2 = ( e i z ∑ k e z k ) ( 1 − e z i ∑ k e z k ) = a i ( 1 − a i ) \frac{\partial a_i}{\partial z_i} = \frac{\partial(\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=\frac{\sum_ke^{z_k}e^{z_i}-(e^{z_i})^2}{(\sum_ke^{z_k})^2}=(\frac{e^z_i}{\sum_ke^{z_k}})(1 - \frac{e^{z_i}}{\sum_ke^{z_k}})=a_i(1-a_i) ∂zi∂ai=∂zi∂(∑kezkezi)=(∑kezk)2∑kezkezi−(ezi)2=(∑kezkeiz)(1−∑kezkezi)=ai(1−ai)
- 如果 i ≠ j i \neq j i=j, ∂ a i ∂ z i = ∂ e z j ∑ k e z k ∂ z i = − e z j ( 1 ∑ k e k z ) 2 e z i = − a i a j \frac{\partial a_i}{\partial z_i}=\frac{\partial\frac{e^{z_j}}{\sum_ke^{z_k}}}{\partial z_i} = -e^{z_j}(\frac{1}{\sum_ke^z_k})^2e^{z_i}=-a_ia_j ∂zi∂ai=∂zi∂∑kezkezj=−ezj(∑kekz1)2ezi=−aiaj。
接下来把上面的组合之后得到:
∂ C ∂ z i \frac{\partial C}{\partial z_i} ∂zi∂C
= ( − ∑ j y j 1 a j ) ∂ a j ∂ z i =(-\sum_{j}y_j\frac{1}{a_j})\frac{\partial a_j}{\partial z_i} =(−∑jyjaj1)∂zi∂aj
= − y i a i a i ( 1 − a i ) + ∑ j ≠ i y j a j a i a j =-\frac{y_i}{a_i}a_i(1-a_i)+\sum_{j \neq i}\frac{y_j}{a_j}a_ia_j =−aiyiai(1−ai)+∑j=iajyjaiaj
= − y i + y i a i + ∑ j ≠ i y j a i =-y_i+y_ia_i+\sum_{j \neq i}\frac{y_j}a_i =−yi+yiai+∑j=iayji
= − y i + a i ∑ j y j =-y_i+a_i\sum_{j}y_j =−yi+ai∑jyj。
推导完成!
(5)对于分类问题来说,我们给定的结果 y i y_i yi最终只有一个类别是1,其他是0,因此对于分类问题,梯度等于:
∂ C ∂ z i = a i − y i \frac{\partial C}{\partial z_i}=a_i - y_i ∂zi∂C=ai−yi
5. L2损失求导数
推导如下:
我们写出L2损失函数的公式:
L 2 l o s s = ( y i − a i ) 2 L2_{loss}=(y_i-a_i)^2 L2loss=(yi−ai)2,其中 y i y_i yi仍然代表标签值, a i a_i ai表示预测值,同样我们对输入神经元(这里就是 a i a_i ai了,因为它没有经过任何其它的函数),那么 ∂ C ∂ z i = 2 ( a i − y i ) \frac{\partial C}{\partial z_i}=2(a_i - y_i) ∂zi∂C=2(ai−yi),其中 z i = a i z_i=a_i zi=ai。
注意到,梯度的变化由于有学习率的存在所以系数是无关紧要的(只用关心数值梯度),所以我们可以将系数省略,也即是:
∂ C ∂ z i = a i − y i \frac{\partial C}{\partial z_i}=a_i - y_i ∂zi∂C=ai−yi
6. 在原论文求证
可以看到无论是L2损失还是交叉熵损失,我们获得的求导形式都完全一致,都是 o u t p u t − l a b e l output-label output−label的形式,换句话说两者的数值梯度趋势是一致的。接下来,我们去原论文求证一下:
上面标红的部分向我们展示了损失函数的细节,我们可以发现原本YOLOV3的损失函数在obj loss部分应该用二元交叉熵损失的,但是作者在代码里直接用方差损失代替了。
至此,可以发现我之前的损失函数解释是有歧义的,作者的本意应该是行云大佬的损失函数理解那个公式(即obj loss应该用交叉熵,而不是方法差损失),不过恰好训练的时候损失函数是我写出的公式(obj loss用方差损失)。。。神奇吧。
7. 总结
本文根据行云大佬的建议,通过手推梯度并在原始论文找证据的方式为大家展示了YOLOV3的损失函数的深入理解,如果有任何疑问可以在留言区留言交流。
8. 参考
- YOLOV3论文:https://pjreddie.com/media/files/papers/YOLOv3.pdf
- DarkNet原始代码:https://github.com/BBuf/Darknet/blob/master/src/yolo_layer.c
- 行云大佬博客:https://blog.csdn.net/qq_34795071/article/details/92803741
欢迎关注GiantPandaCV, 在这里你将看到独家的深度学习分享,坚持原创,每天分享我们学习到的新鲜知识。( • ̀ω•́ )✧
有对文章相关的问题,或者想要加入交流群,欢迎添加BBuf微信:
为了方便读者获取资料以及我们公众号的作者发布一些Github工程的更新,我们成立了一个QQ群,二维码如下,感兴趣可以加入。
你对YOLOV3损失函数真的理解正确了吗?相关推荐
- 你真的理解“吃亏是福”么?
你真的理解"吃亏是福"么?且看 一个10几年的运维老鸟老男孩的随笔! 供朋友参考! 一定不要计较一时的得失! 有一次老男孩老师看电视,一位老大爷(当时70岁)感慨的说," ...
- [转载] Java内存管理-你真的理解Java中的数据类型吗(十)
参考链接: Java中的字符串类String 1 做一个积极的人 编码.改bug.提升自己 我有一个乐园,面向编程,春暖花开! 推荐阅读 第一季 0.Java的线程安全.单例模式.JVM内存结构等知识 ...
- 您真的理解了SQLSERVER的日志链了吗?
您真的理解了SQLSERVER的日志链了吗? 先感谢宋沄剑给本人指点迷津,还有郭忠辉童鞋今天在QQ群里抛出的问题 这个问题跟宋沄剑讨论了三天,再次感谢宋沄剑 一直以来,SQLSERVER提供了一个非常 ...
- Android Binder通信一次拷贝你真的理解了吗?
Android Binder通信一次拷贝你真的理解了吗? Android Binder框架实现目录: Android Binder框架实现之Binder的设计思想 Android Binder ...
- 您真的理解了SQLSERVER的日志链了吗
您真的理解了SQLSERVER的日志链了吗? 先感谢宋沄剑给本人指点迷津,还有郭忠辉童鞋今天在QQ群里抛出的问题 这个问题跟宋沄剑讨论了三天,再次感谢宋沄剑 一直以来,SQLSERVER提供了一个非常 ...
- 神经网络中的激活函数与损失函数深入理解推导softmax交叉熵
神经网络中的激活函数与损失函数&深入理解softmax交叉熵 前面在深度学习入门笔记1和深度学习入门笔记2中已经介绍了激活函数和损失函数,这里做一些补充,主要是介绍softmax交叉熵损失函数 ...
- 什么叫「真的理解」?我们对 AI 的要求或许有点过分
2020-02-10 11:33:45 作者 | Thomas G. Dietterich 编译 | Fawn 编辑 | 丛末 AI 领域所取得的最新进展给 AI 系统带来的进步,举世瞩目,但是仍有一 ...
- TCP 三次握手原理,你真的理解吗
转载自 TCP 三次握手原理,你真的理解吗 最近,阿里中间件小哥哥蛰剑碰到一个问题--client端连接服务器总是抛异常.在反复定位分析.并查阅各种资料文章搞懂后,他发现没有文章把这两个队列以及怎么 ...
- 简单人物画像_你真的理解用户画像吗?| 船说
" 「设计师沙龙」是ARK下半年开始逐渐形成的传统,由ARKers自发组织,分为视觉和交互两类,每月各举办一次.大家围绕一个话题展开,聊聊行业最新案例和工作上的心得,帮助大家共同进步. AR ...
最新文章
- 利用ajax.dll进行Ajax的开发2007-07-15 15:38
- BIEE汇总数据如何放在后面
- 自制仿360首页支持拼音输入全模糊搜索和自动换肤
- 打印速度快点的打印机_SLM推出了功能强大的新型金属3D打印机,速度快20倍
- ansible常用命令
- iic总线从机仲裁_I2C总线的仲裁问题
- 1年经验却拿总监薪资?看到他做的数据可视化报表,我彻底服了
- impalahive大数据平台数据血缘与数据地图(四)-impala血缘架构图及功能介绍
- Shiro整合JWT实现认证和权限鉴定(执行流程清晰详细)
- 春季必买明星款流行春装
- 关于蓝桥杯大赛,你应该了解的那些事!
- 来自Naval Ravikant 的十句话
- python综合实验心得体会_综合实验心得体会
- Hive元数据库中各个表的含义(十)
- Android 百度地图定位工具类
- 仿某板兔网站源码 laysns模版 基于laysns系统开发 2.55可用
- Android 如何选择城市-CityPicker
- 创建TypeScript工程错误排查
- chrome与12306
- linux批量拷贝文件脚本,把文件复制N份的2个Shell脚本代码