交叉熵损失函数

一般我们学习交叉熵损失函数是在二元分类情况下:

L=−[ylogy^+(1−y)log(1−y^)]L=−[ylog ŷ +(1−y)log (1−ŷ )]L=−[ylogy^​+(1−y)log(1−y^​)]

推而广之,我们可以得到下面这个交叉熵损失函数公式:

E=−∑ktklog(yk)E=-\sum_k{t_k}log(y_k) E=−k∑​tk​log(yk​)

从机器学习的角度看,这里的yky_kyk​是神经网络的输出,tkt_ktk​是正确解的标签。

而分类标签有两种方式:

  • One-Hot编码
  • 非One-Hot编码

One-Hot编码下的损失函数实现

使用One-Hot编码时,tkt_ktk​中,只有正确解的标签才为1,其他的都是0,所以在相乘时,这项就为0,但是我们知道log(0)log(0)log(0)是负无穷,显然我们需要特别在代码中处理一下。

先不看负无穷的问题,在One-Hot编码时,tkt_ktk​中只有为1的这项,才有输出,也就是说,我们计算交叉熵损失函数,只用计算对应正确解的输出的自然对数即可

代码如下:

def cross_entropy_error(y, t):delta = 1e-7return -np.sum(t * np.log(y + delta))

比如:

t = [0,0,1,0,0,0,0,0,0,0]
y = [0.1,0.05,0.6,0.0,0.05,0.1,0.0,0.1,0.0,0.0]
cross_entropy_error(y,t) # ==> 0.510825...
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
cross_entropy_error(y,t) # ==> 2.30258...

第一个案例中,正确的标签是2,输出的Softmax概率中2对应的标签的概率最大,为0.6,由此计算出来的损失函数值为0.51;第二个案例,预测的概率最大为0.1,以第一个作为预测结果,即0是预测值,得出损失函数值为2.3,可见预测错了损失函数值偏大。

总之,用One-Hot编码,是将
标签值和预测值的编码一一对应,按照交叉熵的公式处理。

非One-hot编码

如果只有一个值,单个样本的损失函数计算如下:

def cross_entropy_error(y, t):delta = 1e-7return -np.log(y + delta)

这是从前面的One-Hot编码那里推导来的,我们只需要神经网络在正确标签处的输出,就可以计算交叉熵误差。

如果是Mini-Batch呢,需要做哪些变化?

Mini-Batch下的交叉熵函数

One-Hot编码

def cross_entropy_error(y, t):if y.ndim == 1:t = t.reshape(1, t.size) # ndarray的size属性是存在的y = y.reshape(1, y.size)batch_size = y.shape[0]return -np.sum(t * np.log(y+ 1e-7)) / batch_size

这里既是y和t是小批量的形式,即二维矩阵,按照Numpy的调性,矩阵直接相乘是按照元素相乘,最后聚和再除以总体个数即可。看起来就除了batch_size,其实是聚和了二维矩阵相乘的结果。

非One-Hot编码

def cross_entropy_error(y, t):if y.ndim == 1:t = t.reshape(1, t.size) # ndarray的size属性是存在的y = y.reshape(1, y.size)batch_size = y.shape[0]return -np.sum(np.log(y[np.arange(batch_size),t] + 1e-7)) / batch_size

这里还是需要注意这句话:我们只需要神经网络在正确标签处的输出,就可以计算交叉熵误差。所以看起来很复杂的y[np.arange(batch_size),t]目的也是为了获得神经网络的输出,取出的是多行与多列的组合。

END.

参考:
《深度学习入门:基于Python的理论和实现》

https://jamesmccaffrey.wordpress.com/2013/11/05/why-you-should-use-cross-entropy-error-instead-of-classification-error-or-mean-squared-error-for-neural-network-classifier-training/

https://www.jianshu.com/p/474439106874

