Pytorch的nn.Conv2d()详解

  • nn.Conv2d()的使用、形参与隐藏的权重参数
    • in_channels
    • out_channels
    • kernel_size
    • stride = 1
    • padding = 0
    • dilation = 1
    • groups = 1
    • bias = True
    • padding_mode = 'zeros'

nn.Conv2d()的使用、形参与隐藏的权重参数

  二维卷积应该是最常用的卷积方式了,在Pytorch的nn模块中,封装了nn.Conv2d()类作为二维卷积的实现。使用方法和普通的类一样,先实例化再使用。下面是一个只有一层二维卷积的神经网络,作为nn.Conv2d()方法的使用简介:
  

class Net(nn.Module):def __init__(self):nn.Module.__init__(self)self.conv2d = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=4,stride=2,padding=1)def forward(self, x):print(x.requires_grad)x = self.conv2d(x)return xprint(net.conv2d.weight)
print(net.conv2d.bias)

  它的形参由Pytorch手册可以查得,前三个参数是必须手动提供的,后面的有默认值。接下来将一一介绍:

  也许有细心的同学已经发现了,emm…卷积层最重要的可学习参数——权重参数和偏置参数去哪了?在Tensorflow中都是先定义好weight和bias,再去定义卷积层的呀!别担心,在Pytorch的nn模块中,它是不需要你手动定义网络层的权重和偏置的,这也是体现Pytorch使用简便的地方。当然,如果有小伙伴适应不了这种不定义权重和偏置的方法,Pytorch还提供了nn.Functional函数式编程的方法其中的F.conv2d()就和Tensorflow一样要先定义好卷积核的权重和偏置,作为F.conv2d()的形参之一。
  回到nn.Conv2d上来,我们可以通过实例名.weight和实例名.bias来查看卷积层的权重和偏置,如上图所示。还有小伙伴要问了,那么它们是如何初始化的呢?
  首先给结论,在nn模块中,Pytorch对于卷积层的权重和偏置(如果需要偏置)初始化都是采用He初始化的,因为它非常适合于ReLU函数。这一点大家看Pytorch的nn模块中卷积层的源码实现就能清楚地发现了,当然,我们也可以重新对权重等参数进行其他的初始化,可以查看其他教程,此处不再多言。

in_channels

  这个很好理解,就是输入的四维张量[N, C, H, W]中的C了,即输入张量的channels数。这个形参是确定权重等可学习参数的shape所必需的。

out_channels

  也很好理解,即期望的四维输出张量的channels数,不再多说。

kernel_size

  卷积核的大小,一般我们会使用5x5、3x3这种左右两个数相同的卷积核,因此这种情况只需要写kernel_size = 5这样的就行了。如果左右两个数不同,比如3x5的卷积核,那么写作kernel_size = (3, 5),注意需要写一个tuple,而不能写一个列表(list)。

stride = 1

  卷积核在图像窗口上每次平移的间隔,即所谓的步长。这个概念和Tensorflow等其他框架没什么区别,不再多言。

padding = 0

  Pytorch与Tensorflow在卷积层实现上最大的差别就在于padding上
  Padding即所谓的图像填充,后面的int型常数代表填充的多少(行数、列数),默认为0。需要注意的是这里的填充包括图像的上下左右,以padding = 1为例,若原始图像大小为32x32,那么padding后的图像大小就变成了34x34,而不是33x33
  Pytorch不同于Tensorflow的地方在于,Tensorflow提供的是padding的模式,比如same、valid,且不同模式对应了不同的输出图像尺寸计算公式。而Pytorch则需要手动输入padding的数量,当然,Pytorch这种实现好处就在于输出图像尺寸计算公式是唯一的,即

  当然,上面的公式过于复杂难以记忆。大多数情况下的kernel_size、padding左右两数均相同,且不采用空洞卷积(dilation默认为1),因此只需要记 O = (I - K + 2P)/ S +1这种在深度学习课程里学过的公式就好了。

dilation = 1

  这个参数决定了是否采用空洞卷积默认为1(不采用)。从中文上来讲,这个参数的意义从卷积核上的一个参数到另一个参数需要走过的距离,那当然默认是1了,毕竟不可能两个不同的参数占同一个地方吧(为0)。
  更形象和直观的图示可以观察Github上的Dilated convolution animations,展示了dilation=2的情况。

