一般来说,监督学习的目标函数由损失函数和正则化项组成。(Objective = Loss + Regularization)

Pytorch中的损失函数一般在训练模型时候指定。

注意Pytorch中内置的损失函数的参数和tensorflow不同,是y_pred在前,y_true在后,而Tensorflow是y_true在前,y_pred在后。

对于回归模型,通常使用的内置损失函数是均方损失函数nn.MSELoss 。

对于二分类模型,通常使用的是二元交叉熵损失函数nn.BCELoss (输入已经是sigmoid激活函数之后的结果) 或者 nn.BCEWithLogitsLoss (输入尚未经过nn.Sigmoid激活函数) 。

对于多分类模型,一般推荐使用交叉熵损失函数 nn.CrossEntropyLoss。 (y_true需要是一维的,是类别编码。y_pred未经过nn.Softmax激活。)

此外,如果多分类的y_pred经过了nn.LogSoftmax激活,可以使用nn.NLLLoss损失函数(The negative log likelihood loss)。 这种方法和直接使用nn.CrossEntropyLoss等价。

如果有需要,也可以自定义损失函数,自定义损失函数需要接收两个张量y_pred,y_true作为输入参数,并输出一个标量作为损失函数值。

Pytorch中的正则化项一般通过自定义的方式和损失函数一起添加作为目标函数。

一,内置损失函数

内置的损失函数一般有类的实现和函数的实现两种形式。

如:nn.BCE 和 F.binary_cross_entropy 都是二元交叉熵损失函数,前者是类的实现形式,后者是函数的实现形式。

实际上类的实现形式通常是调用函数的实现形式并用nn.Module封装后得到的。

一般我们常用的是类的实现形式。它们封装在torch.nn模块下,并且类名以Loss结尾。

常用的一些内置损失函数说明如下。

nn.MSELoss(均方误差损失,也叫做L2损失,用于回归)

nn.L1Loss (L1损失,也叫做绝对值误差损失,用于回归)

nn.SmoothL1Loss (平滑L1损失,当输入在-1到1之间时,平滑为L2损失,用于回归)

nn.BCELoss (二元交叉熵,用于二分类,输入已经过nn.Sigmoid激活,对不平衡数据集可以用weigths参数调整类别权重)

nn.BCEWithLogitsLoss (二元交叉熵,用于二分类,输入未经过nn.Sigmoid激活)

nn.CrossEntropyLoss (交叉熵,用于多分类,要求label为稀疏编码,输入未经过nn.Softmax激活,对不平衡数据集可以用weigths参数调整类别权重)

nn.NLLLoss (负对数似然损失,用于多分类,要求label为稀疏编码,输入经过nn.LogSoftmax激活)

nn.CosineSimilarity(余弦相似度,可用于多分类)

nn.AdaptiveLogSoftmaxWithLoss (一种适合非常多类别且类别分布很不均衡的损失函数,会自适应地将多个小类别合成一个cluster)

更多损失函数的介绍参考如下知乎文章:

《PyTorch的十八个损失函数》

二,自定义L1和L2正则化项

通常认为L1 正则化可以产生稀疏权值矩阵,即产生一个稀疏模型,可以用于特征选择。

而L2 正则化可以防止模型过拟合(overfitting)。一定程度上,L1也可以防止过拟合。

# L2正则化
def L2Loss(model,alpha):l2_loss = torch.tensor(0.0, requires_grad=True)for name, param in model.named_parameters():if 'bias' not in name: #一般不对偏置项使用正则l2_loss = l2_loss + (0.5 * alpha * torch.sum(torch.pow(param, 2)))return l2_loss# L1正则化
def L1Loss(model,beta):l1_loss = torch.tensor(0.0, requires_grad=True)for name, param in model.named_parameters():if 'bias' not in name:l1_loss = l1_loss +  beta * torch.sum(torch.abs(param))return l1_loss# 将L2正则和L1正则添加到FocalLoss损失,一起作为目标函数
def focal_loss_with_regularization(y_pred,y_true):focal = FocalLoss()(y_pred,y_true) l2_loss = L2Loss(model,0.001) #注意设置正则化项系数l1_loss = L1Loss(model,0.001)total_loss = focal + l2_loss + l1_lossreturn total_lossmodel.compile(loss_func =focal_loss_with_regularization,optimizer= torch.optim.Adam(model.parameters(),lr = 0.01),metrics_dict={"accuracy":accuracy})

