文章目录

  • 普通卷积复习
  • Groups是如何改变卷积方式的
  • 实验验证
  • 参考资料

普通卷积复习

首先我们先来简单复习一下普通的卷积行为。


从上图可以看到,输入特征图为3,经过4个filter卷积后生成了4个输出特征图。对于普通的卷积操作,我们可以得到几个重要的结论:

  1. 输入通道数 = 每个filter的卷积核的个数。(注意区分卷积核和Filter,它们俩的关系是:多个卷积核组成一个Filter)
  2. Filter的个数 = 输出通道数

此时,我们的参数量为:

参数量=输入通道数×输出通道数×卷积核大小=卷积核个数×Filter数 ×卷积核大小\text{参数量} = 输入通道数 \times 输出通道数 \times 卷积核大小 = 卷积核个数 \times \text{Filter数 } \times 卷积核大小 参数量=输入通道数×输出通道数×卷积核大小=卷积核个数×Filter数 ×卷积核大小

这里忽略了偏置

Groups是如何改变卷积方式的

那现在我们不想按照上面的方式,我想让一个Filter只负责一部分输入通道,例如:

上图中,我们将输入通道分成了2组(也就是groups=2),每一组对应一个Filter,这样我们的参数量就下降了1倍。此时,我们还是有4个Filter(因为有4个输出通道),但每个Filter只有2个卷积核,所以一个Filter只对2个输入通道进行卷积。

为了巩固,我们再举个例子:

在该例子中,我们的输入通道为4,输出通道为8。这次我们将4个输入通道分成了4组,也就是groups=4,此时我们的每个Filter的卷积核数量就是1。

从上面两个例子,大家应该很清楚group的作用了,这里进行一个总结:

  1. Groups做的事情将输入通道进行分组,groups的值就是具体分的组数。所以,in_channel ÷ groups 一定要是整数,要不然就没法分组了。每个Filter负责处理一组输入通道,所以Filter的卷积核数量也会随之改变,即每个Filter的卷积核数 = in_channel ÷ groups
  2. Groups的作用:减少计算量和参数量。
  3. Groups其他注意事项输出通道 ÷ groups 也一定要是整数,要不然就会有几组没有Filter与之对应了。

综上,如果加入了groups,则卷积参数量的计算公式为:

参数量=输入通道数groups×输出通道数×卷积核大小\text{参数量} = \frac{\text{输入通道数}}{groups} \times 输出通道数 \times 卷积核大小 参数量=groups输入通道数​×输出通道数×卷积核大小

这里同样忽略了偏置

实验验证

我们现在就来做一组实验,验证上面的说法。 这里我准备一个1x1的图片,卷积核大小也为1x1,输入通道数为4, 输出通道数为8,groups设为2。用图像表示则为:

实验开始:

首先,我们先导包和准备一个打印参数数量的辅助函数:

import torch.nn as nn
import torchdef get_parameter_number(net):total_num = sum(p.numel() for p in net.parameters())return {'Total': total_num}

接下来定义卷积模型,并打印参数量:

model = nn.Conv2d(4, 8, 1, 1, groups=2, bias=False)
get_parameter_number(model)
{'Total': 16}

可以看到,参数量和预期的是一致的。8个Filter,每个Filter两个卷积核,所以一共16个参数。

接下来定义输入层,输入层是1x1的图片,值都为1:

inputs = torch.ones(1, 4, 1, 1)

然后修改卷积核的参数,改为图片上的[1,2,3,4…,16]:

for param in model.parameters():print(param.size())param.data = torch.FloatTensor([list(range(1, 17))]).view(8,2,1,1)
torch.Size([8, 2, 1, 1])

通过参数的shape也可以看出来,8个filter,每个filter2个卷积核。接下来进行前向传递:

model(inputs)
tensor([[[[ 3.]],[[ 7.]],[[11.]],[[15.]],[[19.]],[[23.]],[[27.]],[[31.]]]], grad_fn=<MkldnnConvolutionBackward0>)

完美,跟预想中的结果完全一致。

参考资料

https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html

