全连接层解决MNIST:只是一层全连接层解决MNIST数据集
神经网络的传播:讲解了权重更新的过程
这个系列的文章都是为了总结我目前学习的积累。

损失函数

在我文章的网络中,我利用MSE(mean-square error,均方误差)作为损失函数,softmax作为激活函数。

prediction = tf.nn.softmax(tf.matmul(x, W)+b)
loss = tf.reduce_mean(tf.square(y-prediction))

在我的理解中,样本是堆放在一个空间的。假设我们的理想模型是一个函数,那么图片经过它得出的值跟图片经过我们构建的模型得出的值之间的距离,可以通过MSE来近似表示。当值的距离无限缩小时,我们的模型也就越接近理想模型。

但是,其实在应用分类问题的过程中,我们偏向于应用交叉熵(损失函数)而不是MSE。

在监督学习(supervised learning)中,我们把问题分成回归和分类。两者的本质都是相同的,但是输出不一样。我们可以认为分类的输出是离散的,而回归的输出是连续的。举个例子:

我们来测量小明的温度。那么回归是输出他的体温,如37.5度、38度等等。而分类是着重在他发烧亦或者正常。

或许例子有点奇怪,但这是我认为的它们的区别。

要介绍熵,我们需要先从信息量讲起。我们需要明确一点,越难发生的事情,它能提供的信息也就越多,信息量也就越高。越容易发生的事情,它能提供的信息也就越小,信息量也就越低。再举个例子,当你设置了个闹钟,它响了,你理所当然觉得很正常,自然也不会提供任何信息给你。但是过了时间它还不响,那就说明了可能没电了、可能坏了。(例子真烂,哈哈哈)

由此看来,信息量是跟概率挂钩的存在。因此,相信我们都知道一件事情的概率都记作p(xi)p(xi)p(x_i),那么信息量的定义如下:

I(xi)=−ln(p(xi))I(xi)=−ln(p(xi))

I(x_i) = -ln(p(x_i))

这是−ln(p(xi))−ln(p(xi))-ln(p(x_i))横坐标在0.0 - 1.0的图像(0<p(xi)<10<p(xi)<10

),很形象的体现了信息量跟概率的关系。即随着概率的增加,能够提供的信息量逐步减少。

在信息论里面,熵是对不确定性的测量。但是在信息世界,熵越高,则能传输越多的信息,熵越低,则意味着传输的信息越少。

H(x)=E[I(x)]=E[−ln(p(x))]x={x1,x2,...,xn}H(xi)=∑ip(xi)I(xi)=−∑ip(xi)lnp(xi)H(x)=E[I(x)]=E[−ln(p(x))]x={x1,x2,...,xn}H(xi)=∑ip(xi)I(xi)=−∑ip(xi)lnp(xi)

H(x) = E[I(x)] = E[-ln(p(x))] \qquad x = \left\{ x_1, x_2, ..., x_n \right\} \\ \\ H(x_i) = \sum_i p(x_i)I(x_i) = -\sum_i p(x_i)lnp(x_i)

E为期望函数,而I(x)为x的信息量。即熵会等于信息量的期望,也就是所有x的概率乘以对应的信息量的总和。

我们需要再引入一个概念,相对熵(KL散度)。

KL散度是两个概率分布P和Q差别的非对称性的度量。(from Wiki)

那么我们可以知道,它是描述两个概率分布的差别。所谓的概率分布,也就是我们的标签和预测值了。我们在第一篇文章提到标签的one-hot格式是[0, 1, 0, …, 0]的类型,这是一个对MNIST数据集的准确的描述,因为它肯定它的某一个分类概率一定为1,而其他为0。但我们的预测值是一定同标签存在一定误差的,这也是我们评价这个模型的一个参数,损失值。

所以说,KL散度是用来描述误差很好的指标。那么我们为什么会用到交叉熵?

假设我们定义P(x)P(x)P(x)(真实分布,即标签)和Q(x)Q(x)Q(x)(理论分布,即模型)为两个概率分布,那么对于他们的KL散度,我们可以有:

DKL(P||Q)=−∑iP(x)lnQ(x)P(x)x={x1,x2,...xn}DKL(P||Q)=∑i=1nP(xi)ln(P(xi))−∑i=1nP(xi)ln(Q(xi))DKL(P||Q)=−H(P(x))+[−∑i=1nP(xi)ln(Q(xi))]DKL(P||Q)=−∑iP(x)lnQ(x)P(x)x={x1,x2,...xn}DKL(P||Q)=∑i=1nP(xi)ln(P(xi))−∑i=1nP(xi)ln(Q(xi))DKL(P||Q)=−H(P(x))+[−∑i=1nP(xi)ln(Q(xi))]

