文章目录

  • 1 什么是KL Divergence(KL散度,也说是KL距离)
  • 2 一个简单的例子
  • 3 KL的性质
  • 4 KL散度的公式介绍
  • 5 Pytorch实现KL散度——F.kl_div()
    • 5.1 函数原型
    • 5.2 简单代码
  • 6 KL散度在R- Drop中的应用
    • 6.1 什么是Dropout?
    • 6.2 引入 R-drop
    • 6.3 在R-Drop的代码

1 什么是KL Divergence(KL散度,也说是KL距离)

KL散度是一种概率分布和另一种概率分布的差异的距离。
公式如下:

2 一个简单的例子

3 KL的性质

  1. 具有非对称性,且不满足三角不等式形式,即是DKL(P∣∣Q)≠DKL(Q∣∣P)D_{KL}(P||Q)≠D_{KL}(Q||P)DKL​(P∣∣Q)​=DKL​(Q∣∣P),说明KL散度不是用来衡量距离的。
  2. 如图:

4 KL散度的公式介绍

公式如下:

其中,DKL(P∣∣Q)D_{KL}(P||Q)DKL​(P∣∣Q)——概率PPP与概率QQQ之间的差异,散度越小,则两个概率越接近,那么估计的概率分布也就与真实的概率分布越接近。

KL散度,计算出来的是两者的概率差,可以说KL散度是一种损失,来计算两者的概率差异。

5 Pytorch实现KL散度——F.kl_div()

5.1 函数原型

torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')
  • 参数介绍:
input – Tensor of arbitrary shapetarget – Tensor of the same shape as inputsize_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: Truereduce (bool, optional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: Truereduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'batchmean' | 'sum' | 'mean'. 'none': no reduction will be applied 'batchmean': the sum of the output will be divided by the batchsize 'sum': the output will be summed 'mean': the output will be divided by the number of elements in the output Default: 'mean'
{'none' | 'batchmean' | 'sum' | 'mean'}
'none':不进行压缩
'batchmean':输出的总和除以batchsize
'sum':输出的总和
'mean':输出的输出除以输出中的元素数量
默认值:'mean'
  1. 第一个参数是一个对数概率矩阵,第二个参数是概率矩阵。这里很重要,不然求出来的kl散度可能是个负值。

  2. 如果现在想用Y指导X,第一个参数要传X,第二个要传Y。就是被指导的放在前面,然后求相应的概率和对数概率就可以了。例如DKL(P∣∣Q)D_{KL}(P||Q)DKL​(P∣∣Q),就是求Q的对数概率,求P的概率。

  3. reduce bool类型,可选{‘True’, ‘False’},默认值:TRUE.
    默认情况下,根据大小的平均值,对每个小批处理的损失进行平均或求和。当reduce为False时,返回每个批元素的损失值,并忽略平均大小。

  4. reduction string类型,可选{‘None’, ‘batchmean’, ‘sum’, ‘mean’},默认值:‘mean’
    ‘none’:不进行压缩
    ‘batchmean’:输出的总和除以batchsize
    ‘sum’:输出的总和
    ‘mean’:输出的输出除以输出中的元素数量

5.2 简单代码

import torch
import torch.nn.functional as F# 定义两个矩阵
P = torch.tensor([0.25] * 4 + [0])
Q = torch.tensor([0.2] * 5)print(P, Q)# 因为要用P指导Q,所以求Q的对数概率,P的概率
log_Q = F.log_softmax(P, dim=-1)
_P = F.softmax(Q, dim=-1)kl_sum = F.kl_div(log_Q, _P, reduction='sum')print("SUM:", kl_sum)

-输出如下:

tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.0000]) tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
SUM: tensor(0.2231)

6 KL散度在R- Drop中的应用

6.1 什么是Dropout?

  • 下面这篇文章我已经详细介绍过了,感兴趣的小伙伴可以移步下方链接:

https://blog.csdn.net/weixin_42521185/article/details/124359544

  • Dropout的使用场景一般是神经元较多的神经网络,因为神经元较多可能会使得预测结果过拟合,采用Dropout随机丢弃一些神经元,可以很好的处理这个问题。

  • 但是,Dropout的缺点也很明显:dropout是随机丢弃,所以训练和预测阶段使用的神经网络不同!

6.2 引入 R-drop

  • step1:首先计算LLKL_{LK}LLK​。
    不同模型训练经过不同的dropout可能会产生较大的差异,而我们这里通过KL散度将这些差异缩小,从而能够达到预测较准的情况。
  • step2:新的损失函数

6.3 在R-Drop的代码

from tensorflow.keras.losses import kullback_leibler_divergence as kld
def categorical_crossentropy_with_rdrop(y_true, y_pred,alpha=1):"""配合上述生成器的R-Drop Loss其实loss_kl的除以4,是为了在数量上对齐公式描述结果。"""loss_ce = K.sparse_categorical_crossentropy(y_true, y_pred)  # 原来的loss#这里调用K.Sparse,一部分是常规的交叉熵loss_kl = kld(y_pred[::2], y_pred[1::2]) + kld(y_pred[1::2], y_pred[::2])#另一部分是两个模型的对称KL散度return K.mean(loss_ce) + K.mean(loss_kl) / 4 * alpha
  • step1 首先调用交叉熵损失函数
