nn.BCELoss和nn.CrossEntropyloss总结

nn.BCEloss

公式如下:

1.输入的X 代表模型的最后输出 y 代表你的label 我们的目的就是为了让模型去更好的学习label 所以loss 越小的话 x越接近label 我们的学习效果也越好
2. 使用这个公式前 x需要先通过sigmoid 激活函数 归一化到0-1之间
3. 一般二分类都是用的nn.BCELoss 因为二分类只有0 1 之分 正样本是1 负样本是0 看这个公式 当是正样本的时候 公式为 -w(ylogx) 不看w 的话 Loss的值域应该是 0到正无穷 所以loss最小是0 也就是x为1的时候 所以 x越大loss越小 也就是x越接近正样本1 loss越小 这就是这个公式的意义 反之一样
4. 使用BCELoss input和target shape 是一样的 nn.CrossEntropyloss和这个不同

nn.CrossEntropyloss

公式如:

logsoftmax 公式:

nll loss 公式:

这个公式简单来说就是 logsoftmax+ nllloss的结合体,不明白的先看我参考的那几篇博文

log里面实际上就是softmax 公式 所以输入不需要像BECLoss 一样先经过激活函数 这里面自带激活函数
加上log 就是 logsoftmax 了
再取负数 就是nllloss的概念了 这里nllloss里面有参数 mean 和 sum 实际上就是对应的nn.CrossEntropyloss 里面的reduction参数 mean代表取均值 sum代表取总和
input和target 的shape 不一样 input 是N*C C代表种类个数 target 是N
这里的原因就是上面的nllloss的缘故 他的作用是把对应标签位置的值拿出来取负数
举个例子

比如输入是3*3 代表 3张图片预测3类 每一张图片都预测他属于每一类的概率
因为经过了softmax所以概率和为1 我们假设是
[[0.2,0.3,0.5],
[0.8,0.1,0.1],
[0.7,0.2,0.1] 可以看出来 每一行的和为1 行代表图片个数 列代表种类 而我们的标签 是3 和输入不对应
比如是[0,1,2] 这时候 会自动one-hot编码 比如0 会变成[1,0,0]他会吧每一行对应的标签的数拿出来 第一个0 应该是第0类 所以吧0.8拿出来 第二个是1 吧0.1拿出来 以此类推。 这样就拿出来了3个数 根据reduction的设置 取平均或者总和 代表了最后的损失 可以看出 只有loss越小 说明标签是对应的 学习的越好。

总结

总结一下nn.CrossEntropyloss
看整个公式 实际就是交叉熵公式

原文链接:https://blog.csdn.net/weixin_50249353/article/details/119239907

nn.BCELoss和nn.CrossEntropyloss相关推荐

  1. nn.BCELoss与nn.CrossEntropyLoss的区别

    以前我浏览博客的时候记得别人说过,BCELoss与CrossEntropyLoss都是用于分类问题.可以知道,BCELoss是Binary CrossEntropyLoss的缩写,BCELoss Cr ...

  2. 对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用

    在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题. BCELoss是Binary CrossEntropyLoss的缩写,nn. ...

  3. Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)

    目录 nn.Softmax和nn.LogSoftmax nn.NLLLoss nn.CrossEntropy nn.BCELoss 总结 在Pytorch中的交叉熵函数的血泪史要从nn.CrossEn ...

  4. torch.nn.BCELoss

    import torch import torch.nn as nn a = torch.tensor([0.1,0.3,0.6]) b = torch.tensor([0,0,1]) loss = ...

  5. torch.nn.BCELoss用法

    1. 定义 数学公式为Loss = -w * [p * log(q) + (1-p) * log(1-q)],其中p.q分别为理论标签.实际预测值,w为权重.这里的log对应数学上的ln. PyTor ...

  6. torch.nn.BCELoss are unsafe to autocast

    torch.nn.BCELoss are unsafe to autocast 默认初始loss: lobj = torch.zeros(1, device=device) 临时解决方法: bce_l ...

  7. nn.BCELoss总结

    nn.BCELoss总结 本章内容 nn.BCELoss nn.BCEWithLogitsLoss 本章内容 版本 pytorch 1.0 nn.BCELoss 用于计算预测值和真实值之间的二元交叉熵 ...

  8. nn.functional 和 nn.Module入门讲解

    本文来自<20天吃透Pytorch> 一,nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API. 利用这些张量的AP ...

  9. 速成pytorch学习——5天nn.functional 和 nn.Module

    一,nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API. 利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模 ...

最新文章

  1. 【大数据技术干货】阿里云伏羲(fuxi)调度器FuxiMaster功能简介(一) 多租户(QuotaGroup)管理...
  2. java将字符串逻辑表达式转成布尔值
  3. DataSet DataTable操作
  4. Linux信号之signal函数
  5. console.log打印:自定义样式(含源码、效果图)
  6. cad管线交叉怎么画_高效设计!多种方式进行管线连接、伸缩
  7. 【1+X】软件测试用例概述
  8. C++简易打字游戏(DEV可运行)
  9. 开发环境配置 - Python 3的安装(Win+Linux+Mac)
  10. win7计算机无法识别分辨率,window7分辨率显示不正常
  11. C++ 已知两个时间(年月日)求日期差
  12. 【割点 dfs】UVALive - 7456 Least Crucial Node
  13. Python爬虫|豆瓣图书Top250
  14. 山东大学项目实训十六——可控音乐变压器Controllable Music Transformer
  15. 尝试 Stable Diffusion(通过Google Colab)
  16. java流程图中平行四边形代表什么_程序流程图中通常用平行四边形表示分支结构...
  17. 部署zinnia的问题
  18. 区块链技术指南学习笔记2
  19. Au 闪退解决方法(很邪门)
  20. 20230614使用360安全卫士的断网急救箱解决不能上网的问题

热门文章

  1. Windows 11 WHQL认证的必要性
  2. Linux命令之计算器bc
  3. ESP32 的 I2C 原理 应用入门
  4. 【学术相关】女教师的两难困境:当生育遇上考核
  5. 若依框架使用自带的方法进行图片上传
  6. Python知道cos值求角度_机械臂正运动学-DH参数-Python快速实现
  7. python 桑基图 地理坐标_利用Python+Excel制作桑基(Sankey)图
  8. android ExtCertPathValidatorException: Could not validate
  9. 小丸子学MongoDB系列之——安装MongoDB
  10. 实战goldengate:安装配置+数据初始化+单向DML复制