图解 Pytorch 中 nn.Conv2d 的 groups 参数相关推荐

  1. Pytorch中nn.Conv2d数据计算模拟

    Pytorch中nn.Conv2d数据计算模拟 最近在研究dgcnn网络的源码,其网络架构部分使用的是nn.Conv2d模块.在Pytorch的官方文档中,nn.Conv2d的输入数据为(B, Cin ...

  2. Pytorch中nn.Conv2d的用法

    官网链接: nn.Conv2d     Applies a 2D convolution over an input signal composed of several input planes. ...

  3. Pytorch的nn.Conv2d()参数详解

    nn.Conv2d()的使用.形参与隐藏的权重参数   二维卷积应该是最常用的卷积方式了,在Pytorch的nn模块中,封装了nn.Conv2d()类作为二维卷积的实现.使用方法和普通的类一样,先实例 ...

  4. PyTorch中 nn.Conv2d与nn.ConvTranspose2d函数的用法

    原文链接 1. 通道数问题 : 描述一个像素点,如果是灰度,那么只需要一个数值来描述它,就是单通道.如果有RGB三种颜色来描述它,就是三通道.最初输入的图片样本的 channels ,取决于图片类型: ...

  5. Pytorch中nn.Conv2d的dilation

    dilation原文解释如下: controls the spacing between the kernel points; also known as the à trous algorithm. ...

  6. pytorch中nn.Conv2d卷积的padding的取值问题

    明确卷积的计算公式:d = (d - kennel_size + 2 * padding) / stride + 1 保证输入输出的分辨率大小一致,padding的取值:如果kernal_size = ...

  7. Pytorch的nn.Conv2d()详解

    Pytorch的nn.Conv2d()详解 nn.Conv2d()的使用.形参与隐藏的权重参数 in_channels out_channels kernel_size stride = 1 padd ...

  8. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  9. 对于pytorch中nn.CrossEntropyLoss()与nn.BCELoss()的理解和使用

    在pytorch中nn.CrossEntropyLoss()为交叉熵损失函数,用于解决多分类问题,也可用于解决二分类问题. BCELoss是Binary CrossEntropyLoss的缩写,nn. ...

最新文章

  1. YY的GCD 莫比乌斯反演
  2. python类介绍_python类介绍
  3. oracle删除表空间中的表,ORACLE删除表空间中的所有表
  4. Photoshop2018详细安装教程
  5. mysql安装被打断_MySQL安装未响应解决方法
  6. vue页面按钮点击后,呈现loading加载状态
  7. java语言之数组-----选择排序
  8. ConstraintLayout约束控件详解
  9. 生物信息学biojava|从本地读取并解析遍历genbank文件|从genbank中提取CDS等注释信息
  10. 经纬财富:徐州炒白银需要注意哪些技术指标
  11. mysql navicat视图_navicat怎么创建视图
  12. C语言实现贪吃蛇(双人版本)
  13. linux中怎样隐藏文件,Linux下如何隐藏文件
  14. 《蜗居》触动人心灵的100个瞬间
  15. grok java_ELK实战 - Grok简易入门
  16. 《富爸爸穷爸爸 》 读书笔记
  17. 了解iOS各个版本新特性总结
  18. window安装mysql默认密码忘记_MySQL忘记root密码的处理办法及安装windows服务
  19. nexmo - 当晚售前打电话
  20. Excel如何在姓名与字母之间加空格

热门文章

  1. 读《远见:如何规划职业生涯3大阶段》
  2. 配置 Deepin Linux 支持 中文 宋体、楷体等字体
  3. oracle 11g下载和安装教程
  4. Azure数据仓库表中的数据经常使用的三种分布策略(hash、round_robin 或 replicated)简介
  5. 7-3 优美的括号序列
  6. 中国推进系统行业市场供需与战略研究报告
  7. matlab两列矩阵相除,矩阵运算矩阵除法运算 - matlab资源网2
  8. 轿车麦弗逊式悬架设计(设计说明书+4张CAD图纸+中英文翻译)
  9. windows下启动cmd,打开指定目录,执行指定命令
  10. itune备份在哪里_iTune的iPhone备份文件在哪里,以及如何从中获取真实文件?