groups = 1

  决定了是否采用分组卷积,groups参数可以参考groups参数详解

bias = True

  即是否要添加偏置参数作为可学习参数的一个,默认为True。

padding_mode = ‘zeros’

  即padding的模式,默认采用零填充。

Pytorch的nn.Conv2d()详解相关推荐

  1. nn.Conv2d详解

    nn.Conv2d 是 PyTorch 中的一个卷积层,用于实现二维卷积操作.其主要参数有: in_channels:表示输入图像的通道数,也就是输入特征图的深度. out_channels:表示输出 ...

  2. PyTorch的nn.Linear()详解

    1. nn.Linear() nn.Linear():用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量 一般形状为[batch_size, size],不同于卷积层要求输入输出是 ...

  3. Tensorflow(r1.4)API--tf.nn.conv2d详解

    (一)函数简介 conv2d(input,filter,strides,padding,use_cudnn_on=True,data_format='NHWC',name=None) 1.参数: in ...

  4. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  5. Pytorch中nn.Conv2d数据计算模拟

    Pytorch中nn.Conv2d数据计算模拟 最近在研究dgcnn网络的源码,其网络架构部分使用的是nn.Conv2d模块.在Pytorch的官方文档中,nn.Conv2d的输入数据为(B, Cin ...

  6. 【PyTorch】nn.Conv2d函数详解

    文章目录 1. 函数语法格式 2. 参数解释 3. 尺寸关系 4. 使用案例 5. nn.functional.conv2d 1. 函数语法格式 CONV2D官方链接 torch.nn.Conv2d( ...

  7. Pytorch的nn.Conv2d()参数详解

    nn.Conv2d()的使用.形参与隐藏的权重参数   二维卷积应该是最常用的卷积方式了,在Pytorch的nn模块中,封装了nn.Conv2d()类作为二维卷积的实现.使用方法和普通的类一样,先实例 ...

  8. 【小白学PyTorch】12.SENet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 11 MobileNet详解及PyTorch实现 小白学PyTorch | 10 pytorch常见运算详解 小白学Py ...

  9. 【小白学PyTorch】11.MobileNet详解及PyTorch实现

    <<小白学PyTorch>> 小白学PyTorch | 10 pytorch常见运算详解 小白学PyTorch | 9 tensor数据结构与存储结构 小白学PyTorch | ...

最新文章

  1. System.currentTimeMillis()竟然存在性能问题,这我能信?
  2. mysql缺少函数_零散的MySQL基础总是记不住?看这一篇就够了!
  3. docker build命令详解_Docker 搭建你的第一个 Node 项目到服务器
  4. 机器人末端执行器气爪怎么吸合_平行气爪工作原理是什么?平行气爪原理图作用是什么...
  5. Node.js 笔记01
  6. GlusterFS架构与维护
  7. Velocity.js中文文档
  8. 找出第i个小元素(算法导论第三版9.2-4题)
  9. PyCharm2018 汉化激活
  10. 纪念盘古工坊开发的一款手机游戏正式发布
  11. vue 前端显示图片加token_Vue 页面权限控制和登陆验证
  12. Seurat | 不同单细胞转录组的整合方法
  13. 怎么提高国外服务器速度?
  14. 国产芯片传来好消息,纯国产CPU测试数据“曝光”
  15. java开发中遇到的问题_Java开发过程中遇到的问题及解决方法
  16. 无需下载软件,有手就能做的线上个人简历
  17. squad战术小队steam服务器搭建教程。
  18. Hulu全球研发副总裁诸葛越谈人工智能
  19. 笔记本或者台式机安装kali操作系统
  20. 物联网查流量_物联网流量管理平台

热门文章

  1. idea 上传项目到码云git仓库提交到gitee(完整操作流程)
  2. 团队作业之一:团队介绍及选题背景与意义
  3. mnn模型从训练-转换-预测
  4. JMockit didn't get initialized
  5. Windows 10 64bit 安装dotnetfx 3.5出错的解决办法(备忘)
  6. 冒泡排序(C语言版)
  7. vs2015已停止工作,事件名称APPCRASH 故障模块KERNELBASE.dll
  8. 保护你的 Flutter 应用程序
  9. 避免使用隐式类型转换
  10. java-之冒泡排序法