Pytorch基础(九)——损失函数
一、概念
损失函数在深度学习领域是用来计算搭建模型预测的输出值和真实值之间的误差。
具体实现过程:在一个批次(batch)前向传播完成后,得到预测值,然后损失函数计算出预测值和真实值之间的差值,反向传播去更新权值和偏置等参数,以降低差值,不断向真实值接近,最终得到效果良好的模型。
常见的损失函数包括:MSE(均方差, 也可以叫L2Loss),Cross Entropy Loss(交叉熵),L1 Loss(L1平均绝对值误差),Smooth L1 Loss(平滑的L1 loss),BCELoss (Binary Cross Entropy)等。下面分别对这些损失函数举例说明。
只写了一部分,后面陆续增加。。
二、Pytorch举例
2.1 MSELoss
MSELoss 就是计算真实值和预测值的均方差,也可以叫L2 Loss。
特点:MSE收敛速度比较快,能提供最大似然估计,是回归问题、模式识别、图像处理中最常使用的损失函数。
import torch
from torch import nn
from torch.nn import MSELossinputs = torch.tensor([1, 2, 3], dtype=torch.float32)
outputs = torch.tensor([2, 2, 4], dtype=torch.float32)# MSE
# size_average为True,表示计算批前向传播后损失函数的平均值,如果为False,则计算损失函数的和。
# 同样,reduce为True,返回标量;reduce为False, size_average参数失效,直接返回向量形式的loss
# reduction目的为减少tensor中元素的数量。为none,表示不减少;为'sum',表示求和;为'mean',表示求平均值loss_mse = nn.MSELoss()
result_mse = loss_mse(inputs, outputs)
print(result_mse)loss_mse1 = nn.MSELoss(reduction='sum')
result_mse1 = loss_mse1(inputs, outputs)
print(result_mse1)loss_mse2 = nn.MSELoss(size_average=False, reduce=False, reduction='sum')
result_mse2 = loss_mse2(inputs, outputs)
print(result_mse2)
输出
tensor(0.6667)
tensor(2.)
tensor([1., 0., 1.])
2.2 L1Loss
L1Loss是计算预测值和真实值的平均绝对误差。
特点:对异常点的鲁棒性更强,但在残差为零处不可导,收敛速度比较慢。
loss_l1 = L1Loss()
result_l1 = loss_l1(inputs, outputs)
print(result_l1)
tensor(0.6667)
2.3 SmoothL1loss
SmoothL1loss是L1Loss 和MSE的混合,最早在Fast R-CNN中提出。
特点:收敛速度稳定,模型更容易收敛到局部最优,防止梯度爆炸。
# beta默认为1,表示指定要在L1和L2损失之间更改的阈值。
loss_smol1 = SmoothL1Loss()
result_smol1 = loss_smol1(inputs, outputs)
print(result_smol1)
tensor(0.3333)
2.4 CrossEntropyLoss
CrossEntropyLoss表示概率分布之间的距离,当交叉熵越小说明二者之间越接近,对于高维输入比较有用。一般都需要激活函数将输入转变为(0,1)之间。
经典公式:
其实这个表示BCELoss(二分类交叉熵)。
pytorch的公式表示的是多分类问题:
1)当目标targets 包括类索引,ignore_index才可以设置.
2)表示每个类别的概率;当每个小批项目需要超过单个类别的标签时非常有用,例如混合标签、标签平滑等。
其中: x为输入值,y为目标值,C代表类别数量,w为权值参数。
# weight :为每个类指定的手动重缩放权重。
# ignore_index:ignore_index表示指定忽略目标值,但不影响输入梯度。
# label_smoothing :在[0.0,1.0]之间的浮点型。指定计算损失时的平滑量,其中 0.0 表示不平滑。 如重新思考计算机视觉的初始架构中所述,目标成为原始基本事实和均匀分布的混合。
# 交叉熵损失
x = torch.tensor([0.5, 0.2, 0.3])
x = torch.reshape(x, (1, 3))
print(x)
y = torch.tensor([1])loss_cross = CrossEntropyLoss()
result_cross = loss_cross(x, y)
print(result_cross)
tensor([[0.5000, 0.2000, 0.3000]])
tensor(1.2398)
三、参考文章
目标检测回归损失函数简介:SmoothL1/IoU/GIoU/DIoU/CIoU Loss
损失函数(八)
Pytorch基础(九)——损失函数相关推荐
- 深度学习之Pytorch基础教程!
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展 ...
- 吴恩达深度学习笔记2-Course1-Week2【神经网络基础:损失函数、梯度下降】
神经网络基础:损失函数.梯度下降 本篇以最简单的多个输入一个输出的1层神经网络为例,使用logistic regression讲解了神经网络的前向反向计算(forward/backward propa ...
- 【深度学习】深度学习之Pytorch基础教程!
作者:李祖贤,Datawhale高校群成员,深圳大学 随着深度学习的发展,深度学习框架开始大量的出现.尤其是近两年,Google.Facebook.Microsoft等巨头都围绕深度学习重点投资了一系 ...
- PyTorch基础(part3)
学习笔记,仅供参考,有错必纠 文章目录 PyTorch 基础 线性回归 常用代码 导包 生成数据 构建神经网络模型 非线性回归 生成数据 构建神经网络模型 PyTorch 基础 线性回归 常用代码 # ...
- PyTorch学习笔记(四):PyTorch基础实战
PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...
- 第1周学习笔记:深度学习和pytorch基础
目录 一 视频学习 1.绪论 2.深度学习概述 二 代码学习 1.Pytorch基础练习 2.螺旋数据分类 一 视频学习 1.绪论 人工智能(Artificial Intelligence):使一部机 ...
- OUC暑期培训(深度学习)——第一周学习记录:深度学习和pytorch基础
第一周学习:深度学习和pytorch基础 目录 第一周学习:深度学习和pytorch基础 Part 1:视频学习: 1. 绪论: 2. 深度学习概述: Part 2:代码练习: 1. pytorch基 ...
- Pytorch基础打卡01
1 课程规划 1.1 第一部分 pytorch深度学习基础知识 pytorch简介与安装 pytorch基础知识 pytorch 主要组成模块 基础实战 Fashion-MNIST时装分类 ## 1. ...
- 第1周学习:深度学习和pytorch基础
第1周学习:深度学习和pytorch基础 一.概念学习 1.1关于一些基本问题的思考 1.2深度学习基础 二.代码练习 pytorch 基础练习 螺旋数据分类问题 一.概念学习 1.1关于一些基本问题 ...
- Pytorch ——基础指北_叁 [Pytorch API 构建基础模型]
Pytorch --基础指北_叁 系列文章目录 Pytorch --基础指北_零 Pytorch --基础指北_壹 Pytorch --基础指北_贰 Pytorch --基础指北_叁 文章目录 Pyt ...
最新文章
- Kotlin for 循环使用
- pfSense 2.4.0-RC版发布了!
- 要离开苏州,一大堆东西要处理(包括租的房子)
- c3p0 mysql maven_Maven+JSP+Servlet+C3P0+Mysql实现的音乐库管理系统
- c语言里的%p的作用,C语言中geiwei=m%10什么意思,求解!
- Python爬虫的框架有哪些?推荐这五个!
- HashSet 的contains方法
- Spring4.x(7)---对象的生命周期方法
- Bzoj 4147: [AMPPZ2014]Euclidean Nim(博弈)
- echo linux命令_Linux echo命令示例
- android之app自动启动
- c语言课程设计酒店管理系统实验报告 免费下载,C语言酒店管理系统设计
- 二阶有源带通滤波器滤波原理
- 不透明度十六进制_十六进制不透明度表
- java基础中的基础,简单中的简单
- 微信公众号微信支付提示 调用支付JSAPI缺少参数:appId
- “工业互联网+安全生产”,提升工业企业安全水平
- 奇数值结点链表(C语言实现)
- 中文拼写检测(Chinese Spelling Checking)相关方法、评测任务、榜单
- Hollo world
热门文章
- 如何学习Linux / 新手入门
- Dynamic Data Web Application编译是报GetActionPath调用模糊解决办法
- 简单的动画函数封装(2)
- javaScript第二天(1)
- 修改wordpress上传文件大小限制
- 【LOJ】#2184. 「SDOI2015」星际战争
- 2018暑假集训测试六总结
- Spring_01 spring容器、控制反转(IOC)、依赖注入(DI)
- Java基础之写文件——缓冲区中的多条记录(PrimesToFile3)
- 关于ICallbackEventHandler的疑问