pytorch损失函数binary_cross_entropy和binary_cross_entropy_with_logits的区别
binary_cross_entropy
和binary_cross_entropy_with_logits
都是来自torch.nn.functional
的函数,首先对比官方文档对它们的区别:
函数名 | 解释 |
---|---|
binary_cross_entropy | Function that measures the Binary Cross Entropy between the target and the output |
binary_cross_entropy_with_logits | Function that measures Binary Cross Entropy between target and output logits |
区别只在于这个logits,那么这个logits是什么意思呢?以下是从网络上找到的一个答案:
有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间
再看看官方给的示例代码:
binary_cross_entropy:
input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = F.binary_cross_entropy(F.sigmoid(input), target)
loss.backward()
# input is tensor([[-0.5474, 0.2197],
# [-0.1033, -1.3856],
# [-0.2582, -0.1918]], requires_grad=True)
# target is tensor([[0.7867, 0.5643],
# [0.2240, 0.8263],
# [0.3244, 0.2778]])
# loss is tensor(0.8196, grad_fn=<BinaryCrossEntropyBackward>)
binary_cross_entropy_with_logits:
input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
loss = F.binary_cross_entropy_with_logits(input, target)
loss.backward()
# input is tensor([ 1.3210, -0.0636, 0.8165], requires_grad=True)
# target is tensor([0., 1., 1.])
# loss is tensor(0.8830, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
的确binary_cross_entropy_with_logits不需要sigmoid函数了。
事实上,官方是推荐使用函数带有with_logits的,解释是
This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.
翻译一下就是说将sigmoid层和binaray_cross_entropy合在一起计算比分开依次计算有更好的数值稳定性,这主要是运用了log-sum-exp技巧。
那么这个log-sum-exp主要就是讲如何防止数值计算溢出的问题:
l o g s u m e x p ( x 1 , x 2 , . . . , x n ) = l o g ( ∑ i = 1 n e x i ) logsumexp(x_1,x_2,...,x_n) = log(\sum_{i=1}^{n}e^{x_i}) logsumexp(x1,x2,...,xn)=log(i=1∑nexi)针对上述式子,如果 x i x_i xi很大,那么 e x i e^{x_i} exi很有可能会溢出,为了避免这样的问题,上式可以进行如下变换:
l o g ( ∑ i = 1 n e x i ) = l o g ( e c ∑ i = 1 n e x i − c ) = c l o g e + l o g ( ∑ i = 1 n e x i − c ) log(\sum_{i=1}^{n}e^{x_i})=log(e^c\sum_{i=1}^{n}e^{x_i-c})=cloge+log(\sum_{i=1}^{n}e^{x_i-c}) log(i=1∑nexi)=log(eci=1∑nexi−c)=cloge+log(i=1∑nexi−c)于是乎,这样就可以避免数据溢出了。
pytorch损失函数binary_cross_entropy和binary_cross_entropy_with_logits的区别相关推荐
- Pytorch损失函数cross_entropy、binary_cross_entropy和binary_cross_entropy_with_logits的区别
在做分类问题时我们经常会遇到这几个交叉熵函数: cross_entropy.binary_cross_entropy和binary_cross_entropy_with_logits. 那么他们有什么 ...
- pytorch中CrossEntropyLoss和NLLLoss的区别与联系
pytorch中CrossEntropyLoss和NLLLoss的区别与联系 CrossEntropyLoss和NLLLoss主要是用在多分类问题的损失函数,他们两个既有不同,也有不浅的联系.先分别看 ...
- 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用
首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...
- pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...
- Pytorch损失函数losses简介
一般来说,监督学习的目标函数由损失函数和正则化项组成.(Objective = Loss + Regularization) Pytorch中的损失函数一般在训练模型时候指定. 注意Pytorch中内 ...
- Pytorch损失函数篇
点击关注我哦 一篇文章带你了解pytorch中常用的损失函数 Q:什么是损失函数? 训练神经网络类似于人类的学习方式.我们将数据提供给模型,它可以预测某些内容,并告诉其预测是否正确.然后,模型纠正其错 ...
- 损失函数/成本函数/目标函数的区别
https://mp.weixin.qq.com/s/nkfQnXIwDNPtcZEVQyTvZw 导读 在我刚开始学机器学习的时候也是闹不懂这三者的区别,当然,嘿嘿,初学者的你们是不是也有那么一丢丢 ...
- pytorch 损失函数总结
PyTorch深度学习实战 4 损失函数 损失函数,又叫目标函数,是编译一个神经网络模型必须的两个参数之一.另一个必不可少的参数是优化器. 损失函数是指用于计算标签值和预测值之间差异的函数,在机器学习 ...
- Pytorch —— 损失函数(二)
目录 5.nn.L1Loss 6.nn.MSELoss 7.nn.SmoothL1Loss 8.nn.PoissonNLLLoss 9.nn.KLDivLoss 10.nn.MarginRanking ...
最新文章
- 进程(process)和线程
- jsp的<a>标签中怎么传递参数
- macos安装盘第三方工具制作_一步一步教你为macOS创建系统安装盘
- 阿里云服务器mysql莫名丢失_mysql数据库丢失
- i2c-toos 交互数据_什么是CD-i(交互式光盘)?
- 95-38-045-Buffer-UnpooledByteBuf
- Kudu : NonRecoverableException: illegal replication factor 2 (replication factor must be odd)
- Nodejs获取MySQL数据_nodejs同步调用获取mysql数据时遇到的大坑
- BZOJ1001: [BeiJing2006]狼抓兔子
- 企业员工管理系统 一:项目介绍
- MaxProxy可以成为永久关闭的911S5代理的新选择吗?
- [BZOJ3993]-[SDOI2015]星际战争-二分答案+最大流
- 微信小程序跳转公众号
- @linux安装及使用(压缩|解压)工具RAR
- 安装lux:推荐一款网页视频下载工具。并简单使用。(win)
- matlab仿真中pv,PV的matlab仿真
- 人工智能应用的细分领域有哪些
- 国内有名的汽车与交通调查研究咨询公司情况
- instruction-tuning
- 网站文章采集、撰写、推广注意要点