pytorch中的二分类及多分类交叉熵损失函数
本文主要记录一下pytorch里面的二分类及多分类交叉熵损失函数的使用。
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(2020)
<torch._C.Generator at 0x7f4e8b3298b0>
二分类交叉熵损失函数
该函数主要用于多标签分类中,针对每个标签进行二分类。
Single
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
print(input)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
print(output)
f_output = F.binary_cross_entropy(m(input), target)
print(f_output)
l_output = nn.BCEWithLogitsLoss()(input, target)
print(l_output)
tensor([ 1.2372, -0.9604, 1.5415], requires_grad=True)
tensor(0.2576, grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.2576, grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.2576, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
Batch
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(32,5, requires_grad=True)
target = torch.empty(32,5).random_(2)
output = loss(m(input), target)
print(output)
f_output = F.binary_cross_entropy(m(input), target)
print(f_output)
l_output = nn.BCEWithLogitsLoss()(input, target)
print(l_output)
tensor([[ 1.2986, 1.5832, -1.1648, 0.8027, -0.9628],[-1.5793, -0.2155, 0.4706, -1.2511, 0.7105],[-0.1274, -1.9361, 0.8374, 0.0081, -0.1504],[ 0.1521, 1.1443, 0.2171, -1.1438, 0.9341],[-3.3199, 1.2998, 0.3918, 0.8327, 1.2411],[-0.8507, -0.1016, -1.2434, -0.5755, 0.1871],[-0.3064, 1.3751, 1.8478, 0.0326, 0.2032],[ 0.1782, 2.3037, 1.5948, -1.4731, 1.5312],[-0.9075, -1.7135, 0.4650, -1.7061, 0.0625],[-1.1904, 0.1130, -1.6609, -0.2000, -0.1422],[ 0.3307, -0.8395, -1.3068, -0.8891, 0.9858],[ 0.5484, 0.7461, -1.0738, -2.2162, 0.6801],[-0.8803, 0.9934, -1.6438, 0.3860, 0.4111],[-1.1078, -0.9629, -0.9534, -0.6207, 0.6885],[-0.0175, 1.9496, 0.9740, -0.4687, -0.6127],[ 0.3713, 0.8074, 0.3072, 1.1604, -0.2669],[-0.1773, -0.2787, 0.1926, 0.7492, 0.7492],[-0.3126, -0.3321, -1.7287, -3.0126, 0.1194],[ 1.0486, -0.1890, -0.5853, 0.4353, 0.2619],[ 1.9726, -0.5510, -0.1826, -0.8600, -0.9906],[ 0.7551, 0.8431, -0.8461, -1.2120, 0.2908],[-0.0932, -0.7151, -0.0631, 1.7554, 0.7374],[-0.1494, -0.6990, -0.1666, 2.0430, 1.3968],[ 0.2280, -0.3187, 1.0309, -0.1067, 1.1622],[-1.5120, -0.8617, 1.4165, -0.2361, -0.0355],[-0.8757, -0.6554, 0.1121, -0.1669, -0.2628],[-0.8023, 0.2305, -1.1792, 0.4314, -0.3653],[ 0.7487, 0.5358, -0.2677, -0.8128, 0.3029],[ 1.4439, -0.5677, 0.5564, -0.2485, -0.3281],[-2.0259, 1.1038, 1.0615, 1.7317, -0.0531],[ 0.9083, -0.8274, 0.8101, -1.1375, -1.2009],[ 0.3300, -0.8760, 1.3459, -1.0209, -0.5313]], requires_grad=True)
tensor(0.8165, grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.8165, grad_fn=<BinaryCrossEntropyBackward>)
tensor(0.8165, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
Note
nn.BCELoss()
与F.binary_cross_entropy
计算结果是等价的,具体两者差距可见PyTorch 中,nn 与 nn.functional 有什么区别?nn.BCEWithLogitsLoss
: combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability. 至于为什么更稳定,见 https://blog.csdn.net/u010630669/article/details/105599067- 二分类交叉熵损失函数的input和target的shape是一致的
多分类交叉熵损失函数
Single
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
print(output)
f_output = F.cross_entropy(input, target)
print(f_output)
tensor(1.7541, grad_fn=<NllLossBackward>)
tensor(1.7541, grad_fn=<NllLossBackward>)
Batch
loss = nn.CrossEntropyLoss()
input = torch.randn(32, 10, 5, requires_grad=True)
target = torch.empty(32, 5, dtype=torch.long).random_(10)
output = loss(input, target)
print(output)
f_output = F.cross_entropy(input, target)
print(f_output)
tensor(2.7944, grad_fn=<NllLoss2DBackward>)
tensor(2.7944, grad_fn=<NllLoss2DBackward>)
Note
nn.CrossEntropyLoss
与F.cross_entropy
计算结果是等价的。两个函数都结合了LogSoftmax
andNLLLoss
运算nn.CrossEntropyLoss
的公式为:loss(x,class )=−log(exp(x[class])∑jexp(x[j]))=−x[class]+log(∑jexp(x[j]))\operatorname{loss}(\mathrm{x}, \text { class })=-\log \left(\frac{\exp (\mathrm{x}[\mathrm{class}])}{\sum_{\mathrm{j}} \exp (\mathrm{x}[\mathrm{j}])}\right)=-\mathrm{x}[\mathrm{class}]+\log \left(\sum_{\mathrm{j}} \exp (\mathrm{x}[\mathrm{j}])\right)loss(x, class )=−log(∑jexp(x[j])exp(x[class]))=−x[class]+log(∑jexp(x[j])),这与我们平时见到的多分类交叉熵损失函数有点不同,具体的推导过程见Pytorch里的CrossEntropyLoss详解
参考自:
- https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=crossentropyloss#torch.nn.CrossEntropyLoss
- https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html?highlight=bceloss#torch.nn.BCELoss
- https://www.cnblogs.com/marsggbo/p/10401215.html
pytorch中的二分类及多分类交叉熵损失函数相关推荐
- 交叉熵损失函数分类_BCE和CE交叉熵损失函数的区别
首先需要说明的是PyTorch里面的BCELoss和CrossEntropyLoss都是交叉熵,数学本质上是没有区别的,区别在于应用中的细节. BCE适用于0/1二分类,计算公式就是 " - ...
- LESSON 10.110.210.3 SSE与二分类交叉熵损失函数二分类交叉熵损失函数的pytorch实现多分类交叉熵损失函数
在之前的课程中,我们已经完成了从0建立深层神经网络,并完成正向传播的全过程.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具AutoGrad ...
- pytorch_lesson10 二分类交叉熵损失函数及调用+多分类交叉熵损失函数及调用
注:仅仅是学习记录笔记,搬运了学习课程的ppt内容,本意不是抄袭!望大家不要误解!纯属学习记录笔记!!!!!! 文章目录 一.机器学习中的优化思想 二.回归:误差平方和SSE 三.二分类交叉熵损失函数 ...
- 交叉熵损失函数分类_逻辑回归(Logistic Regression)二分类原理,交叉熵损失函数及python numpy实现...
本文目录: 1. sigmoid function (logistic function) 2. 逻辑回归二分类模型 3. 神经网络做二分类问题 4. python实现神经网络做二分类问题 ----- ...
- 二分类交叉熵损失函数python_【深度学习基础】第二课:softmax分类器和交叉熵损失函数...
[深度学习基础]系列博客为学习Coursera上吴恩达深度学习课程所做的课程笔记. 本文为原创文章,未经本人允许,禁止转载.转载请注明出处. 1.线性分类 如果我们使用一个线性分类器去进行图像分类该怎 ...
- 二分类交叉熵损失函数python_二分类问题的交叉熵损失函数多分类的问题的函数交叉熵损失函数求解...
二分类问题的交叉熵损失函数; 在二分类问题中,损失函数为交叉熵损失函数.对于样本(x,y)来讲,x为样本 y为对应的标签.在二分类问题中,其取值的集合可能为{0,1},我们假设某个样本的真实标签为yt ...
- pytoch人工神经网络基础:最简单的分类(softmax回归+交叉熵分类)
softmax回归分类原理 对于回归问题,可以用模型预测值与真实值比较,用均方误差这样的损失函数表示误差,迭代使误差最小训练模型. 那么分类问题是否可以用线性回归模型预测呢.最简单的方法就是用soft ...
- pytorch中交叉熵损失函数的细节
目前pytorch中的交叉熵损失函数主要分为以下三类,我们将其使用的要点以及场景做一下总结. 类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss() 输入: ...
- 交叉熵损失函数分类_交叉熵损失函数
我们先从逻辑回归的角度推导一下交叉熵(cross entropy)损失函数. 从逻辑回归到交叉熵损失函数 这部分参考自 cs229-note1 part2. 为了根据给定的 预测 (0或1),令假设函 ...
最新文章
- 如何快速融入团队并成为团队核心?(九)
- 【C#食谱】【杭帮菜】菜单2:写一个TCP客户端
- Unity自定义角色控制器(一):碰撞检测
- 会话创建过程-创建Executor
- 基于 Flink+Iceberg 构建企业级实时数据湖 | 附 PPT 下载
- 只不过是R.java文件的特性-----出错信息:R.java was modified manually! Reverting to generated version!...
- PYSQLITE用法初探
- python爬虫之图片下载APP1.0
- WINDOWS获得当前执行程序路径的办法
- Mybatis 报错Mapper method ‘xxx‘ has an unsupported return type
- Linguistic Regularities in Continuous Space Word Representations
- 网页搜索(百度谷歌)你不得不知道的十个小技巧
- ORACLE通过dblink同步SDO_ORDINATE_ARRAY_STR的数据
- es6删除数组某一项_什么时候用集合,什么时候用数组?一文帮你清晰界定
- 特斯拉和拼多多,到底在「较真儿」什么?
- libpng16.so.16错误
- BlazeDS是什么?
- 浅浅瞅瞅RSA-PSS 算法
- Centos 7安装Harbor
- JQ----移动端h5页面通过地址调起通讯录以及高德地图、百度地图定位导航