注意:这里展示的是本篇博文写时的版本最新的实现,但是后续会代码可能会迭代更新,建议对照官方文档进行学习。

先来看源码:

# 这个类是是许多池化类的基类,这里有必要了解一下
class _MaxPoolNd(Module):__constants__ = ['kernel_size', 'stride', 'padding', 'dilation','return_indices', 'ceil_mode']return_indices: boolceil_mode: bool# 构造函数,这里只需要了解这个初始化函数即可。def __init__(self, kernel_size: _size_any_t, stride: Optional[_size_any_t] = None,padding: _size_any_t = 0, dilation: _size_any_t = 1,return_indices: bool = False, ceil_mode: bool = False) -> None:super(_MaxPoolNd, self).__init__()self.kernel_size = kernel_sizeself.stride = stride if (stride is not None) else kernel_sizeself.padding = paddingself.dilation = dilationself.return_indices = return_indicesself.ceil_mode = ceil_modedef extra_repr(self) -> str:return 'kernel_size={kernel_size}, stride={stride}, padding={padding}' \', dilation={dilation}, ceil_mode={ceil_mode}'.format(**self.__dict__)class MaxPool2d(_MaxPoolNd):kernel_size: _size_2_tstride: _size_2_tpadding: _size_2_tdilation: _size_2_tdef forward(self, input: Tensor) -> Tensor:return F.max_pool2d(input, self.kernel_size, self.stride,self.padding, self.dilation, self.ceil_mode,self.return_indices)

MaxPool2d 这个类的实现十分简单。

我们先来看一下基本参数,一共六个:

  1. kernel_size :表示做最大池化的窗口大小,可以是单个值,也可以是tuple元组
  2. stride :步长,可以是单个值,也可以是tuple元组
  3. padding :填充,可以是单个值,也可以是tuple元组
  4. dilation :控制窗口中元素步幅
  5. return_indices :布尔类型,返回最大值位置索引
  6. ceil_mode :布尔类型,为True,用向上取整的方法,计算输出形状;默认是向下取整。

关于 kernel_size 的详解

注意这里的 kernel_size 跟卷积核不是一个东西。 kernel_size 可以看做是一个滑动窗口,这个窗口的大小由自己指定,如果输入是单个值,例如 333 ,那么窗口的大小就是 3×33 \times 33×3 ,还可以输入元组,例如 (3, 2) ,那么窗口大小就是 3×23 \times 23×2 。

最大池化的方法就是取这个窗口覆盖元素中的最大值。

关于 stride 的详解

上一个参数我们确定了滑动窗口的大小,现在我们来确定这个窗口如何进行滑动。如果不指定这个参数,那么默认步长跟最大池化窗口大小一致。如果指定了参数,那么将按照我们指定的参数进行滑动。例如 stride=(2,3) , 那么窗口将每次向右滑动三个元素位置,或者向下滑动两个元素位置。

关于 padding 的详解

这参数控制如何进行填充,填充值默认为0。如果是单个值,例如 1,那么将在周围填充一圈0。还可以用元组指定如何填充,例如 padding=(2,1)padding=(2, 1)padding=(2,1) ,表示在上下两个方向个填充两行0,在左右两个方向各填充一列0。

关于 dilation 的详解

不会

关于 return_indices 的详解

这是个布尔类型值,表示返回值中是否包含最大值位置的索引。注意这个最大值指的是在所有窗口中产生的最大值,如果窗口产生的最大值总共有5个,就会有5个返回值。

关于 ceil_mode 的详解

这个也是布尔类型值,它决定的是在计算输出结果形状的时候,是使用向上取整还是向下取整。怎么计算输出形状,下面会讲到。一看就知道了。

——————————————参数解析结束分界线——————————————

最大池化层输出形状计算
Hout=⌊Hin+2×padding⌊0⌋−dilation⌊0⌋×(kernel_size⌊0⌋−1)−1stride⌊0⌋+1⌋H_{out}=\lfloor \frac{H_{in} + 2 \times padding\lfloor 0 \rfloor - dilation \lfloor 0 \rfloor \times (kernel\_size\lfloor 0 \rfloor - 1)-1}{stride\lfloor 0 \rfloor} + 1 \rfloor Hout​=⌊stride⌊0⌋Hin​+2×padding⌊0⌋−dilation⌊0⌋×(kernel_size⌊0⌋−1)−1​+1⌋

