在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题。

BCELoss是Binary CrossEntropyLoss的缩写,nn.BCELoss()为二元交叉熵损失函数,只能解决二分类问题。

在使用nn.BCELoss()作为损失函数时,需要在该层前面加上Sigmoid函数,一般使用nn.Sigmoid()即可,

而在使用nn.CrossEntropyLoss()其内部会自动加上Sofrmax层。下面会详细说明。

nn.CrossEntropyLoss()

nn.CrossEntropyLoss()的计算公式如下:

其中x为输入也是网络的最后一层的输出,其shape为[batchsize,class],log以e为底,,class为类别索引。

例如,输入x=[[4,8,3]],shape=(1,3),即batchsize=1,class=3,我们首先计算一下。

loss(x,0)=-x[0]+log(exp(x[0])+exp(x[1])+exp(x[2]))=-4+log(exp(4)+exp(8)+exp(3))=-4+8.0247=4.0247

loss(x,1)=-x[1]+log(exp(x[0])+exp(x[1])+exp(x[2]))=-8+log(exp(4)+exp(8)+exp(3))=-8+8.0247=0.0247

loss(x,2)=-x[2]+log(exp(x[0])+exp(x[1])+exp(x[2]))=-3+log(exp(4)+exp(8)+exp(3))=-3+8.0247=5.0247

可以看到loss(x,1)的损失函数最小,也即网络输出为第1类(0表示第0类,1表示第1类,2表示第2类)可能性最大。

loss(x,1)的损失函数最大,则网络输出为第2类的可能性最小。

可以知道若网络的输出为x=[[4,8,3]],而对应的标签为1,则得到损失函数loss=0.0247。

我们在计算机上实现一下:

import torch
from torch import nna=torch.Tensor([[4,8,3]])
y=torch.Tensor([2.]).long()
print(a.numpy())
print(y.numpy(),y.type())
criteon = nn.CrossEntropyLoss()   #nn.CrossEntropyLoss会自动加上Sofrmax层。loss = criteon(a, y)
print("loss=",loss.item())
y1=torch.Tensor([1.]).long()
y2=torch.Tensor([0.]).long()
print("loss1=",criteon(a, y1).item())
print("loss2=",criteon(a, y2).item())

输出结果为:

[[4. 8. 3.]]
[2] torch.LongTensor
loss= 5.024744987487793
loss1= 0.024744924157857895
loss2= 4.024744987487793

跟上面手动计算的基本一致,值得注意的是nn.CrossEntropyLoss()的参数中,第一个参数为网络的输出结果,类型为FloatTensor

第二个参数为标签,为LongTensor类型

nn.BCELoss()

参考nn.BCELoss()的英文源码

计算公式如下:

t[i]—— 表示样本i的label,正类为1,负类为0
o[i]—— 表示样本i预测为正的概率,是神经网络的输出,再通过softmax,log以e为底。

例如,神经网络的输出结果为x=[6,3,-4,6],我们很容易发现,batchsize=4,通过softmax得到o=[0.9975, 0.9526, 0.0180, 0.9975]

lable=[1,0,0,1]通过公式手动计算loss:

r1 = 1 * log(0.9975) + (1-1) *log(1 - 0.9975)=-0.002503130218118477
r2 = 0 * log(0.9526) + (1-0) *log(1 - 0.9526)=...
r3 = 0 * log(0.0180) + (1-0) *log(1 - 0.0180)=...
r4 = 1 * log(0.9975) + (1-1) *log(1 - 0.9975)=...
loss=(-1/4)(r1+r2+r3+r4)=0.7680758203362535

在计算机上使用nn.BCELoss()实现:

import torch
import numpy as np
from torch import nna=torch.Tensor([6,3,-4,6])
y=torch.Tensor([1,0,0,1])
print(a.numpy())
print(y.numpy(),y.type())
criteon = nn.BCELoss()   pred=nn.Sigmoid()(a)
print("pred=",pred)
loss = criteon(pred, y)
print("loss=",loss.item())

