交叉熵介绍

交叉熵(Cross Entropy)是Loss函数的一种(也称为损失函数或代价函数),用于描述模型预测值与真实值的差距大小,常见的Loss函数就是均方平方差(Mean Squared Error),定义如下:

平方差很好理解,预测值与真实值直接相减,为了避免得到负数取绝对值或者平方,再做平均就是均方平方差。注意这里预测值需要经过sigmoid激活函数,得到取值范围在0到1之间的预测值。

平方差可以表达预测值与真实值的差异,但在分类问题种效果并不如交叉熵好,原因可以参考James D. McCaffrey 的 Why You Should Use Cross-Entropy Error Instead Of Classification Error Or Mean Squared Error For Neural Network Classifier Training

交叉熵的定义如下:

截图来自

https://hit-scir.gitbooks.io/neural-networks-and-deep-learning-zh_cn/content/chap3/c3s1.html

上面的文章也介绍了交叉熵可以作为Loss函数的原因,首先是交叉熵得到的值一定是正数,其次是预测结果越准确值越小,注意这里用于计算的“a”也是经过sigmoid激活的,取值范围在0到1。如果label是1,预测值也是1的话,前面一项y * ln(a)就是1 * ln(1)等于0,后一项(1 - y) * ln(1 - a)也就是0 * ln(0)等于0,Loss函数为0,反之Loss函数为无限大非常符合我们对Loss函数的定义。

这里多次强调sigmoid激活函数,是因为在多目标或者多分类的问题下有些函数是不可用的,而TensorFlow本身也提供了多种交叉熵算法的实现。

TensorFlow的交叉熵函数

TensorFlow针对分类问题,实现了四个交叉熵函数,分别是

  • tf.nn.sigmoid_cross_entropy_with_logits

  • tf.nn.softmax_cross_entropy_with_logits

  • tf.nn.sparse_softmax_cross_entropy_with_logits

  • tf.nn.weighted_cross_entropy_with_logits

详细内容请参考API文档

https://www.tensorflow.org/versions/master/api_docs/python/nn.html#sparse_softmax_cross_entropy_with_logits

sigmoid_cross_entropy_with_logits

我们先看sigmoid_cross_entropy_with_logits,为什么呢,因为它的实现和前面的交叉熵算法定义是一样的,也是TensorFlow最早实现的交叉熵算法。这个函数的输入是logits和targets,logits就是神经网络模型中的 W * X矩阵,注意不需要经过sigmoid,而targets的shape和logits相同,就是正确的label值,例如这个模型一次要判断100张图是否包含10种动物,这两个输入的shape都是[100, 10]。注释中还提到这10个分类之间是独立的、不要求是互斥,这种问题我们成为多目标,例如判断图片中是否包含10种动物,label值可以包含多个1或0个1,还有一种问题是多分类问题,例如我们对年龄特征分为5段,只允许5个值有且只有1个值为1,这种问题可以直接用这个函数吗?答案是不可以,我们先来看看sigmoid_cross_entropy_with_logits的代码实现吧。

可以看到这就是标准的Cross Entropy算法实现,对W * X得到的值进行sigmoid激活,保证取值在0到1之间,然后放在交叉熵的函数中计算Loss。对于二分类问题这样做没问题,但对于前面提到的多分类,例如年轻取值范围在0~4,目标值也在0~4,这里如果经过sigmoid后预测值就限制在0到1之间,而且公式中的1 - z就会出现负数,仔细想一下0到4之间还不存在线性关系,如果直接把label值带入计算肯定会有非常大的误差。因此对于多分类问题是不能直接代入的,那其实我们可以灵活变通,把5个年龄段的预测用onehot encoding变成5维的label,训练时当做5个不同的目标来训练即可,但不保证只有一个为1,对于这类问题TensorFlow又提供了基于Softmax的交叉熵函数。

softmax_cross_entropy_with_logits

Softmax本身的算法很简单,就是把所有值用e的n次方计算出来,求和后算每个值占的比率,保证总和为1,一般我们可以认为Softmax出来的就是confidence也就是概率,算法实现如下。

