应用

概念讲解

1)假设有m张图片,经过神经网络后输出为m*n的矩阵(m是图片个数,n是图片类别),下例中:
m=2,n=2既有两张图片,供区分两种类别比如猫狗。假设第0维为猫,第1维为狗

import torch
input=torch.randn(2,2)
input
------------------------
tensor([[-1.6243, -0.4164],[-0.2492, -0.9667]])
------------------------

2)使用softmax将其转化为概率,我们可以看到,第一张图片是狗的概率大,第二张是猫的概率大。

soft = torch.nn.Softmax(dim=1) # 横向计算softmax
soft(input) # 将输出转化为概率
-------------------------
tensor([[0.2301, 0.7699],[0.6721, 0.3279]])
--------------------------

3)对上述结果取对数:(可以使用logsoft(input)替代2,3步骤)

torch.log(soft(input))
---------------------------
tensor([[-1.4694, -0.2615],[-0.3974, -1.1149]])
---------------------------

4)NLLLoss结果就是把上面取对数之后的结果与Label对应的那个值拿出来,再去掉负号,然后求和取均值。
假设target是[0,1]既第一张是猫,第二张是狗。第一行取第0个元素,第二行取第1个,去掉负号,求和取均值,既:

(-(-1.4694) + -(-1.1149))/2 = 1.29215

直接使用NLLLoss函数验证:

nll = nn.NLLLoss()
target = torch.tensor([0,1])
nll(torch.log(soft(input)),target)
-----------------------------------------------
tensor(1.2921)
-------------------------------------------------

5)CrossEntropyLoss其实就是Softmax–Log–NLLLoss合并成一步。

ce = nn.CrossEntropyLoss()
ce(input,target)
-----------------------------
tensor(1.2921)
------------------------------

API

This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class.

参考

https://blog.csdn.net/qq_22210253/article/details/85229988
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=nn%20crossentropyloss#torch.nn.CrossEntropyLoss

pytorch nn.CrossEntropyLoss相关推荐

  1. PyTorch nn.CrossEntropyLoss() dimension out of range (expected to be in range of [-1, 0], but got 1)

    import torch import torch.nn as nn loss_fn = nn.CrossEntropyLoss() # 方便理解,此处假设batch_size = 1 x_input ...

  2. [深度学习] Pytorch nn.CrossEntropyLoss()和nn.NLLLoss() 区别

    nn.NLLLoss()的参数是经过logsoftmax加工的,而CrossEntropyLoss的是原始输出数据 target = torch.tensor([1, 2]) entropy_out ...

  3. 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)

    在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ⁡ ( y , clas ...

  4. 对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用

    在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题. BCELoss是Binary CrossEntropyLoss的缩写,nn. ...

  5. 深入浅出PyTorch中的nn.CrossEntropyLoss

    目录 一.前言 二.理论基础 三.主要参数 3.1 输入与输出 四.从零开始实现 `nn.CrossEntropyLoss` 一.前言 nn.CrossEntropyLoss 常用作多分类问题的损失函 ...

  6. pytorch的nn.CrossEntropyLoss()函数使用方法

    nn.CrossEntropyLoss()函数计算交叉熵损失 用法: # output是网络的输出,size=[batch_size, class] #如网络的batch size为128,数据分为1 ...

  7. PyTorch之torch.nn.CrossEntropyLoss()

    简介 信息熵: 按照真实分布p来衡量识别一个样本所需的编码长度的期望,即平均编码长度 交叉熵: 使用拟合分布q来表示来自真实分布p的编码长度的期望,即平均编码长度 多分类任务中的交叉熵损失函数 代码 ...

  8. nn.CrossEntropyLoss总结

    nn.CrossEntropyLoss总结 目录 nn.CrossEntropyLoss nn.LogSoftmax nn.NLLLoss Cross entropy 目录 版本 pytorch 1. ...

  9. 交叉熵损失函数python_交叉熵损失函数nn.CrossEntropyLoss()

    nn.CrossEntropyLoss() 1.引言 在使用pytorch深度学习框架做多分类时,计算损失函数通常会使用交叉熵损失函数nn.CrossEntropyLoss() 2. 信息量和熵 信息 ...

最新文章

  1. 高可用的Spring FTP上传下载工具类(已解决上传过程常见问题)
  2. 大家是否也习惯将常用到的Python软件包放在一个头文件里?
  3. 关于跨域策略文件crossdomain.xml文件
  4. web监听器监听mysql_JavaWEB开发15——ListenerListener
  5. JavaScript操作文件
  6. CY7C68013 USB接口相机开发记录 - 第一天:资料下载
  7. 并查集路径压缩_并查集简单教学
  8. 怎么看接收灵敏度desense问题?
  9. pythonxy官网下载_spyder安装包
  10. vue实现复制到剪切板的功能
  11. 服务器修改传奇道士神兽升级,1.76复古传奇道士玩家快速升级神兽的技巧
  12. 美团构建实时数仓的痛点是什么?如何解决?
  13. 继续:个人微信的自动收款解决(思路)
  14. [深度应用]·DC竞赛轴承故障检测开源Baseline(基于Keras1D卷积 val_acc:0.99780)
  15. rog魔霸新锐2022款 评测 怎么样
  16. Java 基础 | Java 中引用与指针的关系
  17. MySQL 去除字符串中的括号以及括号内的内容
  18. 视频教程-微信小程序系统教程Java版[3/3阶段]_微信小程序电商系统-微信开发
  19. 面了20家,原来大厂面试的套路是……
  20. excel数字不能累加_如何修复不累加的Excel编号

热门文章

  1. 50 - 算法 -二叉树 - 递归 - LeetCode 101
  2. 和pythondjango后端_webGIS实践:4_0_python django后端搭建web工程
  3. 数据分析20大基本分析方法技术总结【分析目的、分析案例、分析方法与思路】
  4. Java继承知识之基本控制语句(if、switch与穿透现象)
  5. ubuntu如何调出python_ubuntu|linux下 如何用python 模拟按键
  6. 热烈庆祝《Python可以这样学》在台湾发行繁体版
  7. python优先级排序_Python Numpy重新排列双向排序
  8. 遍历 in java_[Java教程]JavaScript中遍历数组 最好不要使用 for in 遍历
  9. python中set函数_python中set()函数简介及实例解析
  10. Python多线程笔记——简单函数版和类实现版