理论基础

KL散度:衡量两个概率分布之间的相似性,其值越小,概率分布越接近。公式表达如下。

DKL(P∥Q)=∑i=1N[p(xi)log⁡p(xi)−p(xi)log⁡q(xi)]=∑i=1N[p(xi)log⁡p(xi)log⁡q(xi)]\begin{aligned} D_{K L}(P \| Q) & =\sum_{i=1}^{N}\left[p\left(x_{i}\right) \log p\left(x_{i}\right)-p\left(x_{i}\right) \log q\left(x_{i}\right)\right] \\ & = \sum_{i=1}^{N}\left[p\left(x_{i}\right) \frac{\log p\left(x_{i}\right)}{\log q\left(x_{i}\right)} \right] \end{aligned} DKL​(P∥Q)​=i=1∑N​[p(xi​)logp(xi​)−p(xi​)logq(xi​)]=i=1∑N​[p(xi​)logq(xi​)logp(xi​)​]​

注:对于两个概率分布 PPP 、QQQ,PPP 为真实事件的概率分布,QQQ 为随机事件拟合出来的该事件的概率分布,即 DKL(P∥Q)D_{K L}(P \| Q)DKL​(P∥Q) 表示使用 PPP 来拟合 QQQ, 或者说使用 QQQ 来指导 PPP。

实现

import torch
import torch.nn as nn
import torch.nn.functional as F# 预测值
input = torch.tensor([0.7, .1, .2], requires_grad=True)  # dim=0 每一行为一个样本# 真实值
target = torch.tensor([.2, .5, .3])# 计算KL散度
# 方式1
kl_loss = nn.KLDivLoss(reduction="batchmean")
output = kl_loss(F.log_softmax(input, dim=0), F.softmax(target, dim=0))
print(output)# 方式2
print(F.kl_div(F.log_softmax(input, dim=0), F.softmax(target, dim=0), reduction="batchmean"))# 方式3
my_kl_loss = F.softmax(target, dim=0) * (torch.log(F.softmax(target, dim=0)) - F.log_softmax(input, dim=0))
my_kl_loss = my_kl_loss.mean()
print(my_kl_loss)# 方式4
my_kl_loss2 = F.softmax(target, dim=0) * (F.log_softmax(target, dim=0) - F.log_softmax(input, dim=0))
my_kl_loss2 = my_kl_loss2.mean()
print(my_kl_loss2)# ----------------输出--------------------
# tensor(0.0239, grad_fn=<DivBackward0>)
# tensor(0.0239, grad_fn=<MeanBackward0>)
# tensor(0.0239, grad_fn=<MeanBackward0>)
# tensor(0.0239, grad_fn=<DivBackward0>)
# ----------------------------------------

几个要点

  1. KL散度的原理
  2. KL实现为什么要做log和softmax
  3. 上溢出和下溢出的情况
  4. 在pytorch的log函数中,默认是以 eee 为底数的

参考:

  • loss函数之KLDivLoss
  • KL散度理解以及使用pytorch计算KL散度
  • 有效防止softmax计算时上溢出(overflow)和下溢出(underflow)的方法
  • 交叉熵

Pytorch之KLDivLoss相关推荐

  1. 损失函数理解汇总,结合PyTorch和TensorFlow2

    点击下方标题,迅速定位到你感兴趣的内容 前言 交叉熵损失(CrossEntropyLoss) KL散度 平均绝对误差(L1范数损失) 均方误差损失(L2范数损失) Hinge loss 余弦相似度 前 ...

  2. 损失函数理解汇总,结合PyTorch1.7和TensorFlow2

    作者丨DengBoCong@知乎 来源丨https://zhuanlan.zhihu.com/p/271911178 编辑丨极市平台 本文仅用于学术分享,如有侵权,请联系后台作删文处理. 本文打算讨论 ...

  3. 收藏 | 损失函数理解汇总,结合PyTorch1.7和TensorFlow2

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:极市平台 AI博士笔记系列推荐 周志华<机器学习> ...

  4. PyTorch的十七个损失函数

    20220113 选损失函数的标准:能使得真实值和预测值越相近的时候总损失越小 20220303 机器学习大牛是如何选择回归损失函数的? MSE,MAE,huber loss 20210925 交叉熵 ...

  5. 实操教程|Pytorch常用损失函数拆解

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 小新 来源 | https://lhyxx.top 编辑 ...

  6. Pytorch Document学习笔记

    Pytorch Document学习笔记 Pytorch Document学习笔记 1. 网络层 1.1 torch.nn.Conv2d 1.2 torch.nn.MaxPool2d / torch. ...

  7. 一文看懂Transformer内部原理(含PyTorch实现)

    Transformer注解及PyTorch实现 原文:http://nlp.seas.harvard.edu/2018/04/03/attention.html 作者:Alexander Rush 转 ...

  8. Step-by-step to Transformer:深入解析工作原理(以Pytorch机器翻译为例)

    大家好,我是青青山螺应如是,大家可以叫我青青,工作之余是一名独立摄影师.喜欢美食.旅行.看展,偶尔整理下NLP学习笔记,不管技术文还是生活随感,都会分享本人摄影作品,希望文艺的技术青年能够喜欢~~如果 ...

  9. PyTorch框架学习十二——损失函数

    PyTorch框架学习十二--损失函数 一.损失函数的作用 二.18种常见损失函数简述 1.L1Loss(MAE) 2.MSELoss 3.SmoothL1Loss 4.交叉熵CrossEntropy ...

最新文章

  1. python3spark文本分类_如何用Spark深度集成Tensorflow实现文本分类?
  2. 求任意大小矩阵的转置矩阵
  3. django中,kindeditor存到数据库的html,前台html标签被自动转义的解决办法
  4. python assert的作用
  5. SpaceVim 语言模块 lua
  6. java报错 日志_java 日志报错
  7. 面试精讲之面试考点及大厂真题 - 分布式专栏 11 Redis热点key大Value解决方案
  8. RabbitMQ的应用场景以及基本原理简介
  9. 基于JAVA+SpringMVC+MYSQL的小说管理系统
  10. 阿里巴巴Java开发 之 编程规约
  11. UI(用户界面)设计规则和规范
  12. ToolBar控件详解
  13. Creating Your First Mac App--Implementing Action Methods 实现动作方法
  14. 教你彻底禁止暴风影音后门进程自己启动
  15. mysql八大知识点_MySQL索引八大法则之上篇
  16. 老调长谈的Flex 4.6 可视组件的生命周期
  17. char * 与char []区别总结
  18. Chapter16/17-项目2:数据可视化
  19. 天啦噜,项目上使用InputStream,我被坑了一把!
  20. 解决“ImportError: cannot import name ‘_validate_lengths‘”问题

热门文章

  1. 腾讯云直播代码 java_JAVA 对接腾讯云直播的实现
  2. word中审阅和修订、批注
  3. (一)通用定时器的相关介绍
  4. asp.net办公自动化OA系统
  5. Zuul1与Spring Cloud Gateway的区别
  6. 【Qt实战派学习群】 建立啦!
  7. 常用cursor光标说明
  8. 基于JAVA黑白图片和上色处理系统(Springboot框架+AI人工智能) 开题报告
  9. 【广告系列一】广告相关名词 CTR/CVR/eCPM...
  10. spark常用RDD算子 - take(),takeOrdered(),top(),first()