Wout=⌊Win+2×padding⌊1⌋−dilation⌊1⌋×(kernel_size⌊1⌋−1)−1stride⌊1⌋+1⌋W_{out}=\lfloor \frac{W_{in} + 2 \times padding\lfloor 1 \rfloor - dilation \lfloor 1 \rfloor \times (kernel\_size\lfloor 1 \rfloor - 1)-1}{stride\lfloor 1 \rfloor} + 1 \rfloor Wout​=⌊stride⌊1⌋Win​+2×padding⌊1⌋−dilation⌊1⌋×(kernel_size⌊1⌋−1)−1​+1⌋

看到向下取整的符号了吗?这个就是由 ceil_mode 控制的。

——————————————结束分界线——————————————

下面我们写代码验证一下最大池化层是如何计算的:

首先验证 kernel_size 参数

import torch
import torch.nn as nn# 仅定义一个 3x3 的池化层窗口
m = nn.MaxPool2d(kernel_size=(3, 3))# 定义输入
# 四个参数分别表示 (batch_size, C_in, H_in, W_in)
# 分别对应,批处理大小,输入通道数,图像高度(像素),图像宽度(像素)
# 为了简化表示,我们只模拟单张图片输入,单通道图片,图片大小是6x6
input = torch.randn(1, 1, 6, 6)print(input)output = m(input)print(output)

结果:

第一个tensor是我们的输入数据 1×1×6×61 \times 1 \times 6 \times 61×1×6×6 ,我们画红线的区域就是我们设置的窗口大小 3×33 \times 33×3 ,背景色为红色的值,为该区域的最大值。

第二个tensor就是我们最大池化后的结果,跟我们标注的一模一样。

这个就是最基本的最大池化。

之后我们验证一下 stride 参数

import torch
import torch.nn as nn# 仅定义一个 3x3 的池化层窗口
m = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))# 定义输入
# 四个参数分别表示 (batch_size, C_in, H_in, W_in)
# 分别对应,批处理大小,输入通道数,图像高度(像素),图像宽度(像素)
# 为了简化表示,我们只模拟单张图片输入,单通道图片,图片大小是6x6
input = torch.randn(1, 1, 6, 6)print(input)output = m(input)print(output)

结果:

红色的还是我们的窗口,但是我们的步长变为了2,可以看到第一个窗口和向右滑动后的窗口,他们的最大值刚好是重叠的部分都是2.688,向下滑动之后,最大值是0.8030,再次向右滑动,最大值是2.4859。

可以看到我们在滑动的时候省略了部分数值,因为剩下的数据不够一次滑动了,于是我们将他们丢弃了。

其实最后图片的宽度和高度还可以通过上面两个公式来计算,我们公式中用的是向下取整,因此我们丢弃了不足的数据。现在我们试试向上取整。

利用 ceil_mode 参数向上取整

import torch
import torch.nn as nn# 仅定义一个 3x3 的池化层窗口
m = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), ceil_mode=True)# 定义输入
# 四个参数分别表示 (batch_size, C_in, H_in, W_in)
# 分别对应,批处理大小,输入通道数,图像高度(像素),图像宽度(像素)
# 为了简化表示,我们只模拟单张图片输入,单通道图片,图片大小是6x6
input = torch.randn(1, 1, 6, 6)print(input)output = m(input)print('\n\n\n\n\n')print(output)

结果:

从结果可以看出,输出的size由原来的 2×22 \times 22×2 变成了现在的 3×33 \times 33×3 。这就是向上取整的结果。为什么会出现这样的结果呢?

这看起来像是我们对输入进行了填充,但是这个填充值不会参与到计算最大值中。

继续验证 padding 参数

import torch
import torch.nn as nn# 仅定义一个 3x3 的池化层窗口
m = nn.MaxPool2d(kernel_size=(3, 3), stride=(3, 3), padding=(1, 1))# 定义输入
# 四个参数分别表示 (batch_size, C_in, H_in, W_in)
# 分别对应,批处理大小,输入通道数,图像高度(像素),图像宽度(像素)
# 为了简化表示,我们只模拟单张图片输入,单通道图片,图片大小是6x6
input = torch.randn(1, 1, 6, 6)print(input)output = m(input)print('\n\n')print(output)

结果:

我们对周围填充了一圈0,我们滑动窗口的范围就变化了,这就是填充的作用。