softmax_cross_entropy_with_logits和sigmoid_cross_entropy_with_logits很不一样,输入是类似的logits和lables的shape一样,但这里要求分类的结果是互斥的,保证只有一个字段有值,例如CIFAR-10中图片只能分一类而不像前面判断是否包含多类动物。想一下问什么会有这样的限制?在函数头的注释中我们看到,这个函数传入的logits是unscaled的,既不做sigmoid也不做softmax,因为函数实现会在内部更高效得使用softmax,对于任意的输入经过softmax都会变成和为1的概率预测值,这个值就可以代入变形的Cross Entroy算法- y * ln(a) - (1 - y) * ln(1 - a)算法中,得到有意义的Loss值了。如果是多目标问题,经过softmax就不会得到多个和为1的概率,而且label有多个1也无法计算交叉熵,因此这个函数只适合单目标的二分类或者多分类问题,TensorFlow函数定义如下。

再补充一点,对于多分类问题,例如我们的年龄分为5类,并且人工编码为0、1、2、3、4,因为输出值是5维的特征,因此我们需要人工做onehot encoding分别编码为00001、00010、00100、01000、10000,才可以作为这个函数的输入。理论上我们不做onehot encoding也可以,做成和为1的概率分布也可以,但需要保证是和为1,和不为1的实际含义不明确,TensorFlow的C++代码实现计划检查这些参数,可以提前提醒用户避免误用。

sparse_softmax_cross_entropy_with_logits

sparse_softmax_cross_entropy_with_logits是softmax_cross_entropy_with_logits的易用版本,除了输入参数不同,作用和算法实现都是一样的。前面提到softmax_cross_entropy_with_logits的输入必须是类似onehot encoding的多维特征,但CIFAR-10、ImageNet和大部分分类场景都只有一个分类目标,label值都是从0编码的整数,每次转成onehot encoding比较麻烦,有没有更好的方法呢?答案就是用sparse_softmax_cross_entropy_with_logits,它的第一个参数logits和前面一样,shape是[batch_size, num_classes],而第二个参数labels以前也必须是[batch_size, num_classes]否则无法做Cross Entropy,这个函数改为限制更强的[batch_size],而值必须是从0开始编码的int32或int64,而且值范围是[0, num_class),如果我们从1开始编码或者步长大于1,会导致某些label值超过这个范围,代码会直接报错退出。这也很好理解,TensorFlow通过这样的限制才能知道用户传入的3、6或者9对应是哪个class,最后可以在内部高效实现类似的onehot encoding,这只是简化用户的输入而已,如果用户已经做了onehot encoding那可以直接使用不带“sparse”的softmax_cross_entropy_with_logits函数。

weighted_sigmoid_cross_entropy_with_logits

weighted_sigmoid_cross_entropy_with_logits是sigmoid_cross_entropy_with_logits的拓展版,输入参数和实现和后者差不多,可以多支持一个pos_weight参数,目的是可以增加或者减小正样本在算Cross Entropy时的Loss。实现原理很简单,在传统基于sigmoid的交叉熵算法上,正样本算出的值乘以某个系数接口,算法实现如下。

总结

这就是TensorFlow目前提供的有关Cross Entropy的函数实现,用户需要理解多目标和多分类的场景,根据业务需求(分类目标是否独立和互斥)来选择基于sigmoid或者softmax的实现,如果使用sigmoid目前还支持加权的实现,如果使用softmax我们可以自己做onehot coding或者使用更易用的sparse_softmax_cross_entropy_with_logits函数。

TensorFlow提供的Cross Entropy函数基本cover了多目标和多分类的问题,但如果同时是多目标多分类的场景,肯定是无法使用softmax_cross_entropy_with_logits,如果使用sigmoid_cross_entropy_with_logits我们就把多分类的特征都认为是独立的特征,而实际上他们有且只有一个为1的非独立特征,计算Loss时不如Softmax有效。这里可以预测下,未来TensorFlow社区将会实现更多的op解决类似的问题,我们也期待更多人参与TensorFlow贡献算法和代码 :)