【深度学习原理】交叉熵损失函数的实现相关推荐

  1. 深度学习中交叉熵_深度计算机视觉,用于检测高熵合金中的钽和铌碎片

    深度学习中交叉熵 计算机视觉 (Computer Vision) Deep Computer Vision is capable of doing object detection and image ...

  2. softmax交叉熵损失函数深入理解(二)

    0.前言 前期博文提到经过两步smooth化之后,我们将一个难以收敛的函数逐步改造成了softmax交叉熵损失函数,解决了原始的目标函数难以优化的问题.Softmax 交叉熵损失函数是目前最常用的分类 ...

  3. 交叉熵损失函数的通用性(为什么深度学习DL普遍用它):预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大是一种类似指数增长的级别,结论:它对结果有引导性

    交叉熵损失函数的通用性(为什么深度学习DL普遍用它):预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 " 惩罚 " 越大,而且是非线性增大是一种类似指数增长的级别,结 ...

  4. 二分类交叉熵损失函数python_【深度学习基础】第二课:softmax分类器和交叉熵损失函数...

    [深度学习基础]系列博客为学习Coursera上吴恩达深度学习课程所做的课程笔记. 本文为原创文章,未经本人允许,禁止转载.转载请注明出处. 1.线性分类 如果我们使用一个线性分类器去进行图像分类该怎 ...

  5. 深度学习基础入门篇[五]:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测

    [深度学习入门到进阶]必看系列,含激活函数.优化策略.损失函数.模型调优.归一化算法.卷积模型.序列模型.预训练模型.对抗神经网络等 专栏详细介绍:[深度学习入门到进阶]必看系列,含激活函数.优化策略 ...

  6. 深度学习 交叉熵损失函数

    交叉熵损失函数 前言 交叉熵损失函数 信息量 信息熵 交叉熵 求导过程 应用 扩展 Binary_Crossentropy 均方差损失函数(MSE) 前言 深度学习中的损失函数的选择,需要注意一点是直 ...

  7. 深度学习中softmax交叉熵损失函数的理解

    1. softmax层的作用 通过神经网络解决多分类问题时,最常用的一种方式就是在最后一层设置n个输出节点,无论在浅层神经网络还是在CNN中都是如此,比如,在AlexNet中最后的输出层有1000个节 ...

  8. 【深度学习】——分类损失函数、回归损失函数、交叉熵损失函数、均方差损失函数、损失函数曲线、

    目录 代码 回归问题的损失函数 分类问题的损失函数 1. 0-1损失 (zero-one loss) 2.Logistic loss 3.Hinge loss 4.指数损失(Exponential l ...

  9. 深度学习-tensorflow1.x之交叉熵损失函数(softmax_cross_entropy_with_logits)代码实现 Tensorflow1.x 和 Numpy

    交叉熵损失函数 神经网络(机器学习)中作为损失函数 具体的理解可以看 https://blog.csdn.net/SIGAI_CSDN/article/details/86554230 交叉熵损失函数 ...

最新文章

  1. Timer 的简单介绍
  2. Java图形化:布局方式
  3. css 设置table样式
  4. 51Nod 1242 斐波那契数列的第N项
  5. SQL Server 的通用分页显示存储过程
  6. Radware为夏威夷电信公司全新的DDoS攻击缓解服务提供支持
  7. vue项目打包与配置-学习笔记
  8. 2017年第八届蓝桥杯国赛B组试题A-36进制-进制转换
  9. select2删除选中项,allowClear设置
  10. Outh2协议有哪四种授权模式?
  11. Java集合系列---TreeMap源码解析(巨好懂!!!)
  12. Python+pandas时间序列对象常用操作
  13. Lettuce替换Jedis操作Redis缓存
  14. 切换至 root 身份
  15. 观电影《头号玩家》有感
  16. error_reporting() 错误级别详解
  17. Jacobi法求特征值特征向量
  18. hdu 5211 Mutiple 数学
  19. 3.32 小猪短租的爬虫-
  20. 认识c语言程序,认识C语言

热门文章

  1. CSS样式表操作及选择器定义
  2. 深度学习——卷积块回顾
  3. 利用PCL做点云的平面拟合
  4. java对mysql进行查找替换_Java对MySQL数据库进行连接、查询和修改【转载】
  5. qps多少才算高并发_AGV小车价格多少才算合适?
  6. 三菱q系列特殊继电器一览表_2020山西三菱Q系列PLC模块回收购销
  7. c语言中初始值的作用,初始C语言学习
  8. python getattr用法_python3,定制类,getattr相关用法
  9. 把string时间取出月份_农村集市上现杀活蚌取出来的珍珠,是真的吗?为何价格这么便宜?...
  10. freebsd mysql5.7_FreeBSD 环境下Mysql问题解决方法集锦