Pytorch专题实战——交叉熵损失函数(CrossEntropyLoss )
文章目录
- 1.用CrossEntropyLoss预测单个目标
- 2.用CrossEntropyLoss预测多个目标
- 3.二分类使用BCELoss损失函数
- 4.多分类使用CrossEntropyLoss损失函数
1.用CrossEntropyLoss预测单个目标
loss = nn.CrossEntropyLoss() #实例化交叉熵损失函数
Y = torch.tensor([0]) #预测是第0个
Y_pred_good = torch.tensor([[2.0, 1.0, 0.1]])
Y_pred_bad = torch.tensor([[0.5, 2.0, 0.3]])
l1 = loss(Y_pred_good, Y) #计算loss
l2 = loss(Y_pred_bad, Y)
print(f'Pytorch Loss1:{l1.item():.4f}') #小数点后保留4位
print(f'Pytorch Loss2:{l2.item():.4f}')
_, predictions1 = torch.max(Y_pred_good, 1)
_, predictions2 = torch.max(Y_pred_bad, 1)
2.用CrossEntropyLoss预测多个目标
Y = torch.tensor([2,0,1]) #三个目标值
Y_pred_good = torch.tensor( #三组待预测[[0.1, 0.2, 3.9],[1.2, 0.1, 0.3],[0.3, 2.2, 0.2]])
Y_pred_bad = torch.tensor([[0.9, 0.2, 0.1],[0.1, 0.3, 1.5],[1.2, 0.2, 0.5]])l1 = loss(Y_pred_good, Y)
l2 = loss(Y_pred_bad, Y)
print(f'Batch Loss1: {l1.item():.4f}')
print(f'Batch Loss2:{l2.item():.4f}')
_, predictions1 = torch.max(Y_pred_good, 1)
_, predictions2 = torch.max(Y_pred_bad, 1)
print(f'Actual class:{Y}, Y_pred1:{predictions1}, Y_pred2:{predictions2}')
3.二分类使用BCELoss损失函数
class NeuralNet1(nn.Module):def __init__(self, input_size, hidden_size):super(NeuralNet1, self).__init__()self.linear1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.linear2 = nn.Linear(hidden_size, 1) #二分类最后输出单元个数为1def forward(self, x):out = self.linear1(x)out = self.relu(out)out = self.linear2(out)y_pred = torch.sigmoid(out)return y_predmodel = NeuralNet1(input_size=28*28, hidden_size=5)
criterion = nn.BCELoss()
4.多分类使用CrossEntropyLoss损失函数
class NeuralNet2(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet2, self).__init__()self.linear1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.linear2 = nn.Linear(hidden_size, num_classes)def forward(self, x):out = self.linear1(x)out = self.relu(out)out = self.linear2(out)return outmodel = NeuralNet2(input_size=28*28, hidden_size=5, num_classes=3)
criterion = nn.CrossEntropyLoss()
Pytorch专题实战——交叉熵损失函数(CrossEntropyLoss )相关推荐
- 【Pytorch】交叉熵损失函数 CrossEntropyLoss() 详解
文章目录 一.损失函数 nn.CrossEntropyLoss() 二.什么是交叉熵 三.Pytorch 中的 CrossEntropyLoss() 函数 参考链接 一.损失函数 nn.CrossEn ...
- 解决pytorch二分类任务交叉熵损失函数CrossEntropyLoss报错:IndexError: Target 1 is out of bounds.
解决方法 修改nn.CrossEntropyLoss()为nn.BCELoss() 问题解析 pytorch 中二分类任务交叉熵要用二分类交叉熵(Binary Cross Entropy),BCELo ...
- pytorch中交叉熵损失函数的细节
目前pytorch中的交叉熵损失函数主要分为以下三类,我们将其使用的要点以及场景做一下总结. 类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss() 输入: ...
- 交叉熵损失函数分类_PyTorch学习笔记——多分类交叉熵损失函数
理解交叉熵 关于样本集的两个概率分布p和q,设p为真实的分布,比如[1, 0, 0]表示当前样本属于第一类,q为拟合的分布,比如[0.7, 0.2, 0.1]. 按照真实分布p来衡量识别一个样本所需的 ...
- 深入理解Pytorch负对数似然函数(torch.nn.NLLLoss)和交叉熵损失函数(torch.nn.CrossEntropyLoss)
在看Pytorch的交叉熵损失函数torch.nn.CrossEntropyLoss官方文档介绍中,给出的表达式如下.不免有点疑惑为何交叉熵损失的表达式是这个样子的 loss ( y , clas ...
- LESSON 10.110.210.3 SSE与二分类交叉熵损失函数二分类交叉熵损失函数的pytorch实现多分类交叉熵损失函数
在之前的课程中,我们已经完成了从0建立深层神经网络,并完成正向传播的全过程.本节课开始,我们将以分类深层神经网络为例,为大家展示神经网络的学习和训练过程.在介绍PyTorch的基本工具AutoGrad ...
- 交叉熵损失函数python_交叉熵损失函数nn.CrossEntropyLoss()
nn.CrossEntropyLoss() 1.引言 在使用pytorch深度学习框架做多分类时,计算损失函数通常会使用交叉熵损失函数nn.CrossEntropyLoss() 2. 信息量和熵 信息 ...
- pytorch中的二分类及多分类交叉熵损失函数
本文主要记录一下pytorch里面的二分类及多分类交叉熵损失函数的使用. import torch import torch.nn as nn import torch.nn.functional a ...
- 语义分割损失函数系列(1):交叉熵损失函数
最近一直在做一些语义分割相关的项目,找损失函数的时候发现网上这些大佬的写得各有千秋,也没说怎么用,在此记录一下自己在训练过程中使用损失函数的一些心得.本人是使用的Pytorch框架,故这一系列都会基于 ...
最新文章
- java IO知识总结
- Oozie基于Hue全流程调度
- try catch finally语句详解
- 【Python】while循环实现用户登录的三次机会
- pandas 作图 统计_解决pandas 作图无法显示中文的问题
- 历史上最怪异的23种飞行器,设计者脑子里都想什么了
- CentOS7下ab压力测试Nginx和Tomcat
- c++语言编程,一个电灯两个开关控制,[理学]四川大学计算机学院精品课程_面向对象程序设计C++课件_游洪越_第一章绪论.ppt...
- scala方法中的变量_Scala变量,变量范围,字段变量,方法参数示例
- 如何让html箭头绝对居中,html - 将垂直居中对齐的导航箭头定位到图像的左侧和右侧 - SO中文参考 - www.soinside.com...
- java笔试+面试总结(大纲)
- cocos2d-x-3.x 场景(3)场景切换特效
- python问题:IndentationError:expected an indented block错误
- 优秀的云计算工程师需要学什么?云计算Docker学习路线
- 北大计算机博士毕业难度,北京大学博士毕业要求
- 法人银行贷款逾期信息查询
- 使用js jquery去搭建完成京东购物车
- Lombok介绍、使用方法和总结
- Lending Club贷款违约预测
- Android 欢迎引导页的魅力
热门文章
- ORACLE REGEXP应用实例
- 深入浅出DDoS***
- h3c交换机怎么设置虚拟服务器,H3C交换机配置 | 如何实现两个网段主机与外部通信...
- python print sep,Python3.x语句print(1,2,3,sep=’:’)的输出结果为()。
- 用Photoshop制作简单贺卡
- DW8里面的HTML面板在哪里,打开Dreamweaver8窗口后,如果没有出现属性面板,可执行()菜单中的 - 问答库...
- 你可能没听过的 Java 8 中的 10 个特性
- pytorch 入门学习多分类问题-9
- Android OpenGL ES 开发教程(16):Viewing和Modeling(MODELVIEW) 变换
- java 中文乱码过滤器_JAVA中文乱码过滤器(用java过滤器解决中文乱码)V0422 整理版...