nn.KLDivLoss
KLDivLoss
作用:
用于连续分布的距离度量;并且对离散采用的连续输出空间分布进行回归通常很有用;用label_smoothing就采用这个;
公式:
公式理解:
p(x)是真实分布,q(x)是拟合分布;实际计算时;通常p(x)作为target,只是概率分布;而xn则是把输出做了LogSoftmax计算;即把概率分布映射到log空间;所以
K-L散度值实际是看log(p(x))-log(q(x))的差值,差值越小,说明拟合越相近。
pytorch使用:
当前版本torch(1.3.1)要想获得真正的KL散度;设置:
reduce=False;size_average=False
(reduce默认也是True,返回所有元素loss的和;size_average=默认是True,是对batch中每个元素进行求平均,当为False时,返回各样本各维度的loss之和;
因为reduce为False会忽略size_average参数,所以其实只需要把reduce=False即可)
代码验证:
import torchimport torch.nn as nnimport numpy as np# ----------------------------------- KLDiv lossloss_f = nn.KLDivLoss(size_average=False, reduce=False)loss_f_mean = nn.KLDivLoss(size_average=True, reduce=True)# 生成网络输出 以及 目标输出output = torch.from_numpy(np.array([[0.1132, 0.5477, 0.3390]])).float()output.requires_grad = Truetarget = torch.from_numpy(np.array([[0.8541, 0.0511, 0.0947]])).float()loss_1 = loss_f(output, target)loss_mean = loss_f_mean(output, target)print('\nloss: ', loss_1)print('\nloss_mean: ', loss_mean)# 熟悉计算公式,手动计算样本的第一个元素的loss,注意这里只有一个样本,是 element-wise计算的output = output[0].detach().numpy()output_1 = output[0] # 第一个样本的第一个元素target_1 = target[0][0].numpy()loss_1 = target_1 * (np.log(target_1) - output_1)print('\n第一个样本第一个元素的loss:', loss_1
链接:https://github.com/TingsongYu/PyTorch_Tutorial/blob/master/Code/3_optimizer/3_1_lossFunction/6_KLDivLoss.py
Note:
D(p||q) ≠ D(q||p) ,不具有对称性所以不能称之为K-L距离
nn.KLDivLoss相关推荐
- loss函数之KLDivLoss
KL散度 KL散度,又叫相对熵,用于衡量两个分布(离散分布和连续分布)之间的距离. 设p(x)p(x)p(x) .q(x)q(x)q(x) 是离散随机变量XXX的两个概率分布,则ppp 对qqq 的K ...
- torch.nn、(二)
参考 torch.nn.(二) - 云+社区 - 腾讯云 目录 Recurrent layers RNN LSTM GRU RNNCell LSTMCell GRUCell Transformer l ...
- Pytorch之KLDivLoss
理论基础 KL散度:衡量两个概率分布之间的相似性,其值越小,概率分布越接近.公式表达如下. DKL(P∥Q)=∑i=1N[p(xi)logp(xi)−p(xi)logq(xi)]=∑i=1N[p( ...
- [Pytorch系列-28]:神经网络基础 - torch.nn模块功能列表
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...
- 25_PyTorch的十九个损失函数(L1Loss、MSELoss、CrossEntropyLoss 、CTCLoss、NLLLoss、PoissonNLLLoss 、KLDivLoss等)
1.20.PyTorch的十九个损失函数 1.20.1.L1Loss(L1范数损失) 1.20.2.MSELoss(均方误差损失) 1.20.3.CrossEntropyLoss (交叉熵损失) 1. ...
- Pytorch入门实战(5):基于nn.Transformer实现机器翻译(英译汉)
使用Google Colab运行(open In Colab) 源码地址 文章目录 本文涉及知识点 本文内容 环境配置 数据预处理 文本分词与构造词典 Dataset and Dataloader 模 ...
- PyTorch的十七个损失函数
20220113 选损失函数的标准:能使得真实值和预测值越相近的时候总损失越小 20220303 机器学习大牛是如何选择回归损失函数的? MSE,MAE,huber loss 20210925 交叉熵 ...
- 实操教程|Pytorch常用损失函数拆解
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 小新 来源 | https://lhyxx.top 编辑 ...
- 损失函数理解汇总,结合PyTorch1.7和TensorFlow2
作者丨DengBoCong@知乎 来源丨https://zhuanlan.zhihu.com/p/271911178 编辑丨极市平台 本文仅用于学术分享,如有侵权,请联系后台作删文处理. 本文打算讨论 ...
最新文章
- Wrong permissions on configuration file, should not be world writable
- python调用shell命令-用Python调用Shell命令
- JavaScript精进篇
- 树莓派:VNC远程控制
- Git remote 修改源
- iphonex如何关机_iphonex常用手势操作有哪些 iphonex常用手势操作介绍【详解】
- java guice_java – Guice:如何为一个类型获得多个@Provides?
- Oracle-洛总脚本--查询相关慢SQL
- ubuntu18.04安装mysql8
- 查询使用NoLock
- 数据库 读锁(共享锁)、 写锁(排他锁)
- android逆向工程dex2jar使用
- PHP实现文件下载两种方式(a标签和header标签)
- CPU性能的三大主要参数
- 【xtku】铜雀台张馨予xp主题_8.2
- [教程]安装青鸟云Web服务器
- No changes detected报错解决方案
- SQLserver未发现数据源名称并且未指定默认驱动程序
- PADS-Layout学习笔记
- Windows SVN迁移实操笔记