PyTorch计算KL散度详解

最近在进行方法设计时,需要度量分布之间的差异,由于样本间分布具有相似性,首先想到了便于实现的KL-Divergence,使用PyTorch中的内置方法时,踩了不少坑,在这里详细记录一下。

简介

首先简单介绍一下KL散度(具体的可以在各种技术博客看到讲解,我这里不做重点讨论)。
从名称可以看出来,它并不是严格意义上的距离(所以才叫做散度~),原因是它并不满足距离的对称性,为了弥补这种缺陷,出现了JS散度(这就是另一个故事了…)
我们先来看一下KL散度的形式:
DKL(P∣∣Q)=∑i=1Npilog⁡piqi=∑i=1Npi∗(log⁡pi−log⁡qi)DKL(P||Q) = \sum_{i=1}^{N} {p_i\log{\frac{p_i}{q_i}}} = \sum_{i=1}^{N} { p_i*(\log{p_i}-\log{q_i})} DKL(P∣∣Q)=i=1∑N​pi​logqi​pi​​=i=1∑N​pi​∗(logpi​−logqi​)

手动代码实现

可以看到,KL散度形式上还是比较直观的,我们先手撸一个试试:
这里我们随机设定两个随机变量P和Q

import torch
P = torch.tensor([0.4, 0.6])
Q = torch.tensor([0.3, 0.7])

快速算一下答案:
DKL(P∣∣Q)=0.4∗(log⁡0.4−log⁡0.3)+0.6∗(log⁡0.6−log⁡0.7)≈0.0226\begin{aligned} DKL(P||Q) &= 0.4* (\log{0.4} - \log{0.3}) + 0.6 * (\log{0.6} - \log{0.7}) \\ & \approx 0.0226 \end{aligned} DKL(P∣∣Q)​=0.4∗(log0.4−log0.3)+0.6∗(log0.6−log0.7)≈0.0226​

数值计算实现版:

def DKL(_p, _q):"""calculate the KL divergence between _p and _q"""return  torch.sum(_p * (_p.log() - _q.log()), dim=-1)divergence = DKL(P, Q)
print(divergence)
# tensor(0.0226)

上面的代码中,之所以求和时dim=-1是因为我在使用的过程中,考虑到有时是对batch中feature进行计算,所以这里只对特征维度进行求和。
接下来,就到了今天介绍的主角~

torch代码实现

torch中提供有两种不同的api用于计算KL散度,分别是torch.nn.functional.kl_div()torch.nn.KLDivLoss(),两者计算效果类似,区别无非是直接计算和作为损失函数类。

先介绍一下torch.nn.functional.kl_div()

注意,该方法的inputtarget与KL(P∣∣Q)KL(P||Q)KL(P∣∣Q)中PPP、QQQ的位置正好相反,从参数名称就可以看出来(target为目标分布PPP,input为待度量分布QQQ)。为了防止指代混乱,我后面统一用PPP、QQQ指代targetinput

这里重点关注几个对计算结果有影响的参数:

reduction:该参数是结果应该以什么规约形式进行呈现,sum即为我们定义式中的效果,batchmean:按照batch大小求平均,mean:按照元素个数进行求平均

再看看log_target的效果:

if not log_target: # defaultloss_pointwise = target * (target.log() - input)
else:loss_pointwise = target.exp() * (target - input)

也就是说,如果log_target=False,此时计算方式为
res=P∗(log⁡P−Q)res = P * ( \log{P}-Q) res=P∗(logP−Q)
这和我们熟悉的定义式的计算方式是不同的,如果想要和定义式的效果一致,需要对input取对数操作(在官方文档中也有提及,建议将input映射到对数空间,防止数值下溢):

import torch.nn.Functional as Fprint(F.kl_div(Q.log(), P, reduction='sum'))
#tensor(0.0226)

而当log_target=True时,此时的计算方式变为
res=eP∗(P−Q)res=e^{P}*(P-Q) res=eP∗(P−Q)
也就是说,此时我们对PPP取对数操作即可得到定义式的效果:

print(F.kl_div(Q.log(), P.log(), log_target=True, reduction='sum'))
#tensor(0.0226)

这样设计的目的也是为了防止数值下溢。

torch.nn.KLDivLoss()的参数列表与torch.nn.functional.kl_div()类似,这里就不过多赘述。

总结

总的来说,当需要计算KL散度时,默认情况下需要对input取对数,并设置reduction='sum'方能得到与定义式相同的结果:

