torch.nn.GroupNorm

字面意思是分组做Normalization,官方说明在这里。

torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)

计算公式


E[x]是x的均值;
Var[x]是标准差;
gama和beta是训练参数,如果不想使用,可以通过参数affine=False设置。默认为True;
eposilon是输入参数,防止Var为0,默认值为1e-05,可以通过参数eps修改。

输入张量要求

输入的张量至少是2维的,其中第一维度为Channel,后面的维度为特征数据。

使用示例

GroupNorm 是将第一维度的Channels按group分,然后每个group按照上面的计算公式做计算。

比如,
input shape = (4,5)
gn = GroupNorm (2,4)
output = gn(input)

那么output就是将4个channel的数据分为2组,前1-2channel为一组,并按公式计算;后3-4channel为一组,并按公式计算;但是这里输出的shape还是(4,5)
GroupNorm 不会改变输入张量的shape,它只是按照group做normalization

三维,四维以上都一样,比如这里的input shape =(4,1,2,3,4,5),GroupNorm 的作用仅仅针对第一维度的channel。

报错

如果GroupNorm 输入的channel num与输入不一致,则会报错
RuntimeError: Expected number of channels in input to be divisible by num_groups

【pytorch】torch.nn.GroupNorm的使用相关推荐

  1. PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx

    PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx 在写 PyTorch 代码时,我们会发现在 torch.nn.xxx 和 torch.nn.funct ...

  2. pytorch torch.nn.MSELoss

    应用 # 1.计算绝对差总和:|0-1|^2+|1-1|^2+|2-1|^2+|3-1|^2=6 # 2.求平均: 6/4 =1.5 import torch import torch.nn as n ...

  3. pytorch torch.nn.Module.register_buffer

    API register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None 注册buf ...

  4. pytorch torch.nn.TransformerEncoderLayer

    API CLASS torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activa ...

  5. pytorch torch.nn.TransformerEncoder

    API CLASS torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None) TransformerEncoder is a ...

  6. pytorch torch.nn.Embedding

    词嵌入矩阵,可以加载使用word2vector,glove API CLASS torch.nn.Embedding(num_embeddings: int, embedding_dim: int, ...

  7. [Pytorch]torch.nn.functional.conv2d与深度可分离卷积和标准卷积

    torch.nn.functional.conv2d与深度可分离卷积和标准卷积 前言 F.conv2d与nn.Conv2d F.conv2d 标准卷积考虑Batch的影响 深度可分离卷积 深度可分离卷 ...

  8. pytorch TORCH.NN 到底是什么?

    PyTorch 提供了设计精美的模块和类torch.nn. torch.optim. Dataset和DataLoader 来帮助创建和训练神经网络.为了充分利用它们的力量并针对需求灵活的定制它们,需 ...

  9. pytorch torch.nn.MSELoss(size_average=True)(均方误差【损失函数】)Mean Squared Error(MSE)、SSE(和方差)

    class torch.nn.MSELoss(size_average=True)[source] 创建一个衡量输入x(模型预测输出)和目标y之间均方误差标准. x 和 y 可以是任意形状,每个包含n ...

最新文章

  1. mysql excel 命令行_MySQL 命令行数据导出到 Excel
  2. 我的Linux随笔目录
  3. html新人入门代码,HTML入门(示例代码)
  4. 【网络安全】Linux内核部分文件分析
  5. c# wpf listbox 高度_WPF快速入门系列(1)——WPF布局概览
  6. 2019牛客暑期多校训练营(第五场)
  7. 下面不是python合法标识符_哪个不是python合法标识符
  8. 原则,策略,规范也是构架的一部分
  9. Form的is_valid校验规则及验证顺序
  10. Ubuntu输入ifconfig找不到IP地址,只有lo问题
  11. 2.privite私有变量的意义
  12. 重构《一》-- 提取方法
  13. linux cpu 工作频率,Linux系统限制CPU工作频率(示例代码)
  14. 主席树-----动态开点,不hash
  15. Winform 按钮权限拦截AOP
  16. Atitit 软件体系的进化,是否需要一个处理中心
  17. Windows64位安装git
  18. css中的单位换算_GitHub - WangQiangrong/cssUnitTransform: css单位转换工具
  19. 【Unity】出现NullReferenceException:Object reference not set to an instance of an object.的原因总结
  20. 神调侃!程序员专属成长书单,我比女朋友更了解你!

热门文章

  1. You must restart adb and Eclipse问题的解决
  2. QIIME 2教程. 28社区Community(2021.2)
  3. 342.基于高通量技术的微生物组研究实验设计
  4. Nat Rev Genet发表房刚组细菌表观组综述论文
  5. R语言:数据筛选match
  6. python将scikit-learn自带数据集转换为pandas dataframe格式
  7. UCL葡萄酒(red white wine quality)数据集字段解释、数据导入实战
  8. 为多模型寻找模型最优参数、多模型交叉验证、可视化、指标计算、多模型对比可视化(系数图、误差图、混淆矩阵、校正曲线、ROC曲线、AUC、Accuracy、特异度、灵敏度、PPV、NPV)、结果数据保存
  9. matplotlib可视化去除轴标签、轴刻度线和轴刻度数值实战:Axis Text Ticks or Tick Labels
  10. Python时间序列模型推理预测实战:时序推理数据预处理(特征生成、lstm输入结构组织)、模型加载、模型预测结果保存、条件判断模型循环运行