D_{KL}(P||Q) = -\sum_i P(x) ln \frac {Q(x)}{P(x)} \qquad x = \left\{ x_1, x_2, ... x_n \right\} \\ D_{KL}(P||Q) = \sum^n_{i=1}P(x_i)ln(P(x_i)) - \sum^n_{i=1}P(x_i)ln(Q(x_i)) \\ D_{KL}(P||Q) = -H(P(x)) + [-\sum^n_{i=1}P(x_i)ln(Q(x_i))]

即,−H(P(x))−H(P(x))-H(P(x))为标签的熵,一个固定值。那么我们在优化标签和理论分布的KL散度的时候,不如直接优化后面的部分。我们将后面的部分称作交叉熵。

参考链接的第一个博客,还讲解了如何简化计算交叉熵。例子如下:
对于one-hot,p = [0, 1, 0],q = [0.2, 0.8, 0.3],有:

H(p,q)=−∑i=1np(xi)ln(q(xi))H(p,q)=−p(x2)∗ln(q(x2))=−1∗ln(0.8)H(p,q)=−∑i=1np(xi)ln(q(xi))H(p,q)=−p(x2)∗ln(q(x2))=−1∗ln(0.8)

H(p, q) = -\sum^n_{i=1}p(x_i)ln(q(x_i)) \\ H(p, q) = -p(x_2) * ln(q(x_2)) = -1 * ln(0.8)
对于n-hot(多分类), p = [1, 1, 0], q = [0.8, 0.6, 0.3],真实分布中不止有一个为1,则:

yiH(p,q)lossxilossx1lossx2lossx3=p(xi)y^i=q(xi)=lossx1+lossx2+lossx3=−yi ln(y^i)−(1−yi)ln(1−y^i)=− y1 ln(y^1)−(1− y1)ln(1−y^1)=−1∗ln(0.8)−0∗ln(1−0.8)=−ln(0.8)=−ln(0.6)=−0∗ln(0.3)−1∗ln(0.7)=−ln(0.7)yi=p(xi)y^i=q(xi)H(p,q)=lossx1+lossx2+lossx3lossxi=−yiln(y^i)−(1−yi)ln(1−y^i)lossx1=−y1ln(y^1)−(1−y1)ln(1−y^1)=−1∗ln(0.8)−0∗ln(1−0.8)=−ln(0.8)lossx2=−ln(0.6)lossx3=−0∗ln(0.3)−1∗ln(0.7)=−ln(0.7)

\begin{aligned}y_i & = p(x_i) \qquad \hat y_i = q(x_i) \\H(p, q) & = loss_{x_1} + loss_{x_2} + loss_{x_3} \\ loss_{x_i} & = -y_i\ ln(\hat y_i) - (1-y_i)ln(1-\hat y_i)\\ \\loss_{x_1} & = - \ y_1\ ln(\hat y_1) - (1- \ y_1)ln(1-\hat y_1)\\ & = -1 * ln(0.8) - 0 * ln(1 - 0.8) = - ln(0.8) \\ loss_{x_2} & = - ln(0.6) \\ loss_{x_3} & = -0 * ln(0.3) - 1 * ln(0.7) \\ & = -ln(0.7)\end{aligned}

loss函数公式的理解不是很难。我们需要明确,假如我们网络最后输出的是三个节点,那么,三个xixix_i节点的loss值加在一起就是全部的loss。计算xixix_i节点时,假如xixix_i的概率yiyiy_i为1,则其他的类别我们不需要判断。假如xixix_i的概率yiyiy_i为0时,我们计算loss值要计算其他真实存在的分类。

换了一个损失函数之后,从第一篇里面的最高准确率0.9179变成:


emmm,好像没提升多少,但最高值0.9258比0.9179多了0.079,也就是多了7.9%啦。

[参考]
https://blog.csdn.net/tsyccnh/article/details/79163834 (关于交叉熵,很好的教程)
https://www.zhihu.com/question/65288314/answer/244557337 (知乎上的大佬)
https://zh.wikipedia.org/wiki/%E7%9B%B8%E5%AF%B9%E7%86%B5 (Wikipedia)

