binary_cross_entropybinary_cross_entropy_with_logits都是来自torch.nn.functional的函数,首先对比官方文档对它们的区别:

函数名 解释
binary_cross_entropy Function that measures the Binary Cross Entropy between the target and the output
binary_cross_entropy_with_logits Function that measures Binary Cross Entropy between target and output logits

区别只在于这个logits,那么这个logits是什么意思呢?以下是从网络上找到的一个答案:

有一个(类)损失函数名字中带了with_logits. 而这里的logits指的是,该损失函数已经内部自带了计算logit的操作,无需在传入给这个loss函数之前手动使用sigmoid/softmax将之前网络的输入映射到[0,1]之间

再看看官方给的示例代码:
binary_cross_entropy:

input = torch.randn((3, 2), requires_grad=True)
target = torch.rand((3, 2), requires_grad=False)
loss = F.binary_cross_entropy(F.sigmoid(input), target)
loss.backward()
# input is  tensor([[-0.5474,  0.2197],
#         [-0.1033, -1.3856],
#         [-0.2582, -0.1918]], requires_grad=True)
# target is  tensor([[0.7867, 0.5643],
#         [0.2240, 0.8263],
#         [0.3244, 0.2778]])
# loss is  tensor(0.8196, grad_fn=<BinaryCrossEntropyBackward>)

binary_cross_entropy_with_logits:

input = torch.randn(3, requires_grad=True)
target = torch.empty(3).random_(2)
loss = F.binary_cross_entropy_with_logits(input, target)
loss.backward()
# input is  tensor([ 1.3210, -0.0636,  0.8165], requires_grad=True)
# target is  tensor([0., 1., 1.])
# loss is  tensor(0.8830, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)

的确binary_cross_entropy_with_logits不需要sigmoid函数了。

事实上,官方是推荐使用函数带有with_logits的,解释是
This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.
翻译一下就是说将sigmoid层和binaray_cross_entropy合在一起计算比分开依次计算有更好的数值稳定性,这主要是运用了log-sum-exp技巧。

那么这个log-sum-exp主要就是讲如何防止数值计算溢出的问题:
l o g s u m e x p ( x 1 , x 2 , . . . , x n ) = l o g ( ∑ i = 1 n e x i ) logsumexp(x_1,x_2,...,x_n) = log(\sum_{i=1}^{n}e^{x_i}) logsumexp(x1​,x2​,...,xn​)=log(i=1∑n​exi​)针对上述式子,如果 x i x_i xi​很大,那么 e x i e^{x_i} exi​很有可能会溢出,为了避免这样的问题,上式可以进行如下变换:
l o g ( ∑ i = 1 n e x i ) = l o g ( e c ∑ i = 1 n e x i − c ) = c l o g e + l o g ( ∑ i = 1 n e x i − c ) log(\sum_{i=1}^{n}e^{x_i})=log(e^c\sum_{i=1}^{n}e^{x_i-c})=cloge+log(\sum_{i=1}^{n}e^{x_i-c}) log(i=1∑n​exi​)=log(eci=1∑n​exi​−c)=cloge+log(i=1∑n​exi​−c)于是乎,这样就可以避免数据溢出了。

pytorch损失函数binary_cross_entropy和binary_cross_entropy_with_logits的区别相关推荐

  1. Pytorch损失函数cross_entropy、binary_cross_entropy和binary_cross_entropy_with_logits的区别

    在做分类问题时我们经常会遇到这几个交叉熵函数: cross_entropy.binary_cross_entropy和binary_cross_entropy_with_logits. 那么他们有什么 ...

  2. pytorch中CrossEntropyLoss和NLLLoss的区别与联系

    pytorch中CrossEntropyLoss和NLLLoss的区别与联系 CrossEntropyLoss和NLLLoss主要是用在多分类问题的损失函数,他们两个既有不同,也有不浅的联系.先分别看 ...

  3. 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用

    首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...

  4. pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...

  5. Pytorch损失函数losses简介

    一般来说,监督学习的目标函数由损失函数和正则化项组成.(Objective = Loss + Regularization) Pytorch中的损失函数一般在训练模型时候指定. 注意Pytorch中内 ...

  6. Pytorch损失函数篇

    点击关注我哦 一篇文章带你了解pytorch中常用的损失函数 Q:什么是损失函数? 训练神经网络类似于人类的学习方式.我们将数据提供给模型,它可以预测某些内容,并告诉其预测是否正确.然后,模型纠正其错 ...

  7. 损失函数/成本函数/目标函数的区别

    https://mp.weixin.qq.com/s/nkfQnXIwDNPtcZEVQyTvZw 导读 在我刚开始学机器学习的时候也是闹不懂这三者的区别,当然,嘿嘿,初学者的你们是不是也有那么一丢丢 ...

  8. pytorch 损失函数总结

    PyTorch深度学习实战 4 损失函数 损失函数,又叫目标函数,是编译一个神经网络模型必须的两个参数之一.另一个必不可少的参数是优化器. 损失函数是指用于计算标签值和预测值之间差异的函数,在机器学习 ...

  9. Pytorch —— 损失函数(二)

    目录 5.nn.L1Loss 6.nn.MSELoss 7.nn.SmoothL1Loss 8.nn.PoissonNLLLoss 9.nn.KLDivLoss 10.nn.MarginRanking ...

最新文章

  1. 进程(process)和线程
  2. jsp的<a>标签中怎么传递参数
  3. macos安装盘第三方工具制作_一步一步教你为macOS创建系统安装盘
  4. 阿里云服务器mysql莫名丢失_mysql数据库丢失
  5. i2c-toos 交互数据_什么是CD-i(交互式光盘)?
  6. 95-38-045-Buffer-UnpooledByteBuf
  7. Kudu : NonRecoverableException: illegal replication factor 2 (replication factor must be odd)
  8. Nodejs获取MySQL数据_nodejs同步调用获取mysql数据时遇到的大坑
  9. BZOJ1001: [BeiJing2006]狼抓兔子
  10. 企业员工管理系统 一:项目介绍
  11. MaxProxy可以成为永久关闭的911S5代理的新选择吗?
  12. [BZOJ3993]-[SDOI2015]星际战争-二分答案+最大流
  13. 微信小程序跳转公众号
  14. @linux安装及使用(压缩|解压)工具RAR
  15. 安装lux:推荐一款网页视频下载工具。并简单使用。(win)
  16. matlab仿真中pv,PV的matlab仿真
  17. 人工智能应用的细分领域有哪些
  18. 国内有名的汽车与交通调查研究咨询公司情况
  19. instruction-tuning
  20. 网站文章采集、撰写、推广注意要点

热门文章

  1. 数据增强 data augmentation
  2. nmon监控资源工具下载以及安装
  3. 个人向前端知识“复健”
  4. 风险与收益并存——新书《利益攸关》解读
  5. 济宁网络诈骗立案_在节日期间避免网络犯罪和诈骗
  6. 看门狗喂狗实验(有问题)
  7. 20155314 2016-2017-2 《Java程序设计》第6周学习总结
  8. 3d打印热床的PEI/玻璃/晶格玻璃/柔性平台/弹簧钢板如何选择
  9. 关于百度飞浆安装不成功的坑
  10. java基础-多态-多态的理解及使用