只写了部分,具体的参考《20天吃透Pytorch》

Pytorch损失函数losses简介相关推荐

  1. pytorch深度学习简介(包括cnn,rnn等我只挑我感觉有必要记录)

    一,深度学习基础 1. 了解常见的四个机器学习方法 监督学习.无监督学习.半监督学习.强化学习是我们日常接触到的常见的四个机器学习方法: 监督学习:通过已有的训练样本(即已知数据以及其对应的输出)去训 ...

  2. pytorch 损失函数总结

    PyTorch深度学习实战 4 损失函数 损失函数,又叫目标函数,是编译一个神经网络模型必须的两个参数之一.另一个必不可少的参数是优化器. 损失函数是指用于计算标签值和预测值之间差异的函数,在机器学习 ...

  3. Pytorch —— 损失函数(二)

    目录 5.nn.L1Loss 6.nn.MSELoss 7.nn.SmoothL1Loss 8.nn.PoissonNLLLoss 9.nn.KLDivLoss 10.nn.MarginRanking ...

  4. Pytorch:损失函数

    4.1.4 损失函数 在深度学习中要用到各种各样的损失函数(loss function),这些损失函数可看作是一种特殊的layer,PyTorch也将这些损失函数实现为nn.Module的子类.然而在 ...

  5. [PyTorch] 损失函数

    参考 pytorch常用损失函数 为什么用交叉熵做损失函数 BCELoss 这是 官方文档 二分类交叉熵损失(BinaryCrossEntropyLoss) ln=−wn[yn⋅log2xn+(1−y ...

  6. Pytorch 损失函数 Mean Squared Error

    Pytorch的损失函数定义在torch.nn.functional下,可以直接使用. Mean Squared Error(MSE)即均方误差,常用在数值型输出上: 其中θ是网络的参数,取决于使用的 ...

  7. PyTorch:生态简介

    PyTorch生态简介 PyTorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了Py ...

  8. pytorch损失函数中‘reduction‘参数

    内容介绍 在调用pytorch的损失函数时,会有一个'reduction'的参数,本文介绍使用不同的参数对应的结果,以L1 loss为例子: reduction = mean 当使用的参数为 mean ...

  9. Pytorch损失函数解析

    本文根据pytorch里面的源码解析各个损失函数,各个损失函数的python接口定义于包torch.nn.modules中的loss.py,在包modules的初始化__init__.py中关于损失函 ...

最新文章

  1. 期望dp ---- B. Tree Array 思维+期望dp 逆序对期望数
  2. j - 数据结构实验:哈希表_一看就懂的数据结构基础「哈希表」
  3. iOS一些推荐的学习路径发展
  4. linux mmu的实现的讲解_Linux中的段
  5. linux游戏调试,LINUX游戏服务器的安装与调试.doc
  6. 【Elasticsearch】所有可用 Qbox 插件的概述:第二部分
  7. 最简单的flex bison例子
  8. javacript 验证函数
  9. 基于物联网的工业分析将席卷制造业
  10. 7. Document write() 方法
  11. pycharm调试步骤(详细)
  12. UE4官方文档学习笔记材质篇——分层材质
  13. matlab运行出现:Optimization terminated.
  14. 3w+字,Python办公自动化之Excel报表自动化,看这一篇就够了!
  15. Windows下安装 msysGit 以及初始化 Git server环境
  16. 最活跃FPGA论坛推荐社区
  17. 论文解读:学习蛋白质的空间结构可以提高蛋白质相互作用的预测
  18. 物联网平台常见问题与答案汇总
  19. 浅谈商城站点如何创造价值
  20. STOCHRSI 指标理解

热门文章

  1. HDU2683——欧拉完全数
  2. Wormholes——Bellman-Ford判断负环
  3. Linux 进程学习(四)------ sigaction 函数
  4. netstat 相关命令解析
  5. linux下成功安装ffmpeg( 亲测有效 )
  6. 07-图6 旅游规划 (25 分)
  7. java 注解 方法 参数_java在注解中绑定方法参数的解决方案
  8. Android APK 打包过程 MD
  9. webapi 找到了与请求匹配的多个操作(ajax报500,4的错误)
  10. 快速学习23种设计模式思想Design Patterns