笔者在学习各种分类模型和损失函数的时候发现了一个问题,类似于Logistic Regression模型和Softmax模型,目标函数都是根据最大似然公式推出来的,但是在使用pytorch进行编码的时候,却发现根本就没有提供softmax之类的损失函数,而提供了CrossEntropyLoss,MSELoss之类的。本文将介绍我们在学习LR模型和Softmax模型的时候接触到的目标函数,与实际应用中的经常用到的CrossEntropyLoss函数之间的关系。

弄懂了这个关系之后,笔者突然发现以前的一篇介绍LR模型和softmax模型基础的文章里存在一个十分傻的bug。本着线上有bug偷偷改,文章有bug坚决不改,不但不改还要四处宣扬的游街示众要不然怎么记得住的原则,笔者打算让那个bug保留在文章里,请各位朋友到评论区帮笔者找找这个bug吧。出bug的文章在这里:浅谈线性回归与softmax分类器。

1. 交叉熵函数(Cross Entropy)

对于一个训练样本集,我们可以把损失函数理解为一个关于训练数据的模型输出 a a a,与该样本的标签 a ˙ \dot{a} a˙的函数,标记为 L ( a , a ˙ ) L(a,\dot{a}) L(a,a˙),该函数用于计算所有训练样本的 a a a值和 a ˙ \dot{a} a˙值之间的关系,当 a a a值和 a ˙ \dot{a} a˙值越接近, L ( a , a ˙ ) L(a,\dot{a}) L(a,a˙)越小,反之 L ( a , a ˙ ) L(a,\dot{a}) L(a,a˙)值越大。很多情况下,交叉熵公式(Cross Entropy)是一个很好的选择。这里写出交叉熵公式:
C r o s s E n t r o p y ( a , a ˙ ) = − ∑ a ˙ ⋅ l o g ( a ) CrossEntropy(a,\dot{a})=- \sum\dot{a} \cdot log(a) CrossEntropy(a,a˙)=−∑a˙⋅log(a)
交叉熵函数的图像为:

可以看到,当预测结果与实际结果越相符时,交叉熵越低;否则交叉熵会快速飙高以达到一个较大的惩罚。有人可能会有疑问:这如何解释LR模型和softmax模型的损失函数呢?

2. LR模型损失函数与CrossEntropy的关系

我们把LR模型的损失函数贴一下:
J ( x ; w , b ) = − 1 n ∑ i = 1 n ( q ( x i ) log ⁡ p ( x i ) + ( 1 − q ( x i ) ) log ⁡ ( 1 − p ( x i ) ) ) J(x;w,b) = -\frac1n\sum_{i=1}^n (q(x_i) \log p(x_i)+(1-q(x_i)) \log (1-p(x_i))) J(x;w,b)=−n1​i=1∑n​(q(xi​)logp(xi​)+(1−q(xi​))log(1−p(xi​)))
提取出核心的部分:
− ( q ( x i ) log ⁡ p ( x i ) + ( 1 − q ( x i ) ) log ⁡ ( 1 − p ( x i ) ) ) (1) -(q(x_i) \log p(x_i)+(1-q(x_i)) \log (1-p(x_i)) \tag{1}) −(q(xi​)logp(xi​)+(1−q(xi​))log(1−p(xi​)))(1)

设:该LR模型的标签集为 { T r u e , F a l s e } \{True,False\} {True,False},我们用 q ( T r u e ∣ x ) q(True|x) q(True∣x)和 q ( F a l s e ∣ x ) q(False|x) q(False∣x)表示样本数据 x x x的实际标签数据。当 x x x的标签取 T r u e True True时, q ( T r u e ∣ x ) = 1 , q ( F a l s e ∣ x ) = 0 q(True|x)=1,q(False|x)=0 q(True∣x)=1,q(False∣x)=0;当 x x x的标签取 F a l s e False False时, q ( T r u e ∣ x ) = 0 , q ( F a l s e ∣ x ) = 1 q(True|x)=0,q(False|x)=1 q(True∣x)=0,q(False∣x)=1。式子 ( 1 ) (1) (1)可以改写为:
− ( q ( T r u e ∣ x i ) log ⁡ p ( T r u e ∣ x i ) + q ( F a l s e ∣ x i ) log ⁡ p ( F a l s e ∣ x i ) ) = − ∑ y = T r u e F a l s e q ( y ∣ x ) log ⁡ ( p ( y ∣ x ) ) -(q(True|x_i) \log p(True|x_i)+q(False|x_i) \log p(False|x_i)) = - \sum_{y=True}^{False}q(y|x)\log(p(y|x)) −(q(True∣xi​)logp(True∣xi​)+q(False∣xi​)logp(False∣xi​))=−y=True∑False​q(y∣x)log(p(y∣x))
这个式子是交叉熵公式在二分类场景下的形式。因此这个LR模型的损失公式,其实是关于预测值与标签值之间的交叉熵公式。

