pytorch nn.CrossEntropyLoss
应用
概念讲解
1)假设有m张图片,经过神经网络后输出为m*n的矩阵(m是图片个数,n是图片类别),下例中:
m=2,n=2既有两张图片,供区分两种类别比如猫狗。假设第0维为猫,第1维为狗
import torch
input=torch.randn(2,2)
input
------------------------
tensor([[-1.6243, -0.4164],[-0.2492, -0.9667]])
------------------------
2)使用softmax
将其转化为概率,我们可以看到,第一张图片是狗的概率大,第二张是猫的概率大。
soft = torch.nn.Softmax(dim=1) # 横向计算softmax
soft(input) # 将输出转化为概率
-------------------------
tensor([[0.2301, 0.7699],[0.6721, 0.3279]])
--------------------------
3)对上述结果取对数:(可以使用logsoft(input)
替代2,3步骤)
torch.log(soft(input))
---------------------------
tensor([[-1.4694, -0.2615],[-0.3974, -1.1149]])
---------------------------
4)NLLLoss结果就是把上面取对数之后的结果与Label对应的那个值拿出来,再去掉负号,然后求和取均值。
假设target是[0,1]既第一张是猫,第二张是狗。第一行取第0个元素,第二行取第1个,去掉负号,求和取均值,既:
(-(-1.4694) + -(-1.1149))/2 = 1.29215
直接使用NLLLoss函数验证:
nll = nn.NLLLoss()
target = torch.tensor([0,1])
nll(torch.log(soft(input)),target)
-----------------------------------------------
tensor(1.2921)
-------------------------------------------------
5)CrossEntropyLoss其实就是Softmax–Log–NLLLoss合并成一步。
ce = nn.CrossEntropyLoss()
ce(input,target)
-----------------------------
tensor(1.2921)
------------------------------
API
This criterion combines nn.LogSoftmax()
and nn.NLLLoss()
in one single class.
参考
https://blog.csdn.net/qq_22210253/article/details/85229988
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=nn%20crossentropyloss#torch.nn.CrossEntropyLoss
pytorch nn.CrossEntropyLoss相关推荐
- PyTorch nn.CrossEntropyLoss() dimension out of range (expected to be in range of [-1, 0], but got 1)
import torch import torch.nn as nn loss_fn = nn.CrossEntropyLoss() # 方便理解,此处假设batch_size = 1 x_input ...
- [深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别
nn.NLLLoss()的参数是经过logsoftmax加工的,而CrossEntropyLoss的是原始输出数据 target = torch.tensor([1, 2]) entropy_out ...
- 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)
在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ( y , clas ...
- 对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用
在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题. BCELoss是Binary CrossEntropyLoss的缩写,nn. ...
- 深入浅出PyTorch中的nn.CrossEntropyLoss
目录 一.前言 二.理论基础 三.主要参数 3.1 输入与输出 四.从零开始实现 `nn.CrossEntropyLoss` 一.前言 nn.CrossEntropyLoss 常用作多分类问题的损失函 ...
- pytorch的nn.CrossEntropyLoss()函数使用方法
nn.CrossEntropyLoss()函数计算交叉熵损失 用法: # output是网络的输出,size=[batch_size, class] #如网络的batch size为128,数据分为1 ...
- PyTorch之torch.nn.CrossEntropyLoss()
简介 信息熵: 按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度 交叉熵: 使用拟合分布q来表示来自真实分布p的编码长度的期望,即平均编码长度 多分类任务中的交叉熵损失函数 代码 ...
- nn.CrossEntropyLoss总结
nn.CrossEntropyLoss总结 目录 nn.CrossEntropyLoss nn.LogSoftmax nn.NLLLoss Cross entropy 目录 版本 pytorch 1. ...
- 交叉熵损失函数python_交叉熵损失函数nn.CrossEntropyLoss()
nn.CrossEntropyLoss() 1.引言 在使用pytorch深度学习框架做多分类时,计算损失函数通常会使用交叉熵损失函数nn.CrossEntropyLoss() 2. 信息量和熵 信息 ...
最新文章
- 高可用的Spring FTP上传下载工具类(已解决上传过程常见问题)
- 大家是否也习惯将常用到的Python软件包放在一个头文件里?
- 关于跨域策略文件crossdomain.xml文件
- web监听器监听mysql_JavaWEB开发15——ListenerListener
- JavaScript操作文件
- CY7C68013 USB接口相机开发记录 - 第一天:资料下载
- 并查集路径压缩_并查集简单教学
- 怎么看接收灵敏度desense问题?
- pythonxy官网下载_spyder安装包
- vue实现复制到剪切板的功能
- 服务器修改传奇道士神兽升级,1.76复古传奇道士玩家快速升级神兽的技巧
- 美团构建实时数仓的痛点是什么?如何解决?
- 继续:个人微信的自动收款解决(思路)
- [深度应用]·DC竞赛轴承故障检测开源Baseline(基于Keras1D卷积 val_acc:0.99780)
- rog魔霸新锐2022款 评测 怎么样
- Java 基础 | Java 中引用与指针的关系
- MySQL 去除字符串中的括号以及括号内的内容
- 视频教程-微信小程序系统教程Java版[3/3阶段]_微信小程序电商系统-微信开发
- 面了20家,原来大厂面试的套路是……
- excel数字不能累加_如何修复不累加的Excel编号
热门文章
- 50 - 算法 -二叉树 - 递归 - LeetCode 101
- 和pythondjango后端_webGIS实践:4_0_python django后端搭建web工程
- 数据分析20大基本分析方法技术总结【分析目的、分析案例、分析方法与思路】
- Java继承知识之基本控制语句(if、switch与穿透现象)
- ubuntu如何调出python_ubuntu|linux下 如何用python 模拟按键
- 热烈庆祝《Python可以这样学》在台湾发行繁体版
- python优先级排序_Python Numpy重新排列双向排序
- 遍历 in java_[Java教程]JavaScript中遍历数组 最好不要使用 for in 遍历
- python中set函数_python中set()函数简介及实例解析
- Python多线程笔记——简单函数版和类实现版