NLLLoss

在图片单标签分类时,输入m张图片,输出一个m*N的Tensor,其中N是分类个数。比如输入3张图片,分三类,最后的输出是一个3*3的Tensor,举个例子:

第123行分别是第123张图片的结果,假设第123列分别是猫、狗和猪的分类得分。
可以看出模型认为第123张都更可能是猫。
然后对每一行使用Softmax,这样可以得到每张图片的概率分布。

这里dim的意思是计算Softmax的维度,这里设置dim=1,可以看到每一行的加和为1。比如第一行0.6600+0.0570+0.2830=1。

如果设置dim=0,就是一列的和为1。比如第一列0.2212+0.3050+0.4738=1。
我们这里一张图片是一行,所以dim应该设置为1。
然后对Softmax的结果取自然对数:

Softmax后的数值都在0~1之间,所以ln之后值域是负无穷到0。
NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
假设我们现在Target是[0,2,1](第一张图片是猫,第二张是猪,第三张是狗)。第一行取第0个元素,第二行取第2个,第三行取第1个,去掉负号,结果是:[0.4155,1.0945,1.5285]。再求个均值,结果是:

下面使用NLLLoss函数验证一下:

嘻嘻,果然是1.0128!

CrossEntropyLoss

CrossEntropyLoss就是把以上Softmax–Log–NLLLoss合并成一步,我们用刚刚随机出来的input直接验证一下结果是不是1.0128:

NLLLoss CrossEntropyLoss Pytorch相关推荐

  1. NLLLOSS CrossEntropyLoss

    今天在看论文的时候,看到了NLLLOSS函数,嗯?这是个啥,然后就查了查,原来是跟CrossEntropyLoss一样的,这里整理一下,方便以后查阅. NLLLOSS & CrossEntro ...

  2. 获取当前脚本目录路径问题汇总

    20211223 https://blog.csdn.net/qq_43178297/article/details/88053836 获取上一层目录 import osprint('***获取当前目 ...

  3. 深度学习网络模型可视化netron

    很多时候,复现人家工程的时候,需要了解人家的网络结构.但不同框架之间可视化网络层方法不一样,这样给研究人员造成了很大的困扰. 前段时间,发现了一个可视化模型结构的神奇:Netron 查看全文 http ...

  4. BCELoss、crossentropyLoss、NLLLoss的使用(pytorch)

    文章目录 BCELoss 参考文档 理解 demo 应用 crossentropyLoss.NLLLoss 参考文档 crossEntropyLoss NLLLoss BCELoss 用于二分类问题, ...

  5. PyTorch 入坑十一: 损失函数、正则化----深刻剖析softmax+CrossEntropyLoss

    这里写目录标题 概念 Loss Function Cost Function Objective Function 正则化 损失函数 交叉熵损失函数nn.CrossEntropyLoss() 自信息 ...

  6. PyTorch的十七个损失函数

    20220113 选损失函数的标准:能使得真实值和预测值越相近的时候总损失越小 20220303 机器学习大牛是如何选择回归损失函数的? MSE,MAE,huber loss 20210925 交叉熵 ...

  7. Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试

    使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...

  8. Pytorch:RNN、LSTM、GRU 构建人名分类器(one-hot版本、Embedding嵌入层版本)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 2. RNN经典案例 2.1 使用RNN模型构建人名分类器 学 ...

  9. Pytorch损失函数解析

    本文根据pytorch里面的源码解析各个损失函数,各个损失函数的python接口定义于包torch.nn.modules中的loss.py,在包modules的初始化__init__.py中关于损失函 ...

最新文章

  1. ceph bluestore源码分析:C++ 获取线程id
  2. Python3学习笔记(一):基础语法
  3. 目标检测(Google object_detection) API 上训练自己的数据集
  4. Matlab中fileter和conv的区别及卷积的计算方法
  5. python学起来难不难-Python自学难不难,培训班推荐?
  6. android NDK 编译hellojni 例子文件
  7. ora-03113 访问某条记录_用了Excel十几年,你居然不知道“记录单”?!可能错过一个亿……...
  8. 90 岁程序员,他的压缩算法改变了世界!
  9. Python3.x:pytesseract识别率提高(样本训练)
  10. Cesium 获取经纬度的几种方法
  11. mysql 共享锁(读写锁) 修改数据问题(update,insert)(LOCK IN SHARE MODE)
  12. dpdk LRO功能总结
  13. CentOS下MySQL安装失败,报socket '/tmp/mysql.sock错误解决方法
  14. 区块链交易——举例说明
  15. JavaScript - 正则(RegExp)判断文本框中是否包含特殊符号
  16. element日历(Calendar)排班
  17. Android 虚拟按键上报
  18. 联想电脑安装虚拟机出现不可恢复的错误
  19. mysql 统计七日留存率_移动APP中,7日留存率到底如何定义?
  20. 任务栏微信图标显示为白框,解决办法

热门文章

  1. 2022-2028年中国交通建设PPP模式深度分析及发展战略研究报告(全卷)
  2. Python gRPC 安装
  3. Sping中利用HandlerExceptionResolver实现全局异常捕获
  4. shap_value
  5. [翻译]Python中yield的解释
  6. tensorflow 学习笔记-- tf.reduce_max、tf.sequence_mask
  7. scikit-learn - 分类模型的评估 (classification_report)
  8. LeetCode简单题之按既定顺序创建目标数组
  9. 将编译器pass添加到Relay
  10. TVM如何训练TinyML