3. softmax模型的损失函数与CrossEntropy的关系

同样贴下softmax的损失函数:
J ( x ; w , b ) = − 1 n ∑ i = 1 n log ⁡ exp ⁡ ( w y T x i ) ∑ c exp ⁡ ( w c T x i ) J(x;w,b) = -\frac1n \sum_{i=1}^n \log \frac{\exp(w_y^Tx_i)}{\sum_c \exp(w_c^Tx_i)} J(x;w,b)=−n1​i=1∑n​log∑c​exp(wcT​xi​)exp(wyT​xi​)​

上边这个函数是建立在一个前提上,即:测试数据集中所有数据的分类标签都是确定到一个具体分类。假设我们的标签集为 C = { c 1 , c 2 , . . . , c k } C=\{c_1,c_2,...,c_k\} C={c1​,c2​,...,ck​},一共有k个分类,那么针对测试集中的样本数据 x x x,其标签数据 y y y为一个k维独热向量。也就是说,不允许有标签表示某个测试数据 x x x有一半可能属于 c 1 c_1 c1​,另一半可能属于 c 2 c_2 c2​。
我们把这个公式的关键部分提取一下:
− ∑ log ⁡ exp ⁡ ( w y T x i ) ∑ c exp ⁡ ( w c T x i ) (2) -\sum \log \frac{\exp(w_y^Tx_i)}{\sum_c \exp(w_c^Tx_i)} \tag{2} −∑log∑c​exp(wcT​xi​)exp(wyT​xi​)​(2)
由于:
exp ⁡ ( w y T x i ) ∑ c exp ⁡ ( w c T x i ) = p ( y ∣ x i ) \frac{\exp(w_y^Tx_i)}{\sum_c \exp(w_c^Tx_i)} = p(y|x_i) ∑c​exp(wcT​xi​)exp(wyT​xi​)​=p(y∣xi​)
用 p ( y ∣ x i ) p(y|x_i) p(y∣xi​)替换可得:
− ∑ log ⁡ p ( y ∣ x i ) (3) -\sum \log p(y|x_i) \tag{3} −∑logp(y∣xi​)(3)
已知 y ∈ C y\in C y∈C,设 y = c k y=c_k y=ck​,则式 ( 3 ) (3) (3)可以扩写为
− ∑ ( 0 ⋅ log ⁡ p ( c 1 ∣ x i ) + 0 ⋅ log ⁡ p ( c 2 ∣ x i ) + ⋯ + 0 ⋅ log ⁡ p ( c k − 1 ∣ x i ) + 1 ⋅ log ⁡ p ( y ∣ x i ) ) -\sum (0 \cdot \log p(c_1|x_i) + 0 \cdot \log p(c_2|x_i) + \cdots + 0 \cdot \log p(c_{k-1}|x_i ) + 1 \cdot \log p(y|x_i )) −∑(0⋅logp(c1​∣xi​)+0⋅logp(c2​∣xi​)+⋯+0⋅logp(ck−1​∣xi​)+1⋅logp(y∣xi​))
上式可以写成交叉熵公式的形式:
− ∑ j = 1 k q ( y ∣ x i ) ⋅ log ⁡ p ( y ∣ x i ) -\sum_{j=1}^{k} q(y|x_i) \cdot \log p(y|x_i) −j=1∑k​q(y∣xi​)⋅logp(y∣xi​)

4. 结论

CrossEntropy函数就是我们在学习LR模型和Softmax模型的时候经常遇到的目标函数的更加通用化的表示。不仅适用于多分类场景,也使用于训练数据的标签不唯一的情况,也就是某个训练数据 x x x的标签有50%的可能性为 c 1 c_1 c1​,也有50%的可能性为 c 2 c_2 c2​的情况。

