Pytorch nn.BCEWithLogitsLoss()的简单理解与用法
这个东西,本质上和nn.BCELoss()没有区别,只是在BCELoss上加了个logits函数(也就是sigmoid函数),例子如下:
import torch
import torch.nn as nnlabel = torch.Tensor([1, 1, 0])
pred = torch.Tensor([3, 2, 1])
pred_sig = torch.sigmoid(pred)
loss = nn.BCELoss()
print(loss(pred_sig, label))loss = nn.BCEWithLogitsLoss()
print(loss(pred, label))loss = nn.BCEWithLogitsLoss()
print(loss(pred_sig, label))
输出结果分别为:
tensor(0.4963)
tensor(0.4963)
tensor(0.5990)
可以看到,nn.BCEWithLogitsLoss()相当于是在nn.BCELoss()中预测结果pred的基础上先做了个sigmoid,然后继续正常算loss。所以这就涉及到一个比较奇葩的bug,如果网络本身在输出结果的时候已经用sigmoid去处理了,算loss的时候用nn.BCEWithLogitsLoss()…那么就会相当于预测结果算了两次sigmoid,可能会出现各种奇奇怪怪的问题——
比如网络收敛不了(流泪猫猫头.jpg)
Ref
[1] https://zhuanlan.zhihu.com/p/170558960
Pytorch nn.BCEWithLogitsLoss()的简单理解与用法相关推荐
- Pytorch nn.Fold()的简单理解与用法
官方文档:https://pytorch.org/docs/stable/generated/torch.nn.Fold.html 这个东西基本上就是绑定Unfold使用的.实际上,在没有overla ...
- Pytorch Tensor.unfold()的简单理解与用法
unfold的作用就是手动实现的滑动窗口操作,也就是只有卷,没有积:不过相比于nn.functional中的unfold而言,其窗口的意味更浓,只能是一维的,也就是不存在类似2×2窗口的说法. ret ...
- Pytorch nn.functional.unfold()的简单理解与用法
unfold的作用就是手动实现(卷积中)的滑动窗口操作,也就是只有卷,没有积 ret = F.unfold(inp, size) inp:输入Tensor,必须是四维的(B, C, H, W) siz ...
- Pytorch forward()的简单理解与用法
1.基本用法 在pytorch中,使用torch.nn包来构建神经网络,我们定义的网络继承自nn.Module类.而一个nn.Module包含神经网络的各个层(放在__init__里面)和前向传播方式 ...
- Pytorch nn.DataParallel()的简单用法
简单来说就是使用单机多卡进行训练. 一般来说我们看到的代码是这样的: net = XXXNet() net = nn.DataParallel(net) 这样就可以让模型在全部GPU上训练. 方法定义 ...
- Pytorch nn.Transformer的mask理解
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨林小平@知乎(已授权) 来源丨https://zhuanlan ...
- 收藏 | Pytorch nn.Transformer的mask理解
点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨林小平@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p/35336542 ...
- pytorch卷积操作nn.Conv中的groups参数用法解释
MobileNetV1<MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications> ...
- Focal Loss 分类问题 pytorch实现代码(简单实现)
ps:由于降阳性这步正负样本数量在差距巨大.正样本1500多个,而负样本750000多个.要用 Focal Loss来解决这个问题. 首先感谢Code_Mart的博客把理论汇总了下https://bl ...
最新文章
- ROC曲线是什么?ROC曲线是怎么绘制的?ROC曲线的横纵坐标是什么?如何用Python绘制?AUC又是什么?
- linux mysql 二进制包_mysql5.7 二进制包安装
- linux库函数mmap()原理及用法详解
- Lighttpd日志打印格式
- java sleep方法_一文搞懂 Java 线程中断!
- 微信知识付费小程序博客源码(带299条数据)
- MTK 驱动---(8)emmc 介绍
- 【零基础学Java】—字符串的概述和特点(十五)
- Android 实例解说Application类
- JAVA怎么在函数内改变传入的值
- *第七周*数据结构实践项目三【负数把整数赶出队列】
- KEIL使用教程——KEIL常用配置技巧
- Java 汉字繁体转简体
- python画网络图_python3中NetworkX网络图绘制
- 泛函分析极简笔记(2)——Mahalanobis distance
- 监控mysql锁定状态_mysql InnoDB锁等待的查看及分析
- Confluence 6 的小型文字档案(Cookies)
- EBS-自动获取/创建CCID
- Persistence Query
- 菜鸟笔记-DuReader阅读理解基线模型代码阅读笔记(八)—— 模型训练-训练
热门文章
- 王道 计算机网络试题讲解_计算机考研-统考计算机网络(2009-2012)试题及精心解析...
- linux与linux传文件乱码,关于Linux与windows传递文件乱码问题
- 开方计算与浮点数的问题
- 计算机类中外合作办学情况,郑州大学2021年河南省本科一批各专业录取分数统计...
- Java实验3 方法与数组
- AcWing 100. 增减序列
- ##CSP 201803-2 碰撞的小球(C语言)100分
- 东大OJ-Prim算法
- Spring中Bean管理操作基于XML配置文件方法实现
- Python:列表list对应项求和