【pytorch】torch.nn.GroupNorm的使用
torch.nn.GroupNorm
字面意思是分组做Normalization,官方说明在这里。
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)
计算公式
E[x]是x的均值;
Var[x]是标准差;
gama和beta是训练参数,如果不想使用,可以通过参数affine=False设置。默认为True;
eposilon是输入参数,防止Var为0,默认值为1e-05,可以通过参数eps修改。
输入张量要求
输入的张量至少是2维的,其中第一维度为Channel,后面的维度为特征数据。
使用示例
GroupNorm 是将第一维度的Channels按group分,然后每个group按照上面的计算公式做计算。
比如,
input shape = (4,5)
gn = GroupNorm (2,4)
output = gn(input)
那么output就是将4个channel的数据分为2组,前1-2channel为一组,并按公式计算;后3-4channel为一组,并按公式计算;但是这里输出的shape还是(4,5)
GroupNorm 不会改变输入张量的shape,它只是按照group做normalization
三维,四维以上都一样,比如这里的input shape =(4,1,2,3,4,5),GroupNorm 的作用仅仅针对第一维度的channel。
报错
如果GroupNorm 输入的channel num与输入不一致,则会报错
RuntimeError: Expected number of channels in input to be divisible by num_groups
【pytorch】torch.nn.GroupNorm的使用相关推荐
- PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx
PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx 在写 PyTorch 代码时,我们会发现在 torch.nn.xxx 和 torch.nn.funct ...
- pytorch torch.nn.MSELoss
应用 # 1.计算绝对差总和:|0-1|^2+|1-1|^2+|2-1|^2+|3-1|^2=6 # 2.求平均: 6/4 =1.5 import torch import torch.nn as n ...
- pytorch torch.nn.Module.register_buffer
API register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None 注册buf ...
- pytorch torch.nn.TransformerEncoderLayer
API CLASS torch.nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=2048, dropout=0.1, activa ...
- pytorch torch.nn.TransformerEncoder
API CLASS torch.nn.TransformerEncoder(encoder_layer, num_layers, norm=None) TransformerEncoder is a ...
- pytorch torch.nn.Embedding
词嵌入矩阵,可以加载使用word2vector,glove API CLASS torch.nn.Embedding(num_embeddings: int, embedding_dim: int, ...
- [Pytorch]torch.nn.functional.conv2d与深度可分离卷积和标准卷积
torch.nn.functional.conv2d与深度可分离卷积和标准卷积 前言 F.conv2d与nn.Conv2d F.conv2d 标准卷积考虑Batch的影响 深度可分离卷积 深度可分离卷 ...
- pytorch TORCH.NN 到底是什么?
PyTorch 提供了设计精美的模块和类torch.nn. torch.optim. Dataset和DataLoader 来帮助创建和训练神经网络.为了充分利用它们的力量并针对需求灵活的定制它们,需 ...
- pytorch torch.nn.MSELoss(size_average=True)(均方误差【损失函数】)Mean Squared Error(MSE)、SSE(和方差)
class torch.nn.MSELoss(size_average=True)[source] 创建一个衡量输入x(模型预测输出)和目标y之间均方误差标准. x 和 y 可以是任意形状,每个包含n ...
最新文章
- mysql excel 命令行_MySQL 命令行数据导出到 Excel
- 我的Linux随笔目录
- html新人入门代码,HTML入门(示例代码)
- 【网络安全】Linux内核部分文件分析
- c# wpf listbox 高度_WPF快速入门系列(1)——WPF布局概览
- 2019牛客暑期多校训练营(第五场)
- 下面不是python合法标识符_哪个不是python合法标识符
- 原则,策略,规范也是构架的一部分
- Form的is_valid校验规则及验证顺序
- Ubuntu输入ifconfig找不到IP地址,只有lo问题
- 2.privite私有变量的意义
- 重构《一》-- 提取方法
- linux cpu 工作频率,Linux系统限制CPU工作频率(示例代码)
- 主席树-----动态开点,不hash
- Winform 按钮权限拦截AOP
- Atitit 软件体系的进化,是否需要一个处理中心
- Windows64位安装git
- css中的单位换算_GitHub - WangQiangrong/cssUnitTransform: css单位转换工具
- 【Unity】出现NullReferenceException:Object reference not set to an instance of an object.的原因总结
- 神调侃!程序员专属成长书单,我比女朋友更了解你!
热门文章
- You must restart adb and Eclipse问题的解决
- QIIME 2教程. 28社区Community(2021.2)
- 342.基于高通量技术的微生物组研究实验设计
- Nat Rev Genet发表房刚组细菌表观组综述论文
- R语言:数据筛选match
- python将scikit-learn自带数据集转换为pandas dataframe格式
- UCL葡萄酒(red white wine quality)数据集字段解释、数据导入实战
- 为多模型寻找模型最优参数、多模型交叉验证、可视化、指标计算、多模型对比可视化(系数图、误差图、混淆矩阵、校正曲线、ROC曲线、AUC、Accuracy、特异度、灵敏度、PPV、NPV)、结果数据保存
- matplotlib可视化去除轴标签、轴刻度线和轴刻度数值实战:Axis Text Ticks or Tick Labels
- Python时间序列模型推理预测实战:时序推理数据预处理(特征生成、lstm输入结构组织)、模型加载、模型预测结果保存、条件判断模型循环运行