Pytorch以及tensorflow中KLdivergence的计算
1. KL divergence是什么
KL 散度是一个距离衡量指标,衡量的是两个概率分布之间的差异。
y p r e d y_{pred} ypred指的是模型的输出的预测概率,形如[0.35,0.25,0.4]; y t r u e y_{true} ytrue是一个one-hot形式的真实概率,形如[0,1,0]。
神经网络的训练目标是使得输出的预测概率尽可能接近真实的概率分布。
KL divergence loss的计算公式为: K L ( y p r e d , y t r u e ) = y t r u e l o g ( y t r u e y p r e d ) KL(y_{pred},y_{true}) = y_{true}log(\frac{y_{true}}{y_{pred}} ) KL(ypred,ytrue)=ytruelog(ypredytrue)
2. logits是什么?
logits是几率,神经网络最后一层的输出如果不经过激活函数,比如softmax的话,那么这个输出就叫做logits。
logits经过softamx激活函数得到概率值,比如:logits = [4,3.9,1],经过softmax激活后,得到 probability = [0.5116072 0.46292138 0.02547142]
p i = e i ∑ j e j p_{i}=\frac{e^{i}}{\sum_{j} e^{j}} pi=∑jejei,
比如上面probability的每一个元素的计算过程为:
e 4 e 4 + e 3.9 + e 1 = 0.5116072 , e 3.9 e 4 + e 3.9 + e 1 = 0.46292138 , e 1 e 4 + e 3.9 + e 1 = 0.0254714 \frac{e^{4}}{e^{4}+e^{3.9}+e^{1}} = 0.5116072, \frac{e^{3.9}}{e^{4}+e^{3.9}+e^{1}} = 0.46292138,\frac{e^{1}}{e^{4}+e^{3.9}+e^{1}}= 0.0254714 e4+e3.9+e1e4=0.5116072,e4+e3.9+e1e3.9=0.46292138,e4+e3.9+e1e1=0.0254714
3.使用Pytorch计算KL divergence loss
Pytorch计算KL divergence loss的 [官方文档在这里]。(https://pytorch.org/docs/master/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss)
import torch
import torch.nn as nn
from torch.nn import functional as F
p_logits = torch.tensor([4,3.9,1],dtype = torch.float32)
p = F.log_softmax(p_logits,dim=-1)q_logits = torch.tensor([5,4,0.1],dtype = torch.float32)
q = F.softmax(q_logits,dim=-1)
q_soft = F.softmax(q_logits/5,dim=-1)loss_1 = nn.KLDivLoss(reduction='sum')(F.log_softmax(p_logits ,dim=0), F.softmax(q_logits,dim=0))
loss_2 = nn.KLDivLoss(reduction='sum')(p,q)soft_loss1 = nn.KLDivLoss(reduction='sum')(F.log_softmax(p_logits ,dim=0), F.softmax(q_logits/5,dim=0))
soft_loss2 = nn.KLDivLoss(reduction='sum')(p,q_soft)import numpy as np
print(f'student predict: {np.exp(p.numpy())}')
print(f'q target: {q.numpy()}')
print(f'q soft probility: {q_soft.numpy()}')print(f'loss_1: {loss_1}')
print(f'loss_2: {loss_2}')print(f'soft loss1: {soft_loss1}')
print(f'soft loss2: {soft_loss2}')
打印一下输出可以看到:
student predict: [0.5116072 0.46292138 0.02547142]
q target: [0.7271004 0.2674853 0.00541441]
q soft probility: [0.4557798 0.37316093 0.1710592 ]
loss_1: 0.10048329085111618
loss_2: 0.10048329085111618
soft loss1: 0.1926761120557785
soft loss2: 0.1926761120557785
4.使用Tensorflow计算KL divergence loss
import tensorflow as tf
import numpy as np
onehot_labels = tf.nn.softmax(q_logits)
logits = [0.5116072, 0.46292138, 0.02547142]tf_loss1 = tf.keras.losses.KLDivergence()(tf.nn.softmax(q_logits),tf.nn.softmax(p_logits))
print(tf_loss1.numpy())
tensorflow loss: 0.1004832461476326
可以看到,pytorch以及tensorflow中计算的结果几乎是一致的。
Pytorch以及tensorflow中KLdivergence的计算相关推荐
- 深度学习PyTorch,TensorFlow中GPU利用率较低,CPU利用率很低,且模型训练速度很慢的问题总结与分析
在深度学习模型训练过程中,在服务器端或者本地pc端,输入nvidia-smi来观察显卡的GPU内存占用率(Memory-Usage),显卡的GPU利用率(GPU-util),然后采用top来查看CPU ...
- 深度学习PyTorch,TensorFlow中GPU利用率较低,使用率周期性变化的问题
在用tensorflow训练神经网络时,发现训练迭代的速度时而快时而慢,监督的GPU使用率也是周期性变化,通过了解,发现原因是: GPU在等待CPU读取,预处理,并传输数据过来,因此要提高GPU的使用 ...
- pytorch和tensorflow中实现SMU激活函数
在Pytorch中实现SMU激活函数 本文代码来源于githubuSMU源码链接 # coding=utf-8import torch from torch import nnclass SMU(nn ...
- 深度学习PyTorch、TensorFlow中GPU利用率与内存占用率很低的问题
上周,在一个使用Pytorch搭建的目标训练项目中,训练时,通过使用命令行执行NVIDIA-SMI(仅支持英伟达显卡)命令发现GPU的利用率基本一直停留在0%,并且显存占用率也较低.CSDN上有一篇分 ...
- 编写同时在PyTorch和Tensorflow上工作的代码
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 ❝ "库开发人员不再需要在框架之间进行选择." ...
- 【深度学习】编写同时在PyTorch和Tensorflow上工作的代码
作者 | Ram Sagar 编译 | VK 来源 | Analytics In Diamag ❝ "库开发人员不再需要在框架之间进行选择." ❞ 来自德国图宾根人工智能中心的研究 ...
- Pytorch和Tensorflow在10000*1000数据规模线性回归算法中的运算速度对比
Pytorch和Tensorflow在10000*1000数据规模线性回归算法中的运算速度对比 因为在学习人工智能相关知识,于是将学习过程与程序放在这里,希望对大家有帮助,共同学习,共同进步(不喜勿喷 ...
- lstm中look_back的大小选择_[Pytorch和Tensorflow对比(二)]:LSTM
1. LSTM网络输入和输出的区别 1.1 单LSTM Tensorflow # state_is_tuple表示输出的state是 (c_state, m_state)形式的tuple Pytorc ...
- 掌握深度学习,为什么要用PyTorch、TensorFlow框架?
作者 | Martin Heller 译者 | 弯月 责编 | 屠敏 来源 | CSDN(ID:CSDNnews) [导读]如果你需要深度学习模型,那么 PyTorch 和 TensorFlow 都是 ...
最新文章
- 开源MongoDB管理工具MongoCola1.20 发布 离开IBM GDC的最后一个版本
- linux安装软件命令1003无标题,linux系统安装OFED(infiniband)
- IOS开发基础知识--碎片32
- 京东产品负责人:数据如何高效驱动供应链?
- (转)Unity3d UnityEditor编辑器定制和开发插件
- Redis高效性探索--线程IO模型,通信协议
- javaml_一些基于Java的AI框架:Encog,JavaML,Weka
- angularjs 表单验证 和 页面初始化闪烁
- 计算机网络的功能分布计算,网络中心的分布计算(转帖)
- 今日头条ocpm计费规则_今日头条广告投放推广新产品选OCPM还是CPA好?是新的计费方式吗?...
- c语言中L''作用,L/C问题: 请问L/C上的49:Confirmation Instruction 有什么作用啊[1]
- html5 video cache,手机里的videoCache文件夹什么意思?可以删除吗?
- Ubuntu下制作deb包的方法详解
- 关于Pidgin和webqq
- UPC2022/3/18 晚训练赛补题
- Bootstrap框架——栅格系统
- android raw相机,最高大上的安卓相机App?专业拍摄ProShot
- mysql 某个日期加七天_Mysql时间操作(当天,昨天,7天,30天,半年,全年,季度)...
- YOLOv3 物体识别显示中文标签
- 每天坚持“踮脚尖”,时间久了,身体会收获什么?每天踮多久?
热门文章
- std::lock_guard使用案例及常用系统函数调用案例
- SEM托管技术人员如何精准把握客户需求
- AMD发运首批“推土机”架构处理器
- 广州虚拟动力:虚拟主播动作捕捉设备原理是什么呢?
- 婚庆公司七夕情人节策划活动PPT模板
- android 系统时间改变颜色吗,安卓手机通知栏时间、日期、通知颜色修改教程
- 专利申请的详细流程和时间
- 50 + 你值得收藏的 Kubernetes 生态工具 (2020 最新版)
- bzoj 2732 射箭 半平面交 解题报告
- Codeforces Round #590 (Div. 3)题解