这个东西,本质上和nn.BCELoss()没有区别,只是在BCELoss上加了个logits函数(也就是sigmoid函数),例子如下:

import torch
import torch.nn as nnlabel = torch.Tensor([1, 1, 0])
pred = torch.Tensor([3, 2, 1])
pred_sig = torch.sigmoid(pred)
loss = nn.BCELoss()
print(loss(pred_sig, label))loss = nn.BCEWithLogitsLoss()
print(loss(pred, label))loss = nn.BCEWithLogitsLoss()
print(loss(pred_sig, label))

输出结果分别为:

tensor(0.4963)
tensor(0.4963)
tensor(0.5990)

可以看到,nn.BCEWithLogitsLoss()相当于是在nn.BCELoss()中预测结果pred的基础上先做了个sigmoid,然后继续正常算loss。所以这就涉及到一个比较奇葩的bug,如果网络本身在输出结果的时候已经用sigmoid去处理了,算loss的时候用nn.BCEWithLogitsLoss()…那么就会相当于预测结果算了两次sigmoid,可能会出现各种奇奇怪怪的问题——

比如网络收敛不了(流泪猫猫头.jpg)

Ref

[1] https://zhuanlan.zhihu.com/p/170558960

Pytorch nn.BCEWithLogitsLoss()的简单理解与用法相关推荐

  1. Pytorch nn.Fold()的简单理解与用法

    官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Fold.html 这个东西基本上就是绑定Unfold使用的.实际上,在没有overla ...

  2. Pytorch Tensor.unfold()的简单理解与用法

    unfold的作用就是手动实现的滑动窗口操作,也就是只有卷,没有积:不过相比于nn.functional中的unfold而言,其窗口的意味更浓,只能是一维的,也就是不存在类似2×2窗口的说法. ret ...

  3. Pytorch nn.functional.unfold()的简单理解与用法

    unfold的作用就是手动实现(卷积中)的滑动窗口操作,也就是只有卷,没有积 ret = F.unfold(inp, size) inp:输入Tensor,必须是四维的(B, C, H, W) siz ...

  4. Pytorch forward()的简单理解与用法

    1.基本用法 在pytorch中,使用torch.nn包来构建神经网络,我们定义的网络继承自nn.Module类.而一个nn.Module包含神经网络的各个层(放在__init__里面)和前向传播方式 ...

  5. Pytorch nn.DataParallel()的简单用法

    简单来说就是使用单机多卡进行训练. 一般来说我们看到的代码是这样的: net = XXXNet() net = nn.DataParallel(net) 这样就可以让模型在全部GPU上训练. 方法定义 ...

  6. Pytorch nn.Transformer的mask理解

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨林小平@知乎(已授权) 来源丨https://zhuanlan ...

  7. 收藏 | Pytorch nn.Transformer的mask理解

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨林小平@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p/35336542 ...

  8. pytorch卷积操作nn.Conv中的groups参数用法解释

    MobileNetV1<MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications> ...

  9. Focal Loss 分类问题 pytorch实现代码(简单实现)

    ps:由于降阳性这步正负样本数量在差距巨大.正样本1500多个,而负样本750000多个.要用 Focal Loss来解决这个问题. 首先感谢Code_Mart的博客把理论汇总了下https://bl ...

最新文章

  1. ROC曲线是什么?ROC曲线是怎么绘制的?ROC曲线的横纵坐标是什么?如何用Python绘制?AUC又是什么?
  2. linux mysql 二进制包_mysql5.7 二进制包安装
  3. linux库函数mmap()原理及用法详解
  4. Lighttpd日志打印格式
  5. java sleep方法_一文搞懂 Java 线程中断!
  6. 微信知识付费小程序博客源码(带299条数据)
  7. MTK 驱动---(8)emmc 介绍
  8. 【零基础学Java】—字符串的概述和特点(十五)
  9. Android 实例解说Application类
  10. JAVA怎么在函数内改变传入的值
  11. *第七周*数据结构实践项目三【负数把整数赶出队列】
  12. KEIL使用教程——KEIL常用配置技巧
  13. Java 汉字繁体转简体
  14. python画网络图_python3中NetworkX网络图绘制
  15. 泛函分析极简笔记(2)——Mahalanobis distance
  16. 监控mysql锁定状态_mysql InnoDB锁等待的查看及分析
  17. Confluence 6 的小型文字档案(Cookies)
  18. EBS-自动获取/创建CCID
  19. Persistence Query
  20. 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练

热门文章

  1. 王道 计算机网络试题讲解_计算机考研-统考计算机网络(2009-2012)试题及精心解析...
  2. linux与linux传文件乱码,关于Linux与windows传递文件乱码问题
  3. 开方计算与浮点数的问题
  4. 计算机类中外合作办学情况,郑州大学2021年河南省本科一批各专业录取分数统计...
  5. Java实验3 方法与数组
  6. AcWing 100. 增减序列
  7. ##CSP 201803-2 碰撞的小球(C语言)100分
  8. 东大OJ-Prim算法
  9. Spring中Bean管理操作基于XML配置文件方法实现
  10. Python:列表list对应项求和