SoftMarginLoss

用于二分类任务

对于包含NNN个样本的batch数据 D(x,y)D(x, y)D(x,y), xxx代表模型输出,yyy代表真实的类别标签,yyy中元素的值属于{1,−1}\{1,-1\}{1,−1}。losslossloss计算如下:

loss=∑ilog⁡(1+exp⁡(−y[i]∗x[i]))x.nelement ()loss= \frac{\sum_{i}\log (1+\exp (-y[i] * x[i]))}{\text { x.nelement }()}loss= x.nelement ()∑i​log(1+exp(−y[i]∗x[i]))​

x.nelement (){\text { x.nelement }()} x.nelement ()代表xxx中元素的个数

若每个样本对应一个二分类,则x.nelement ()=N{\text { x.nelement }()}=N x.nelement ()=N
若每个样本对应于M个二分类,则x.nelement ()==M∗N{\text { x.nelement }()}==M*N x.nelement ()==M∗N

  • 当 x[i]x[i]x[i]与y[i]y[i]y[i]同号,即预测正确时,exp⁡(−y[i]∗x[i])<1\exp(-y[i] * x[i]) <1exp(−y[i]∗x[i])<1, log⁡(1+exp⁡(−y[i]∗x[i]))<log2\log(1+\exp (-y[i] * x[i])) < log2log(1+exp(−y[i]∗x[i]))<log2, 值很小。并且y[i]∗x[i]y[i] * x[i]y[i]∗x[i]乘积越大,分类确信度越高,loss越小;

  • 当 x[i]x[i]x[i]与y[i]y[i]y[i]异号,即预测错误时,log⁡(1+exp⁡(−y[i]∗x[i]))\log (1+\exp (-y[i] * x[i]))log(1+exp(−y[i]∗x[i]))取值较大;

losslossloss取值log⁡(1+exp⁡(−y[i]∗x[i]))\log (1+\exp (-y[i] * x[i]))log(1+exp(−y[i]∗x[i]))而不是log⁡(exp⁡(−y[i]∗x[i]))\log (\exp (-y[i] * x[i]))log(exp(−y[i]∗x[i])),是为了避免losslossloss计算为负数

例子:

import torch
import torch.nn as nn
import mathdef validate_SoftMarginLoss(input, target):val = 0for li_x, li_y in zip(input, target):for x, y in zip(li_x, li_y):loss_val = math.log(1 + math.exp(- y * x), math.e)val += loss_valreturn val / input.nelement()x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8], [0.1, 0.2, 0.4, 0.8]])
print(x.size())
y = torch.FloatTensor([[1, -1, 1, 1], [1, -1, 1, 1]])
print(y.size())loss = nn.SoftMarginLoss(reduction="none")
loss_val = loss(x, y)
print(loss_val)loss = nn.SoftMarginLoss(reduction="sum")
loss_val = loss(x, y)
print(loss_val.item())
print(loss_val.item() / x.nelement())loss = nn.SoftMarginLoss(reduction="mean")
loss_val = loss(x, y)
print(loss_val.item())valid_loss_val = validate_SoftMarginLoss(x, y)
print(valid_loss_val)

结果:

torch.Size([2, 4])
torch.Size([2, 4])
tensor([[0.6444, 0.7981, 0.5130, 0.3711],[0.6444, 0.7981, 0.5130, 0.3711]])
4.653303146362305
0.5816628932952881
0.5816628932952881
0.5816628606614725

