一、BCELoss()

功能详解,见blog《pytorch验证CrossEntropyLoss ,BCELoss 和 BCEWithLogitsLoss的关系》

class torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')
  • weight (Tensoroptional) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.

  • size_average (booloptional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: True

  • reduce (booloptional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: True

  • reduction (stringoptional) – Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum''none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: size_average and reduce are in the process of being deprecated, and in the meantime, specifying either of those two args will override reduction. Default: 'mean'

(1) weight必须和target的shape一致,默认为none。定义BCELoss的时候指定即可。
(2) 默认情况下 nn.BCELoss(),reduce = True,size_average = True。
(3) 如果reduce为False,size_average不起作用,返回向量形式的loss。
(4) 如果reduce为True,size_average为True,返回loss的均值,即loss.mean()。
(5) 如果reduce为True,size_average为False,返回loss的和,即loss.sum()。
(6) 如果reduction = ‘none’,直接返回向量形式的 loss。
(7) 如果reduction = ‘sum’,返回loss之和。
(8) 如果reduction = ''elementwise_mean,返回loss的平均值。
(9) 如果reduction = ''mean,返回loss的平均值

1、对于weight参数的研究(weight参数的定义方法)

我通过研究后发现,weight参数的shape必须能够广播成和input的shape完全一样;即weight是作用在BCELoss()的输入的每一个元素上。
对比下面三分代码的weight。

import torchaa = torch.ones((2, 3, 4, 4), dtype=torch.float32) * 0.1
target = aa + 0.01weight = torch.ones((2, 3, 4, 4)) * 0.1
weight[1] = 10bce_loss = torch.nn.BCELoss(weight=weight, reduction='none')
loss = bce_loss(aa, target)
print(loss.size())
print(loss)

也可以:

import torchaa = torch.ones((2, 3, 4, 4), dtype=torch.float32) * 0.1
target = aa + 0.01weight = torch.ones((3, 4, 4)) * 0.1
weight[1] = 1
weight[2] = 10bce_loss = torch.nn.BCELoss(weight=weight, reduction='none')
loss = bce_loss(aa, target)
print(loss.size())
print(loss)

还可以:

import torchaa = torch.ones((2, 3, 4, 4), dtype=torch.float32) * 0.1
target = aa + 0.01weight = torch.ones((4, 4)) * 0.1bce_loss = torch.nn.BCELoss(weight=weight, reduction='none')
loss = bce_loss(aa, target)
print(loss.size())
print(loss)

weight的shape如果没有办法广播成和input的shape相同,就会报错。我也实验过了。

2、对于size_average、reduce、reduction参数的研究

通过前面的官方讲解,可以看出,size_average和reduce现在已经deprecated;size_average和reduce的组合就是实现reduction里面包括的几个功能。细节的代码示例,见KLDivLoss()。

二、KLDivLoss()

首先是KL散度损失的计算公式

假设输入x大小为(MXCXWXH),  y的大小也为(MXCXWXH)。N是M*C*W*H的乘积,共计算N个 , 也就是每一个位置的KL散度。具体细节,可以参考后面代码中的reduction='none'的情况。

class torch.nn.KLDivLoss(size_average=None, reduce=None, reduction='mean', log_target=False)
  • size_average (booloptional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there are multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: True

  • reduce (booloptional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: True

  • reduction (stringoptional) – Specifies the reduction to apply to the output: 'none' | 'batchmean' | 'sum' | 'mean''none': no reduction will be applied. 'batchmean': the sum of the output will be divided by batchsize. 'sum': the output will be summed. 'mean': the output will be divided by the number of elements in the output. Default: 'mean'

  • log_target (booloptional) – Specifies whether target is passed in the log space. Default: False

1、这几个参数的中文解释,见前面的BCELoss()
2、通过前面的官方讲解,可以看出,size_average和reduce现在已经deprecated;size_average和reduce的组合就是实现reduction里面包括的几个功能。

①reduction='none'

import torchaa = torch.ones((2, 3, 4, 4), dtype=torch.float32) * 0.1
target = aa + 0.01kl_loss = torch.nn.KLDivLoss(reduction='none')loss = kl_loss(aa, target)
print(loss.size())
print(loss)

返回:

②不设置reduction参数,此时等价于reduction='mean'

import torchaa = torch.ones((2, 3, 4, 4), dtype=torch.float32) * 0.1
target = aa + 0.01kl_loss = torch.nn.KLDivLoss()loss = kl_loss(aa, target)
print(loss.size())
print(loss)

返回:

会出现一个警告: /Users/chensi/Library/Python/3.8/lib/python/site-packages/torch/nn/functional.py:1958: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.
  warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size."

没啥影响。出现这个警告的原因,见⑤分析

③reduction='sum'

import torchaa = torch.ones((2, 3, 4, 4), dtype=torch.float32) * 0.1
target = aa + 0.01kl_loss = torch.nn.KLDivLoss(reduction='sum')loss = kl_loss(aa, target)
print(loss.size())
print(loss)

返回:

④reduction='batchmean'

import torchaa = torch.ones((2, 3, 4, 4), dtype=torch.float32) * 0.1
target = aa + 0.01kl_loss = torch.nn.KLDivLoss(reduction='batchmean')loss = kl_loss(aa, target)
print(loss.size())
print(loss)

返回:

⑤分析:

-24.3648 / 2 = -12.1824
-12.1824 / (3 * 4 * 4)  = -0.2538
 可以看出:
'batchmean' 就是在 'sum'的基础上除以了一个batchsize;即计算batchsize上的平均loss值。
'mean'就是在'batchmean'的基础上,计算每个像素点的平均loss值。
前面'mean'里面出现的警告,就是建议使用'batchmean',即使用batchsize上的平均loss值;因为在KL散度的定义里,就是除以的batchsize。而且直观的感觉'batchmean'更加合理,因为除以batchsize相当于计算一个完整样本的loss,和目标分类在样本个数上进行平均类似;'mean'就是在像素点上求平均loss了。

pytorch BCELoss()、KLDivLoss()的参数 及 “对于size_average、reduce、reduction参数的研究”相关推荐

  1. pytorch小知识点(二)-------CrossEntropyLoss(reduction参数)

    在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax. 首先要知道上面提到的这些函数一部分是来自于torch. ...

  2. PyTorch官方中文文档:torch.optim 优化器参数

    内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...

  3. Pytorch中tensor维度和torch.max()函数中dim参数的理解

    Pytorch中tensor维度和torch.max()函数中dim参数的理解 维度 参考了 https://blog.csdn.net/qq_41375609/article/details/106 ...

  4. 数组做参数_ES6 系列:你不知道的 Rest 参数与 Spread 语法细节

    Rest 参数与 Spread 语法 在 JavaScript 中,很多内建函数都支持传入任意数量的参数. 例如: Math.max(arg1, arg2, ..., argN) -- 返回入参中的最 ...

  5. R语言使用caret包对GBM模型自定义参数调优:自定义优化参数网格、可视化核心参数与评估指标关系、Accuracy与树的深度、个数的关系、Kappa与树的深度、个数的关系

    R语言使用caret包对GBM模型自定义参数调优:自定义优化参数网格.可视化核心参数与评估指标关系.Accuracy与树的深度.个数的关系.Kappa与树的深度.个数的关系 目录 R语言使用caret ...

  6. python一个函数可以有参数也可以没有参数_python 传入任意多个参数(方法调用可传参或不传参)...

    1.可传参数与不传参数,在定义中给参数设置默认值 class HandleYmal: """ 获取测试环境的配置 """ def __ini ...

  7. php函数多个参数_php中,用函数,如果有很多个参数,只使用最后一个参数,有什么优雅的写法?...

    分两种情况讨论这个问题. 如果你是想固定其中某几个值 如果你是想让其中某些参数有默认值 情况一:如果你是想提前固定其中某几个参数的值 你可以对函数进行部分求值(柯里化),得到一个新的函数,后续使用的时 ...

  8. 【数据挖掘】高斯混合模型 ( 与 K-Means 每个步骤对比 | 初始参数设置 | 计算概率 | 计算平均值参数 | 计算方差参数 | 计算高斯分布概率参数 | 算法终止条件 )

    文章目录 I . 高斯混合模型 ( 样本 -> 模型 ) II . 高斯混合模型 ( 模型 -> 样本 ) III . 高斯混合模型 与 K-Means 迭代过程对比 IV . 高斯混合模 ...

  9. python可变长参数(非关键字及关键字参数)

    可变长参数存在的意义是:每次调用一个函数处理不同量的参数输入.即,参数在调用之前输入的参数数量是未知的,或者多次调用该函数,每次的参数输入的量是不一致的: 可变长参数分为非关键字和关键字类型,分别对应 ...

  10. java 变长参数 知乎_变长参数探究

    前言 变长参数,指的是函数参数数量可变,或者说函数接受参数的数量可以不固定.实际上,我们最开始学C语言的时候,就用到了这样的函数:printf,它接受任意数量的参数,向终端格式化输出字符串.本文就来探 ...

最新文章

  1. git使用-设置项目忽略文件
  2. Django RestFramework BaseSerializer
  3. 云原生应用架构转型不好做?阿里云这个平台让你一步到位!
  4. 华为P50 Pro渲染图再曝光:液态镜头、四曲面屏很吸睛
  5. Asia Hong Kong Regional Contest 2016
  6. windos开启IIS管理器
  7. Stream Editor 流编辑器命令
  8. 【OpenCV】 - 图像分割之分水岭算法,watershed()函数的输出,对marker和image的改变
  9. 虚拟服务器 共享打印机,教你轻松解决打印机共享难题
  10. 谷歌提出新框架Soft Diffusion:从通用扩散过程中正确调度、学习和采样
  11. 联想Thinkpad E470 笔记本 无声音解决方案
  12. vue-父子组件传参以及无限级评论
  13. 【python】python3.7数据分析入门学习笔记 研读
  14. CSS:使用线性渐变实现标签右上角三角形角标效果/HTML上标、下标
  15. 关于WIN7输入法的小问题
  16. java实现第一个数字
  17. 年薪30W的程序员,都在哪些平台兼职接私活?
  18. Python 视频教程 ( 猿课 )
  19. 电脑上总显示宽带连接服务器怎么办啊,宽带连接不上_10招解决方法轻松搞
  20. (求老师啊,求同伴啊)php 生命数字密码设计第一步:数据库基本连接

热门文章

  1. [ 淘宝商城 ] 商城SEO
  2. Item 10.const成员函数 (Meaning of a Const Member Function)
  3. Python 爬虫入门(二)—— IP代理使用
  4. Atitit 提升可读性 流畅接口 1.1. 大接口vs 小接口 小接口可用流畅api串接起来 1 1.2. 部分comm fun可用大接口固化 1 2. 流畅接口 方法连 “Fluent接口
  5. Atitit maven 编译与资源文件与web目录自定义配置 与eclipse的集成与war包打包 1.1. 配置webapp目录 plugin设置 webappDirectory 1 1.2.
  6. Atitit 健康减肥与软件健康减肥的总结 attilax著 1. 几大最佳实践减肥行为 1 1.1. 控制饮食分量用小碗 小盘子 小餐具 1 1.2. 软件如何减肥,控制资源占有率,比如体积 打包
  7. Atitit.每周末总结 于每周一计划日程表 流程表 v8 -------------import 上周遗漏日志补充 检查话费 检查流量情况 Crm问候 Crm表total and 问候
  8. Atitit 软件与互联网理论 attilax总结
  9. atitit.guice3 绑定方式打总结生成非单例对象toInstance toProvider区别 v2 pb29
  10. atitit.解决net.sf.json.JSONException There is a cycle in the hierarchy