干货回顾丨TensorFlow四种Cross Entropy算法的实现和应用相关推荐

  1. TensorFlow四种Cross Entropy算法实现和应用

    交叉熵介绍 交叉熵(Cross Entropy)是Loss函数的一种(也称为损失函数或代价函数),用于描述模型预测值与真实值的差距大小,常见的Loss函数就是均方平方差(Mean Squared Er ...

  2. TensorFlow学习笔记(二十三)四种Cross Entropy交叉熵算法实现和应用

    交叉熵(Cross-Entropy) 交叉熵是一个在ML领域经常会被提到的名词.在这篇文章里将对这个概念进行详细的分析. 1.什么是信息量? 假设是一个离散型随机变量,其取值集合为,概率分布函数为 p ...

  3. C++基础代码--20余种数据结构和算法的实现

    C++基础代码--20余种数据结构和算法的实现 过年了,闲来无事,翻阅起以前写的代码,无意间找到了大学时写的一套C++工具集,主要是关于数据结构和算法.以及语言层面的工具类.过去好几年了,现在几乎已经 ...

  4. unity回顾之力的四种ForceMode

    给具有刚体的物体添加力,常使用的方法void Rigidbody.AddForce();有四个重载:void Rigidbody.AddForce(vector3 Force);.void Rigid ...

  5. Tensorflow四种交叉熵函数计算公式

    Tensorflow交叉熵函数:cross_entropy 注意:tensorflow交叉熵计算函数输入中的logits都不是softmax或sigmoid的输出,而是softmax或sigmoid函 ...

  6. 干货回顾丨深度学习性能提升的诀窍

    Pedro Ribeiro Simoes拍摄 原文: How To Improve Deep Learning Performance 作者: Jason Brownlee 你是如何提升深度学习模型的 ...

  7. 干货课堂丨分享一种LCD驱动电路方案【飞凌嵌入式】

    在一次项目定制中,客户要求我们将 CPU 主控和 LCD 显示屏电压驱动电路做成一体板, LCD 显示屏所需要的 AVDD,VGH,VGL 等电压需要主控板提供,因为这几路电压所输出的电流都很小(一般 ...

  8. 干货回顾丨深度学习应用大盘点

      当首次介绍深度学习时,我们认为它是一个要比机器学习更好的分类器.或者,我们亦理解成大脑神经计算. 第一种理解大大低估了深度学习构建应用的种类,而后者又高估了它的能力,因而忽略了那些不是一般人工智能 ...

  9. 干货回顾丨机器学习笔记-----AP(affinity propagat)算法讲解及matlab实现

    在统计和数据挖掘中,亲和传播(AP)是基于数据点之间"消息传递"概念的聚类算法.与诸如k-means或k-medoids的聚类算法不同,亲和传播不需要在运行算法之前确定或估计聚类的 ...

最新文章

  1. esp freertos_如何开始使用FreeRTOS和ESP8266
  2. ELK不香了!我用Graylog
  3. python使用pandas通过聚合获取时序数据的最后一个指标数据(例如长度指标、时间指标)生成标签并与原表连接(join)进行不同标签特征的可视化分析
  4. 和逛微博、刷朋友圈一样玩转 GitHub
  5. Apache Camel简介与入门
  6. 超参数优化 贝叶斯优化框架_mlmachine-使用贝叶斯优化进行超参数调整
  7. java中true转换为int_在Java中将字节转换为int的最优雅的方式
  8. 全国计算机网络教学研讨会,历届全国高校计算机网络教学研讨会
  9. 轻触开源(一)-Java泛型Type类型的应用和实践
  10. jQuery插件实现网页底部自动加载-类似新浪微博
  11. ES6、ES7、ES8、ES9、ES10 新特性ECMAScript版本简介
  12. emule服务器无响应,全部服务器无响应!!!
  13. 定时器 - 延时函数
  14. 进销存货物管理系统 论文
  15. windows命令行将应用程序加入环境变量
  16. Eli Lilly(礼来) | RPA在医疗行业的应用案例
  17. vue+element ui设置默认头像
  18. 计算机网络-DHCP的工作原理,IP地址如何获取
  19. Emacs 安装与使用
  20. 帝国CMS 7.2 蓝色响应式网站模板自适应宽屏智能整站源码 A1

热门文章

  1. 《预训练周刊》第6期:GAN人脸预训练模型、通过深度生成模型进行蛋白序列设计
  2. 程序员必读10本算法书推荐
  3. MoviePy - 中文文档4-MoviePy实战案例-把多个clip放置在一个画面中(超美)
  4. 任何网络都能山寨!新型黑盒对抗攻击可模拟未知网络进行攻击 | CVPR 2021
  5. 9款超赞的AI开源项目!| 本周Github精选
  6. 一文读懂生成对抗网络GANs(附学习资源)
  7. 从零开始教你训练神经网络(附公式学习资源)
  8. 2020年最新!百度、微软、浪潮、谷歌企业级综述更新!
  9. 国产深度学习框架迎来高光时刻,继清华 Jittor开源后,旷视「天元」纷纷重磅开源!...
  10. MySQL内存预估_mysql时该如何估算内存的消耗,公式如何计算?