divergence = F.kl_div(Q.log(), P, reduction='sum')

由于我们度量的是两个分布的差异,因此通常需要对输入进行softmax归一化(如果已经归一化则无需此操作):

divergence = F.kl_div(Q.softmax(-1).log(), P.softmax(-1), reduction='sum')

PyTorch中计算KL散度详解相关推荐

  1. pytorch中的kl散度,为什么kl散度是负数?

    F.kl_div()或者nn.KLDivLoss()是pytroch中计算kl散度的函数,它的用法有很多需要注意的细节. 输入 第一个参数传入的是一个对数概率矩阵,第二个参数传入的是概率矩阵.并且因为 ...

  2. 机器学习:KL散度详解

    KL 散度,是一个用来衡量两个概率分布的相似性的一个度量指标. 我们知道,现实世界里的任何观察都可以看成表示成信息和数据,一般来说,我们无法获取数据的总体,我们只能拿到数据的部分样本,根据数据的部分样 ...

  3. pytorch中的卷积操作详解

    首先说下pytorch中的Tensor通道排列顺序是:[batch, channel, height, width] 我们常用的卷积(Conv2d)在pytorch中对应的函数是: torch.nn. ...

  4. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  5. PyTorch中的matmul函数详解

    PyTorch中的两个张量的乘法可以分为两种: 两个张量对应的元素相乘(element-wise),在PyTorch中可以通过torch.mul函数(或者∗*∗运算符)实现 两个张量矩阵相乘(Matr ...

  6. PyTorch中squeeze()和unsqueeze()详解

    pytorch中squeeze()和unsqueeze()作用 squeeze() squeeze() 用于在张量的指定维度插入新的维度 (为1) 得到维度提升的张量. unsqueeze() uns ...

  7. PyTorch中的topk函数详解

    听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index. 用法 torch.topk(input, k, dim=None, largest=True, sor ...

  8. pytorch中scatter()、scatter_()详解

    scatter().scatter_() scatter() 和 scatter_() 的作用一样. 不同之处在于 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会在 ...

  9. python variable_PyTorch中的Variable变量详解

    一.了解Variable 顾名思义,Variable就是 变量 的意思.实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性. 具体来说,在pyt ...

最新文章

  1. 【C++】多线程(链式、循环队列)实现生产者消费者模式
  2. 一个比较保守的404页面
  3. ce修改器传奇刷元宝_真原始传奇刷元宝方法 不封号刷元宝技巧
  4. pytorch 模型可视化_PyTorch Tips(FLOPs计算/参数量/计算图可视化/相关性分析)
  5. 建模大师怎么安装到revit中_全面解析Revit软件在装配式建筑项目中的建模思路...
  6. Lotus Notes 和 Crystal Report 的整合應用
  7. python爬虫----(4. scrapy框架,官方文档以及例子)
  8. python tkinter 弹窗_tkinter主窗口和子窗口同时弹出该怎么办?
  9. Guns 查询列表_入门试炼03
  10. 使用实体框架核心和C#创建具有Dotnet核心的自定义Web爬虫程序
  11. JAVA程序设计第十版第七章_java程序设计第七章答案
  12. 【linux】gcc命令
  13. 顺序栈基本操作的C语言实现(含全部代码实现)--- 数据结构之顺序栈
  14. 近世代数——Part2 群:基础与子群 课后习题
  15. 白鹭引擎 android9,【安卓】手把手教你Egret引擎一键发布华为快游戏
  16. Kaggle泰坦尼克号提升准确率探索
  17. [SSL_CHX][2021-08-20]幸运数字们
  18. 《Adobe Premiere Pro CS4经典教程》——1.7 Adobe Premiere Pro工作区
  19. [NodeJS] Jest 环境下 Axios 请求报错: Cross origin http://localhost forbidden
  20. AliOS Things的SDK ESP8266 连接阿里生活物联网平台 配网失败解决方案

热门文章

  1. 全国统计专业技术高级资格考试大纲(2021年)
  2. POI报表入门及百万数据报表导出和读取
  3. 服务器数据挂载与解挂
  4. 如何在微信公众号编辑器发布免费好看的排版内容
  5. DB2复制表结构及数据
  6. HTML5仿微信公众号界面
  7. C# 倍福ADS的正确打开方式,使用AdsRemote组件优雅的通过ADS通讯
  8. 排版工具:gnu indent 【转】
  9. 如何实现一个简单的网络帧同步方案
  10. 基于UDP的帧同步网络方案(基础)