输出结果如下:

[ 6.  3. -4.  6.]
[1. 0. 0. 1.] torch.FloatTensor
pred= tensor([0.9975, 0.9526, 0.0180, 0.9975])
loss= 0.7679222226142883

可以看到与手动计算的loss,相差不大,对于nn.BCELoss()的输入其第一个参数为神经网络nn.Sigmoid()的输出,第二个参数为FloatTensor类型。

对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用相关推荐

  1. pytorch 中维度(Dimension)概念的理解

    pytorch 中维度(Dimension)概念的理解 Dimension为0(即维度为0时) 维度为0时,即tensor(张量)为标量.例如:神经网络中损失函数的值即为标量. 接下来我们创建一个di ...

  2. pytorch中的CrossEntropyLoss

    这里主要探讨torch.nn.CrossEntropyLoss函数的用法. 使用方法如下: # 首先定义该类 loss = torch.nn.CrossEntropyLoss() #然后传参进去 lo ...

  3. pytorch中网络loss传播和参数更新理解

    相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56 ...

  4. [深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别

    nn.NLLLoss()的参数是经过logsoftmax加工的,而CrossEntropyLoss的是原始输出数据 target = torch.tensor([1, 2]) entropy_out ...

  5. pytorch 中pad函数toch.nn.functional.pad()的使用

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...

  6. pytorch中的transforms.ToTensor和transforms.Normalize理解

  7. Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)

    目录 nn.Softmax和nn.LogSoftmax nn.NLLLoss nn.CrossEntropy nn.BCELoss 总结 在Pytorch中的交叉熵函数的血泪史要从nn.CrossEn ...

  8. nn.BCELoss和nn.CrossEntropyloss

    nn.BCELoss和nn.CrossEntropyloss总结 nn.BCEloss 公式如下: 1.输入的X 代表模型的最后输出 y 代表你的label 我们的目的就是为了让模型去更好的学习lab ...

  9. 速成pytorch学习——5天nn.functional 和 nn.Module

    一,nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API. 利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模 ...

最新文章

  1. php get memory,PHP memory_get_usage 和 memory_get_peak_usage获取内存的区别
  2. Hbase表结构设计
  3. linux两个文件修改主机名
  4. python基础知识~ 函数详解2
  5. delphi 调用php接口_新浪图床 API 接口调用与请求方法详细教程
  6. PostgreSQL 8.0 中文手册
  7. java架构师之路:JAVA程序员必看的15本书的电子版下载地址
  8. Andorid 反编译App
  9. SpringMVC框架使用注解执行定时任务
  10. 用vb6.0查看计算机用户名,如何使用vb6.0输入登录窗口的用户名和密码?
  11. numpy教程:函数库和ufunc函数
  12. Atitit 文件上传功能的实现 图片 视频 目录 1. 上传原理 1 1.1. http post编码 multipart / form-data 1 1.2. 临时文件模式 最简单 2 1.3
  13. oracle数据库提示ORA-01033
  14. 儿童python编程教程-一款儿童编程入门的理想工具——PythonTurtle
  15. 网页打开5秒后弹出广告窗口
  16. 就是要让你搞懂Nginx,这篇就够了!
  17. 阿里 P10 是怎样的存在?
  18. websocket连接不稳定_帮你解决WiFi卡顿:拒绝连接不稳定、网速慢
  19. 从单个系统到云翼一体化支撑,京东云DevOps推进中的一波三折
  20. CSS进阶(6)- 居中总结

热门文章

  1. 「SwiftUI」延迟执行代码
  2. 转载 sap FI-CO总账科目简析
  3. 企业面临的7大数据分析挑战
  4. 无法安装64位版本的微软Office
  5. 云原生之Kubernetes:18、详解准入控制器
  6. 用Javascript 编写 HTML在线编辑器
  7. Veins文档(中文)
  8. PMP考试 变更管理专题
  9. 计算机体系结构:不同改进方案的性价比计算(1.4)
  10. 域控知识与安全01:域控知识基础