损失函数(MSE和交叉熵)相关推荐

  1. 【交叉熵损失函数】关于交叉熵损失函数的一些理解

    目录 0. 前言 1.损失函数(Loss Function) 1.1 损失项 1.2 正则化项 2. 交叉熵损失函数 2.1 softmax 2.2 交叉熵 0. 前言 有段时间没写博客了,前段时间主 ...

  2. 交叉熵损失函数优缺点_交叉熵损失函数的优点(转载)

    第一篇: 利用一些饱和激活函数的如sigmoid激活时,假如利用均方误差损失,那么损失函数向最后一层的权重传递梯度时,梯度公式为 可见梯度与最后一层的激活函数的导数成正比,因此,如果起始输出值比较大, ...

  3. 机器学习常见损失函数,二元交叉熵,类别交叉熵,MSE,稀疏类别交叉熵

    一 损失函数介绍 损失函数用于描述模型预测值与真实值的差距大小.一般有有两种常见的算法--均值平方差(MSE)和交叉熵.下面来分别介绍每个算法的具体内容. 1 均值平方差 均值平方差(Mean Squ ...

  4. 经典损失函数——均方误差(MSE)和交叉熵误差(CEE)的python实现

    损失函数(loss function)用来表示当前的神经网络对训练数据不拟合的程度.这个损失函数有很多,但是一般使用均方误差和交叉熵误差等. 1.均方误差(mean squared error) 先来 ...

  5. LESSON 10.110.210.3 SSE与二分类交叉熵损失函数二分类交叉熵损失函数的pytorch实现多分类交叉熵损失函数

    在之前的课程中,我们已经完成了从0建立深层神经网络,并完成正向传播的全过程.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具AutoGrad ...

  6. 交叉熵损失函数公式_交叉熵损失函数对其参数求导

    1.Sigmoid 二分类交叉熵 交叉熵公式: 其中y是laebl:0 或1. hθ(xi)是经过sigmoid得到的预测概率.θ为网络的参数, m为样本数. hθ()函数如下所示, J(θ) 对参数 ...

  7. 交叉熵损失函数分类_交叉熵损失函数

    我们先从逻辑回归的角度推导一下交叉熵(cross entropy)损失函数. 从逻辑回归到交叉熵损失函数 这部分参考自 cs229-note1 part2. 为了根据给定的 预测 (0或1),令假设函 ...

  8. 语义分割损失函数系列(1):交叉熵损失函数

    最近一直在做一些语义分割相关的项目,找损失函数的时候发现网上这些大佬的写得各有千秋,也没说怎么用,在此记录一下自己在训练过程中使用损失函数的一些心得.本人是使用的Pytorch框架,故这一系列都会基于 ...

  9. 交叉熵损失函数优缺点_交叉熵损失函数

    交叉熵代价函数(Cross-entropy cost function)是用来衡量人工神经网络(ANN)的预测值与实际值的一种方式.与二次代价函数相比,它能更有效地促进ANN的训练.在介绍交叉熵代价函 ...

  10. [TensorFlow] 交叉熵损失函数,加权交叉熵损失函数

    写在前面 在文章[TensorFlow] argmax, softmax_cross_entropy_with_logits, sparse_softmax_cross_entropy_with_lo ...

最新文章

  1. linux下面显示所有正在运行的线程
  2. windows API 开发飞机订票系统 图形化界面 (一)
  3. 【Pytorch神经网络基础理论篇】 02 pytorch环境的安装
  4. Alibaba Cloud Linux 2 开源后又有什么新动作?
  5. 输入法全屏_五笔输入法那么方便,为什么败给了拼音?如今,我可算是明白了...
  6. LeetCode 701 二叉搜索树中的插入操作
  7. Java后台生成小程序二维码
  8. 网站优化工具-YUI Compressor
  9. MockingBrid(AI拟声)教程
  10. Final Scrum
  11. SG90舵机的电路连接和驱动(树莓派)
  12. 批量删除新浪微博关注
  13. dm9000数据速率_STM32F103战舰DM9000的LWIP例程TCP速度慢,发送间隔太长
  14. 【最小开发板】Attiny85开发与实践
  15. Android 购物选择颜色、尺码实现(二)
  16. 梦龙物联卡冻结_四川梦龙科技物联卡哪个划算
  17. 未来的全能保姆机器人作文_未来的保姆机器人
  18. 快速刷QQ空间访问量QQ军刀
  19. 网站的配色应该如何做
  20. Linux 系统信息查看命令

热门文章

  1. 编解码学习笔记(九):QuickTime系列
  2. 超级计算机浪漫展览,这是最独特的“中国式浪漫”
  3. linux就是这个范儿之特种文件系统(1)
  4. java毕业设计基于ssm框架的生鲜超市进销存管理系统
  5. python支付宝二维码支付源代码
  6. 中线提取算法_基于Guided Filter的地形图中线要素提取算法
  7. 文献管理软件Mendeley优缺点分析
  8. rust 连接mysql数据库_Dlang、Rust 以及 Golang 数据库操作方式对比
  9. SYNPROXY抵御DDoS攻击的原理和优化
  10. java ean13_【求大神指导】java实现EAN13条形码识别