手搓GPT系列之 - Logistic Regression模型,Softmax模型的损失函数与CrossEntropyLoss的关系相关推荐

  1. ​“从0到1手搓GPT”教程来了!李飞飞高徒出品,马斯克点赞!

    来源:量子位 "从0到1手搓GPT"教程来了! 视频1个多小时,从原理到代码都一一呈现,训练微调也涵盖在内,手把手带着你搞定. 该内容刚发出来,在Twitter已吸引400万关注量 ...

  2. 【温故知新】Linner Regression、Logistic Regression、Softmax Regression区别与联系

    先来回顾一下Linner Regression和Logistic Regression,而Softmax Regression可以认为是多类别的Logistic Regression. 因为看过李宏毅 ...

  3. Tensorflow【实战Google深度学习框架】—Logistic regression逻辑回归模型实例讲解

    文章目录 1.前言 2.程序详细讲解 环境设定 数据读取 准备好placeholder,开好容器来装数据 准备好参数/权重 拿到每个类别的score 计算多分类softmax的loss functio ...

  4. python机器学习算法(赵志勇)学习笔记( Logistic Regression,LR模型)

    Logistic Regression(逻辑回归) 分类算法是典型的监督学习,分类算法通过对训练样本的学习,得到从样本特征到样本的标签之间的映射关系,也被称为假设函数,之后可利用该假设函数对新数据进行 ...

  5. Logistic Regression 模型简介

    https://tech.meituan.com/intro_to_logistic_regression.html 逻辑回归(Logistic Regression)是机器学习中的一种分类模型,由于 ...

  6. 逻辑回归Logistic Regression 模型简介

    逻辑回归(Logistic Regression)是机器学习中的一种分类模型,由于算法的简单和高效,在实际中应用非常广泛.本文作为美团机器学习InAction系列中的一篇,主要关注逻辑回归算法的数学模 ...

  7. 逻辑回归Logistic Regression 之基础知识准备

    0. 前言   这学期 Pattern Recognition 课程的 project 之一是手写数字识别,之二是做一个网站验证码的识别(鸭梨不小哇).面包要一口一口吃,先尝试把模式识别的经典问题-- ...

  8. Logistic Regression 之基础知识准备

    0. 前言   这学期 Pattern Recognition 课程的 project 之一是手写数字识别,之二是做一个网站验证码的识别(鸭梨不小哇).面包要一口一口吃,先尝试把模式识别的经典问题-- ...

  9. 逻辑回归(Logistic Regression)简介及C++实现

    逻辑回归(Logistic Regression):该模型用于分类而非回归,可以使用logistic sigmoid函数( 可参考:http://blog.csdn.net/fengbingchun/ ...

最新文章

  1. Saiku_学习_01_saiku安装与运行
  2. Windows服务ServicesDependedOn的奇怪问题?
  3. c语言ascii码表数字,求教!我想显示数字但是现在显示的却是数字在ASCII码中对应的符...
  4. android 等待按钮框架,Android 开发 MaterialDialog框架的详解
  5. table表格 html 1128
  6. 更新CentOS中的python(从2.6.X到2.7.X)
  7. Java 面向对象 之 多态实例2
  8. 图片转换成base64编码格式展示
  9. android m4a转mp3格式转换,音频提取格式转换app
  10. 模糊综合评价模型 ——第四部分,三级模糊综合评价模型应用:例题5,陶瓷厂六种产品销量的评判
  11. 【ME909】华为ME909 4G LTE模块在树莓派下通过minicom进行发送短信演示
  12. 李佳琦如果直播卖保险,你敢不敢买?
  13. c4d语言包英文,Maxon Cinema 4D R23(C4D R23)中英文安装及设置详细教程(附下载)
  14. web前端基础——超链接(dw笔记版)
  15. JAVA实现Excel文件的导入导出
  16. # 学习记录1(C#-解决内存泄漏的几种方法)
  17. 微信小程序——简单饮食推荐(四)
  18. CKA考试习题:存储管理-普通卷、PV、PVC
  19. wptx64能卸载吗_Win10如何卸载应用?Win10内置应用卸载方法
  20. 技嘉显卡性能测试软件,你好六啊!GTX 1660 Ti深度测试:升吧

热门文章

  1. linux基础-快速入门
  2. Android用命令行查看手机架构
  3. -I (大写i)、-L、-l(小写L) 的使用
  4. 为什么 ChatGPT 会引起 Google 的恐慌?
  5. day18_雷神_django第一天
  6. UNIX发展史(BSD,GNU,linux)(转)
  7. 【青少年编程】【三级】接苹果
  8. 求助:Appium 如何实现登录手机淘宝时拖动苹果到购物车的验证
  9. 新一代 IT 服务管理平台 DOSM,助力企业数字化转型
  10. 解决MAC OS X不识别Kindle Fire