但是有一点需要注意,就是即使我们填充了0,这个0也不会被选为最大值。例如上图的左上角四个数据,如果我们全部变为负数,结果是-0.1711,而不会是我们填充的0值,这一点要注意。

最后验证 return_indices 参数:

import torch
import torch.nn as nn# 仅定义一个 3x3 的池化层窗口
m = nn.MaxPool2d(kernel_size=(3, 3), return_indices=True)# 定义输入
# 四个参数分别表示 (batch_size, C_in, H_in, W_in)
# 分别对应,批处理大小,输入通道数,图像高度(像素),图像宽度(像素)
# 为了简化表示,我们只模拟单张图片输入,单通道图片,图片大小是6x6
input = torch.randn(1, 1, 6, 6)print(input)output = m(input)print(output)

结果:

仅仅是多返回了一个位置信息。元素位置从0开始计数,6表示第7个元素,9表示第10个元素…需要注意的是,返回值实际上是多维的数据,但是我们只看相关的元素位置信息,忽略维度的问题。

最后一个参数 dilation ,不会

torch.nn.MaxPool2d详解相关推荐

  1. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  2. torch.nn.Linear详解

    在学习transformer时,遇到过非常频繁的nn.Linear()函数,这里对nn.Linear进行一个详解. 参考:https://pytorch.org/docs/stable/_module ...

  3. torch.nn.parameter详解

    :-- 目录: 参考: 1.parameter基本解释: 2.参数requires_grad的深入理解: 2.1 Parameter级别的requires_grad 2.2Module级别的requi ...

  4. Pytorch损失函数torch.nn.NLLLoss()详解

    在各种深度学习框架中,我们最常用的损失函数就是交叉熵(torch.nn.CrossEntropyLoss),熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度.交 ...

  5. pytorch笔记:torch.nn.MaxPool2d

    1 基本用法 class torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=Fals ...

  6. torch.stack(), torch.cat()用法详解

    torch.stack(), torch.cat()用法详解 if __name__ == '__main__':import torchx_dat = torch.tensor([[1, 2], [ ...

  7. torch nn.MaxPool2d

    1.应用 import torch import torch.nn as nnm = nn.MaxPool2d(2) input = torch.randn(1, 1, 4, 4) output = ...

  8. Torch.arange函数详解

    torch.arange函数详解 官方文档:torch.arange 函数原型 arange(start=0, end, step=1, *, out=None, dtype=None, layout ...

  9. 【Pytorch】torch.argmax 函数详解

    文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...

最新文章

  1. 小功率荧光灯拆解分析
  2. Win7下共享文件(以及凭据管理简单介绍)
  3. Oracle的if else if
  4. Airdoc创始人:工智能可以在医疗领域多个环节发挥作用 但有局限性
  5. (转)学习密度与专注力
  6. Couchbase 2.0归类视图简介
  7. 【leecode】小练习(简单8题)
  8. zuul源码分析之Request生命周期管理
  9. 使用Kotlin在活动之间进行Android意向处理
  10. 1次订单事故,扣了我3个月绩效!
  11. 面试专题:Python面试题陷阱,你是否会中招?
  12. python之模块copy_reg(在python3中为copyreg,功能基本不变)
  13. 上班摸鱼的模拟经营文字游戏(管理后台页面,老板都看不出来)
  14. linux查看iozone安装目录,IOZone的基本使用
  15. Android os 4.4.4 魅族,魅族Mx3刷机包 Android 4.4.4 稳定版Flyme OS 3.7.3A 流畅顺滑体验
  16. html按钮位置设置吗,html改变button按钮位置
  17. matlab相机标定工具箱下载,matlab相机标定工具箱
  18. GreenDao 使用详解(入门篇)
  19. 如何解决 “无法成功完成操作,因为文件包含病毒或潜在垃圾软件
  20. Android 打造形形色色的进度条 实现可以如此简单

热门文章

  1. 大数据入门学习之环境搭建
  2. 015 Rust死灵书之Transmutes转换
  3. 墨者靶场 入门:WebShell文件上传漏洞分析溯源(第1题)
  4. 中国矿业大学考研经验分享
  5. 注册快捷键(单快捷键、组合快捷键)
  6. STC8A单片机应用开发
  7. 宇宙第一帅的HTML笔记
  8. Windows系统中设置软件的开机自动启动
  9. 赛龙代小权终审无罪释放,重燃创业之心
  10. Qt工具栏中设置小部件间隔的方法