对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用
在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题。
BCELoss是Binary CrossEntropyLoss的缩写,nn.BCELoss()为二元交叉熵损失函数,只能解决二分类问题。
在使用nn.BCELoss()作为损失函数时,需要在该层前面加上Sigmoid函数,一般使用nn.Sigmoid()即可,
而在使用nn.CrossEntropyLoss()其内部会自动加上Sofrmax层。下面会详细说明。
nn.CrossEntropyLoss()
nn.CrossEntropyLoss()的计算公式如下:
其中x为输入也是网络的最后一层的输出,其shape为[batchsize,class],log以e为底,,class为类别索引。
例如,输入x=[[4,8,3]],shape=(1,3),即batchsize=1,class=3,我们首先计算一下。
loss(x,0)=-x[0]+log(exp(x[0])+exp(x[1])+exp(x[2]))=-4+log(exp(4)+exp(8)+exp(3))=-4+8.0247=4.0247
loss(x,1)=-x[1]+log(exp(x[0])+exp(x[1])+exp(x[2]))=-8+log(exp(4)+exp(8)+exp(3))=-8+8.0247=0.0247
loss(x,2)=-x[2]+log(exp(x[0])+exp(x[1])+exp(x[2]))=-3+log(exp(4)+exp(8)+exp(3))=-3+8.0247=5.0247
可以看到loss(x,1)的损失函数最小,也即网络输出为第1类(0表示第0类,1表示第1类,2表示第2类)可能性最大。
loss(x,1)的损失函数最大,则网络输出为第2类的可能性最小。
可以知道若网络的输出为x=[[4,8,3]],而对应的标签为1,则得到损失函数loss=0.0247。
我们在计算机上实现一下:
import torch
from torch import nna=torch.Tensor([[4,8,3]])
y=torch.Tensor([2.]).long()
print(a.numpy())
print(y.numpy(),y.type())
criteon = nn.CrossEntropyLoss() #nn.CrossEntropyLoss会自动加上Sofrmax层。loss = criteon(a, y)
print("loss=",loss.item())
y1=torch.Tensor([1.]).long()
y2=torch.Tensor([0.]).long()
print("loss1=",criteon(a, y1).item())
print("loss2=",criteon(a, y2).item())
输出结果为:
[[4. 8. 3.]]
[2] torch.LongTensor
loss= 5.024744987487793
loss1= 0.024744924157857895
loss2= 4.024744987487793
跟上面手动计算的基本一致,值得注意的是nn.CrossEntropyLoss()的参数中,第一个参数为网络的输出结果,类型为FloatTensor
第二个参数为标签,为LongTensor类型
nn.BCELoss()
参考nn.BCELoss()的英文源码
计算公式如下:
t[i]—— 表示样本i的label,正类为1,负类为0
o[i]—— 表示样本i预测为正的概率,是神经网络的输出,再通过softmax,log以e为底。
例如,神经网络的输出结果为x=[6,3,-4,6],我们很容易发现,batchsize=4,通过softmax得到o=[0.9975, 0.9526, 0.0180, 0.9975]
lable=[1,0,0,1]通过公式手动计算loss:
r1 = 1 * log(0.9975) + (1-1) *log(1 - 0.9975)=-0.002503130218118477
r2 = 0 * log(0.9526) + (1-0) *log(1 - 0.9526)=...
r3 = 0 * log(0.0180) + (1-0) *log(1 - 0.0180)=...
r4 = 1 * log(0.9975) + (1-1) *log(1 - 0.9975)=...
loss=(-1/4)(r1+r2+r3+r4)=0.7680758203362535
在计算机上使用nn.BCELoss()实现:
import torch
import numpy as np
from torch import nna=torch.Tensor([6,3,-4,6])
y=torch.Tensor([1,0,0,1])
print(a.numpy())
print(y.numpy(),y.type())
criteon = nn.BCELoss() pred=nn.Sigmoid()(a)
print("pred=",pred)
loss = criteon(pred, y)
print("loss=",loss.item())
输出结果如下:
[ 6. 3. -4. 6.]
[1. 0. 0. 1.] torch.FloatTensor
pred= tensor([0.9975, 0.9526, 0.0180, 0.9975])
loss= 0.7679222226142883
可以看到与手动计算的loss,相差不大,对于nn.BCELoss()的输入其第一个参数为神经网络nn.Sigmoid()的输出,第二个参数为FloatTensor类型。
对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用相关推荐
- pytorch 中维度(Dimension)概念的理解
pytorch 中维度(Dimension)概念的理解 Dimension为0(即维度为0时) 维度为0时,即tensor(张量)为标量.例如:神经网络中损失函数的值即为标量. 接下来我们创建一个di ...
- pytorch中的CrossEntropyLoss
这里主要探讨torch.nn.CrossEntropyLoss函数的用法. 使用方法如下: # 首先定义该类 loss = torch.nn.CrossEntropyLoss() #然后传参进去 lo ...
- pytorch中网络loss传播和参数更新理解
相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56 ...
- [深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别
nn.NLLLoss()的参数是经过logsoftmax加工的,而CrossEntropyLoss的是原始输出数据 target = torch.tensor([1, 2]) entropy_out ...
- pytorch 中pad函数toch.nn.functional.pad()的使用
padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...
- pytorch中的transforms.ToTensor和transforms.Normalize理解
- Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)
目录 nn.Softmax和nn.LogSoftmax nn.NLLLoss nn.CrossEntropy nn.BCELoss 总结 在Pytorch中的交叉熵函数的血泪史要从nn.CrossEn ...
- nn.BCELoss和nn.CrossEntropyloss
nn.BCELoss和nn.CrossEntropyloss总结 nn.BCEloss 公式如下: 1.输入的X 代表模型的最后输出 y 代表你的label 我们的目的就是为了让模型去更好的学习lab ...
- 速成pytorch学习——5天nn.functional 和 nn.Module
一,nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API. 利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模 ...
最新文章
- php get memory,PHP memory_get_usage 和 memory_get_peak_usage获取内存的区别
- Hbase表结构设计
- linux两个文件修改主机名
- python基础知识~ 函数详解2
- delphi 调用php接口_新浪图床 API 接口调用与请求方法详细教程
- PostgreSQL 8.0 中文手册
- java架构师之路:JAVA程序员必看的15本书的电子版下载地址
- Andorid 反编译App
- SpringMVC框架使用注解执行定时任务
- 用vb6.0查看计算机用户名,如何使用vb6.0输入登录窗口的用户名和密码?
- numpy教程:函数库和ufunc函数
- Atitit 文件上传功能的实现 图片 视频 目录 1. 上传原理	1 1.1. http post编码 multipart / form-data	1 1.2. 临时文件模式 最简单	2 1.3
- oracle数据库提示ORA-01033
- 儿童python编程教程-一款儿童编程入门的理想工具——PythonTurtle
- 网页打开5秒后弹出广告窗口
- 就是要让你搞懂Nginx,这篇就够了!
- 阿里 P10 是怎样的存在?
- websocket连接不稳定_帮你解决WiFi卡顿:拒绝连接不稳定、网速慢
- 从单个系统到云翼一体化支撑,京东云DevOps推进中的一波三折
- CSS进阶(6)- 居中总结