loss函数之SoftMarginLoss相关推荐

  1. 一起学习ML和DL中常用的几种loss函数

    摘要:本篇内容和大家一起学习下机器学习和深度学习中常用到的几种loss函数. 本文分享自华为云社区<[MindSpore易点通]网络实战之交叉熵类Loss函数>,作者:Skytier . ...

  2. 【Dual-Path-RNN-Pytorch源码分析】loss函数:SI-SNR

    DPRNN使用的loss函数是 SI-SNR SI-SNR 是scale-invariant source-to-noise ratio的缩写,中文翻译为尺度不变的信噪比,意思是不受信号变化影响的信噪 ...

  3. tensorflow学习(4.loss函数以及正则化的使用 )

    本文还是以MNIST的CNN分析为例 loss函数一般有MSE均方差函数.交叉熵损失函数,说明见 https://blog.csdn.net/John_xyz/article/details/6121 ...

  4. 多分类loss函数本质理解

    一.面对一个多分类问题,如何设计合理的损失函数呢? 1.损失函数的本质在数学上称为目标函数:这个目标函数的目标值符合最完美的需求:损失函数的目标值肯定是0,完美分类的损失必然为0 : 2.损失函数分为 ...

  5. 深度学习基础(三)loss函数

    loss函数,即损失函数,是决定网络学习质量的关键.若网络结构不变的前提下,损失函数选择不当会导致模型精度差等后果.若有错误,敬请指正,Thank you! 目录 一.loss函数定义 二.常见的lo ...

  6. Keras自定义Loss函数

    Keras作为一个深度学习库,非常适合新手.在做神经网络时,它自带了许多常用的目标函数,优化方法等等,基本能满足新手学习时的一些需求.具体包含目标函数和优化方法.但它也支持用户自定义目标函数,下边介绍 ...

  7. 商汤使用AutoML设计Loss函数,全面超越人工设计

    点击我爱计算机视觉标星,更快获取CVML新技术 深度学习领域,神经架构搜索得到的算法如雨后春笋般出现. 今天一篇arXiv论文<AM-LFS: AutoML for Loss Function ...

  8. 深度学习中的损失函数总结以及Center Loss函数笔记

    北京 上海巡回站 | NVIDIA DLI深度学习培训 2018年1月26/1月12日 NVIDIA 深度学习学院 带你快速进入火热的DL领域 阅读全文                        ...

  9. 'int' object has no attribute 'backward'报错 使用Pytorch编写 Hinge loss函数

    在编写SVM中的Hinge loss函数的时候报错"'int' object has no attribute 'backward'" for epoch in range(50) ...

最新文章

  1. html embed详解
  2. windows服务器双网卡链路聚合_基于windows server 2012的多网卡链路聚合实验设计与......
  3. anaconda与pip 清华镜像源
  4. docker下MySQL修改配置并重启生效:表名不区分大小写
  5. php动态写入vue,Vue自定义动态组件使用详解
  6. 电源模块的9个主要性能指标及其作用
  7. Django的核心思想ORM
  8. 解决AS3 Socket编程中最令人头疼的问题
  9. docker安装部署和常用命令
  10. Apk去掉签名以及重新签名的方法
  11. matlab建立机器人模型,matlab 机器人工具箱8-通过URDF建立机器人模型
  12. 从哪些方面评价一款在线客服系统产品
  13. 分享一些程序员必备网站
  14. Perl中的单行注释和多行注释
  15. (ssl1458)数字金字塔(三角形)逆推法
  16. Quartus||仿真图设计
  17. 位置无关(PIC)代码原理剖析
  18. Ansoft Maxwell 永磁体表面磁场和空间磁场仿真
  19. HTML5期末大作业:个人主页网站设计——服装明星主页(7页)表格带留言板带音乐
  20. GNSS定位原理(伪距)

热门文章

  1. GDI+简单现实文字旋转
  2. Swing-JTree树模型的操作
  3. matlab方阵对角线清零
  4. linux mysql apache php 安装_linux下安装apache与php;Apache+PHP+MySQL配置攻略
  5. angular获取路由参数_Angular应用带参数的路由实现
  6. 笔记本我的计算机怎么找不到了,Win10我的电脑在哪?图标没了怎么办?Win10此电脑不见了解决方法...
  7. PHP九宫格翻牌抽奖,PHP 九宫格抽奖代码
  8. Jmeter负载和压力测试
  9. mysql 记录更新时间_MySQL表内更新时,自动记录时间
  10. python字符串转化为数字信号_用python实现简单的数字信号软件滤波处理