loss_ce = K.sparse_categorical_crossentropy(y_true,y_pred)
  • step2 然后调用KL散度的内容

这里参考“唐僧爱吃唐僧肉的解释”


  • 本研究生已疲惫了…

KL Divergence ——衡量两个概率分布之间的差异相关推荐

  1. 衡量两个概率分布之间的差异性的指标

    衡量两个概率分布之间的差异性的指标 衡量两个概率分布之间的差异性的指标 KL散度(Kullback–Leibler divergence) JS散度(Jensen-Shannon divergence ...

  2. 如何查找两个列表之间的差异?

    1. 概述 查找相同数据类型的对象集合之间的差异是一项常见的编程任务.举个例子,假设我们有一份申请考试的学生名单和另一份通过考试的学生名单.这两张名单的区别会告诉我们那些没有通过考试的学生. 在Jav ...

  3. 「从 Windows 到 macOS」快速理顺两大系统之间的差异

    虽然从熟悉的平台转移到另一个陌生平台的做法一般不会经常发生,但如果你已经决定从 Windows 转移到 macOS,那么在踏入「新世界」的大门之前,或许这份「从 Windows 到 macOS」的入门 ...

  4. java如何比较两个date_在Java中,如何获得两个date之间的差异秒?

    不熟悉DateTime - 如果你有两个date,你可以调用getTime来得到毫秒,得到差异并除以1000.例如 Date d1 = ...; Date d2 = ...; long seconds ...

  5. 图片上两点之间的距离和两组图片之间的差异的关系

    制作两个矩阵A和D 用神经网络分类A和D,让x和y都是(0,1)之间的小数,则A与D的第三点A:x[2]=0,D:x[2]=a(x-y)的距离为 所以如果变化a对分类A和D有什么影响? 实验过程 制作 ...

  6. python两组数的差异_Python中两个日期之间的差异

    我尝试了上面larsmans发布的代码,但是有两个问题: 1)原样的代码将引发mauguerra提到的错误2)如果将代码更改为以下内容: ... d1 = d1.strftime("%Y-% ...

  7. git 对比两个commit 之间的差异

    git log 查看commit记录 git log --pretty=format:"%h %s" 查看commit记录并以commit_short_id commit_mess ...

  8. 如何快速找出找出两个数组中的_找出JavaScript中两个数组之间的差异

    LeetCode今天面临的挑战是在数组中查找所有消失的数字. 蛮力 我们的输入包括一个缺少数字的实际数组.我们想将该数组与相同长度的数组进行比较,其中没有遗漏的数字.所以如果给定的话[4,3,2,7, ...

  9. git查看两次提交之间的差异_如何在同一分支的两个不同提交之间区分同一文件?...

    如果已配置" difftool",则可以使用 git difftool revision_1:file_1 revision_2:file_2 示例:将文件的最后一次提交与同一分支 ...

  10. java 比较2个时间大小写_date - Java 8:计算两个LocalDateTime之间的差异

    Tapas Bose代码和Thomas代码存在一些问题. 如果时间差异为负,则数组获得负值. 例如,如果 LocalDateTime toDateTime = LocalDateTime.of(201 ...

最新文章

  1. Java 命名规范(非常全面)
  2. 优化程序性能的策略汇总
  3. 2018秋招面经:斗鱼、滴滴、百度、美团、小米、腾讯
  4. matlab 二维高斯滤波 傅里叶_光电图像处理 | 傅里叶变换(二)
  5. qtreewidget点击空白处时取消以选项_VUE+elementUI 点击页面空白处弹窗不隐藏
  6. OpenShift 4 - 容器应用备份和恢复
  7. Java-Runtime
  8. 网络编程之OSI七层协议
  9. vmware workstation 不可恢复错误 vcpu-0
  10. 用网速作为手机信号强度
  11. javascript:理解try...catch...finally
  12. 03_Snaker流程demo
  13. GPS北斗卫星时钟同步系统的原理和技术
  14. android 九宫格带删除,Android--选择多张图片,支持拖拽删除、排序、预览图片
  15. AI、量子计算引爆硬科技创新,雷鸣、王海峰、施尧耘等北大120周年论道信科最前沿...
  16. GPT-4王者加冕!读图做题性能炸天,凭自己就能考上斯坦福
  17. 视频直播源码_直播平台搭建_直播程序源码——技术架构解析
  18. 照度/感光度(Lux)
  19. 明日方舟公式计算机,明日方舟公开招募公式汇总
  20. SAP MM 供应商无英文名称,ME21N里却带出了英文名字?

热门文章

  1. 【数据结构】串(一)—— 串的基础知识
  2. Logistic-Sine-Cosine混沌映射(提供文献及Matlab代码)
  3. Rust之crate
  4. CF1467B Hills And Valleys 题解
  5. 戴尔卡耐基《人性的弱点》
  6. 解决Navicat远程服务器2013-Lost connection to MYSQL server at 'waitting for initial communication packet'
  7. evolution ubuntu邮箱_Ubuntu evolution 邮件客户端配置详解(图)
  8. 【笔记整理】通信原理第九章复习——线性分组码
  9. linux之OPERATION(运维)一
  10. 修正蹩脚的Scratch汉化