KLDivLoss

作用:

用于连续分布的距离度量;并且对离散采用的连续输出空间分布进行回归通常很有用;用label_smoothing就采用这个;

公式:

公式理解:

p(x)是真实分布,q(x)是拟合分布;实际计算时;通常p(x)作为target,只是概率分布;而xn则是把输出做了LogSoftmax计算;即把概率分布映射到log空间;所以

K-L散度值实际是看log(p(x))-log(q(x))的差值,差值越小,说明拟合越相近。

pytorch使用:

当前版本torch(1.3.1)要想获得真正的KL散度;设置:

reduce=False;size_average=False

(reduce默认也是True,返回所有元素loss的和;size_average=默认是True,是对batch中每个元素进行求平均,当为False时,返回各样本各维度的loss之和;

因为reduce为False会忽略size_average参数,所以其实只需要把reduce=False即可)

代码验证:

import torchimport torch.nn as nnimport numpy as np# -----------------------------------  KLDiv lossloss_f = nn.KLDivLoss(size_average=False, reduce=False)loss_f_mean = nn.KLDivLoss(size_average=True, reduce=True)# 生成网络输出 以及 目标输出output = torch.from_numpy(np.array([[0.1132, 0.5477, 0.3390]])).float()output.requires_grad = Truetarget = torch.from_numpy(np.array([[0.8541, 0.0511, 0.0947]])).float()loss_1 = loss_f(output, target)loss_mean = loss_f_mean(output, target)print('\nloss: ', loss_1)print('\nloss_mean: ', loss_mean)# 熟悉计算公式,手动计算样本的第一个元素的loss,注意这里只有一个样本,是 element-wise计算的output = output[0].detach().numpy()output_1 = output[0]  # 第一个样本的第一个元素target_1 = target[0][0].numpy()loss_1 = target_1 * (np.log(target_1) - output_1)print('\n第一个样本第一个元素的loss:', loss_1

链接:https://github.com/TingsongYu/PyTorch_Tutorial/blob/master/Code/3_optimizer/3_1_lossFunction/6_KLDivLoss.py

Note:

D(p||q) ≠ D(q||p) ,不具有对称性所以不能称之为K-L距离

nn.KLDivLoss相关推荐

  1. loss函数之KLDivLoss

    KL散度 KL散度,又叫相对熵,用于衡量两个分布(离散分布和连续分布)之间的距离. 设p(x)p(x)p(x) .q(x)q(x)q(x) 是离散随机变量XXX的两个概率分布,则ppp 对qqq 的K ...

  2. torch.nn、(二)

    参考 torch.nn.(二) - 云+社区 - 腾讯云 目录 Recurrent layers RNN LSTM GRU RNNCell LSTMCell GRUCell Transformer l ...

  3. Pytorch之KLDivLoss

    理论基础 KL散度:衡量两个概率分布之间的相似性,其值越小,概率分布越接近.公式表达如下. DKL(P∥Q)=∑i=1N[p(xi)log⁡p(xi)−p(xi)log⁡q(xi)]=∑i=1N[p( ...

  4. [Pytorch系列-28]:神经网络基础 - torch.nn模块功能列表

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  5. 25_PyTorch的十九个损失函数(L1Loss、MSELoss、CrossEntropyLoss 、CTCLoss、NLLLoss、PoissonNLLLoss 、KLDivLoss等)

    1.20.PyTorch的十九个损失函数 1.20.1.L1Loss(L1范数损失) 1.20.2.MSELoss(均方误差损失) 1.20.3.CrossEntropyLoss (交叉熵损失) 1. ...

  6. Pytorch入门实战(5):基于nn.Transformer实现机器翻译(英译汉)

    使用Google Colab运行(open In Colab) 源码地址 文章目录 本文涉及知识点 本文内容 环境配置 数据预处理 文本分词与构造词典 Dataset and Dataloader 模 ...

  7. PyTorch的十七个损失函数

    20220113 选损失函数的标准:能使得真实值和预测值越相近的时候总损失越小 20220303 机器学习大牛是如何选择回归损失函数的? MSE,MAE,huber loss 20210925 交叉熵 ...

  8. 实操教程|Pytorch常用损失函数拆解

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 小新 来源 | https://lhyxx.top 编辑 ...

  9. 损失函数理解汇总,结合PyTorch1.7和TensorFlow2

    作者丨DengBoCong@知乎 来源丨https://zhuanlan.zhihu.com/p/271911178 编辑丨极市平台 本文仅用于学术分享,如有侵权,请联系后台作删文处理. 本文打算讨论 ...

最新文章

  1. Wrong permissions on configuration file, should not be world writable
  2. python调用shell命令-用Python调用Shell命令
  3. JavaScript精进篇
  4. 树莓派:VNC远程控制
  5. Git remote 修改源
  6. iphonex如何关机_iphonex常用手势操作有哪些 iphonex常用手势操作介绍【详解】
  7. java guice_java – Guice:如何为一个类型获得多个@Provides?
  8. Oracle-洛总脚本--查询相关慢SQL
  9. ubuntu18.04安装mysql8
  10. 查询使用NoLock
  11. 数据库 读锁(共享锁)、 写锁(排他锁)
  12. android逆向工程dex2jar使用
  13. PHP实现文件下载两种方式(a标签和header标签)
  14. CPU性能的三大主要参数
  15. 【xtku】铜雀台张馨予xp主题_8.2
  16. [教程]安装青鸟云Web服务器
  17. No changes detected报错解决方案
  18. SQLserver未发现数据源名称并且未指定默认驱动程序
  19. PADS-Layout学习笔记
  20. Windows SVN迁移实操笔记

热门文章

  1. 看PG10文档的笔记
  2. 【路径规划】基于蚁群算法求解机器人栅格地图路径规划matlab代码
  3. 仿QQ多级折叠、展开菜单,三级下拉导航
  4. html5绘制好看的时钟,利用纯html5绘制出来的一款非常漂亮的时钟
  5. java四则运算思路_java四则运算
  6. 酷炫一款动态背景(HTML +js canvas)
  7. 基于springboot的化妆品美妆销售商城网站
  8. SRS(简单实时视频服务) 笔记(3)- 配置文件和Http回调
  9. 山西宗教文化漫谈(四)——云冈:东方艺术宝库
  10. 基于RK3568开源鸿蒙的助农金融服务终端设计方案