前言

最近在看新论文的过程中,发现新论文中的代码非常简洁,只用了unfold和fold方法便高效的将论文的思想表达出,因此学习记录一下unfold和fold方法。

一、方法详解

  • 方法
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
  • parameters
    • kernel_size (int or tuple) – 滑动窗口的size

    • stride (int or tuple, optional) – 空间维度上滑动的步长,默认步长为1

    • padding (int or tuple, optional) – implicit zero padding to be added on both sides of input. Default: 0

    • dilation (int or tuple, optional) – 空洞卷积的扩充率,默认为1

  • 释义:提取滑动窗口滑过的所有值,例如下面的例子中,
[[ 0.4009,  0.6350, -0.5197,  0.8148, -0.7235],
[-1.2102,  0.4621, -0.3421, -0.9261, -2.8376],
[-1.5553,  0.1713,  0.6820, -2.0880, -0.0204],
[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108],
[ 0.1459, -0.4568,  1.0039, -1.2385, -1.4467]]

kernel size =3 的窗口滑过,会首先记录

[[ 0.4009,  0.6350, -0.5197, -1.2102,  0.4621, -0.3421, -1.5553, 0.1713,  0.6820],[ 0.6350, -0.5197,  0.8148,  0.4621, -0.3421, -0.9261,  0.1713, 0.6820, -2.0880],[-0.5197,  0.8148, -0.7235, -0.3421, -0.9261, -2.8376,  0.6820, -2.0880, -0.0204],[-1.2102,  0.4621, -0.3421, -1.5553,  0.1713,  0.6820,  1.1419, -0.4881, -0.9510],[ 0.4621, -0.3421, -0.9261,  0.1713,  0.6820, -2.0880, -0.4881, -0.9510, -0.0367],[-0.3421, -0.9261, -2.8376,  0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108],[-1.5553,  0.1713,  0.6820,  1.1419, -0.4881, -0.9510,  0.1459, -0.4568,  1.0039],[ 0.1713,  0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568, 1.0039, -1.2385],[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108,  1.0039, -1.2385, -1.4467]]
  • Note:unfold方法的输入只能是4维的,即(N,C,H,W)

二、如何计算输出的size

  • 栗子
import torch
import torch.nn as nn
if __name__ == '__main__':x = torch.randn(2, 3, 5, 5)print(x)unfold = nn.Unfold(2)y = unfold(x)print(y.size())print(y)
  • 运行结果
torch.Size([2, 12, 16])

接下来,我们一步一步分析这个结果是怎么计算出来的!

首先,要知道的是,我们的输入必须是4维的,即(B,C,H,W),其中,B表示Batch size;C代表通道数;H代表feature map的高;W表示feature map的宽。首先,我们假设经过Unfolder处理之后的size为(B,h,w)。然后我们需要计算h(即输出的高),计算公式如下所示:

这里是引用举个栗子:假设输入通道数为3,kernel size为(2,2),图片最常见的通道数为3(所以我们拿来举例),经过Unfolder方法后,输出的高变为322=12,即输出的H为12。

计算完成之后,我们需要计算w,计算公式如下所示:

其中,d代表的是空间的所有维度数,例如空间维度为(H,W),则d=2。下面通过举例,我们来计算输出的w。

举个栗子:如果输入的H、W分别为5,kernel size为2,则输出的w为

4*4=16,故最终的输出size为[2,12,16]。

三、案例

  • 案例
import torch
import torch.nn as nn
if __name__ == '__main__':x = torch.randn(1, 3, 5, 5)print(x)unfold = nn.Unfold(kernel_size=3)output = unfold(x)print(output, output.size())
  • 运行结果
tensor([[[[ 0.4009,  0.6350, -0.5197,  0.8148, -0.7235],[-1.2102,  0.4621, -0.3421, -0.9261, -2.8376],[-1.5553,  0.1713,  0.6820, -2.0880, -0.0204],[ 1.1419, -0.4881, -0.9510, -0.0367, -0.8108],[ 0.1459, -0.4568,  1.0039, -1.2385, -1.4467]],[[-0.9973, -0.7601, -0.2161,  1.2120, -0.3036],[-0.7279,  0.0833, -0.8886, -0.9168,  0.7503],[-0.6748,  0.7064,  0.6903, -1.0447,  0.8688],[-0.5230, -1.2308, -0.3932,  1.2521, -0.2523],[-0.3930,  0.6452,  0.1690,  0.3744,  0.2015]],[[ 0.6403,  1.3915, -1.9529,  0.2899, -0.8897],[-0.1720,  1.0843, -1.0177, -1.7480, -0.5217],[-0.9648, -0.0867, -0.2926,  0.3010,  0.3192],[ 0.1181, -0.2218,  0.0766,  0.5914, -0.8932],[-0.4508, -0.3964,  1.1163,  0.6776, -0.8948]]]])
tensor([[[ 0.4009,  0.6350, -0.5197, -1.2102,  0.4621, -0.3421, -1.5553,0.1713,  0.6820],[ 0.6350, -0.5197,  0.8148,  0.4621, -0.3421, -0.9261,  0.1713,0.6820, -2.0880],[-0.5197,  0.8148, -0.7235, -0.3421, -0.9261, -2.8376,  0.6820,-2.0880, -0.0204],[-1.2102,  0.4621, -0.3421, -1.5553,  0.1713,  0.6820,  1.1419,-0.4881, -0.9510],[ 0.4621, -0.3421, -0.9261,  0.1713,  0.6820, -2.0880, -0.4881,-0.9510, -0.0367],[-0.3421, -0.9261, -2.8376,  0.6820, -2.0880, -0.0204, -0.9510,-0.0367, -0.8108],[-1.5553,  0.1713,  0.6820,  1.1419, -0.4881, -0.9510,  0.1459,-0.4568,  1.0039],[ 0.1713,  0.6820, -2.0880, -0.4881, -0.9510, -0.0367, -0.4568,1.0039, -1.2385],[ 0.6820, -2.0880, -0.0204, -0.9510, -0.0367, -0.8108,  1.0039,-1.2385, -1.4467],[-0.9973, -0.7601, -0.2161, -0.7279,  0.0833, -0.8886, -0.6748,0.7064,  0.6903],[-0.7601, -0.2161,  1.2120,  0.0833, -0.8886, -0.9168,  0.7064,0.6903, -1.0447],[-0.2161,  1.2120, -0.3036, -0.8886, -0.9168,  0.7503,  0.6903,-1.0447,  0.8688],[-0.7279,  0.0833, -0.8886, -0.6748,  0.7064,  0.6903, -0.5230,-1.2308, -0.3932],[ 0.0833, -0.8886, -0.9168,  0.7064,  0.6903, -1.0447, -1.2308,-0.3932,  1.2521],[-0.8886, -0.9168,  0.7503,  0.6903, -1.0447,  0.8688, -0.3932,1.2521, -0.2523],[-0.6748,  0.7064,  0.6903, -0.5230, -1.2308, -0.3932, -0.3930,0.6452,  0.1690],[ 0.7064,  0.6903, -1.0447, -1.2308, -0.3932,  1.2521,  0.6452,0.1690,  0.3744],[ 0.6903, -1.0447,  0.8688, -0.3932,  1.2521, -0.2523,  0.1690,0.3744,  0.2015],[ 0.6403,  1.3915, -1.9529, -0.1720,  1.0843, -1.0177, -0.9648,-0.0867, -0.2926],[ 1.3915, -1.9529,  0.2899,  1.0843, -1.0177, -1.7480, -0.0867,-0.2926,  0.3010],[-1.9529,  0.2899, -0.8897, -1.0177, -1.7480, -0.5217, -0.2926,0.3010,  0.3192],[-0.1720,  1.0843, -1.0177, -0.9648, -0.0867, -0.2926,  0.1181,-0.2218,  0.0766],[ 1.0843, -1.0177, -1.7480, -0.0867, -0.2926,  0.3010, -0.2218,0.0766,  0.5914],[-1.0177, -1.7480, -0.5217, -0.2926,  0.3010,  0.3192,  0.0766,0.5914, -0.8932],[-0.9648, -0.0867, -0.2926,  0.1181, -0.2218,  0.0766, -0.4508,-0.3964,  1.1163],[-0.0867, -0.2926,  0.3010, -0.2218,  0.0766,  0.5914, -0.3964,1.1163,  0.6776],[-0.2926,  0.3010,  0.3192,  0.0766,  0.5914, -0.8932,  1.1163,0.6776, -0.8948]]]) torch.Size([1, 27, 9])

觉得写的不错的话,欢迎点赞+评论+收藏,这对我帮助很大!

PyTorch基础(13)-- torch.nn.Unfold()方法相关推荐

  1. 直观理解 torch.nn.Unfold

    torch.nn.Unfold 是把batch中的数据按 C.Kernel_W.Kernel_H 打包,详细解释参考: PyTorch中torch.nn.functional.unfold函数使用详解 ...

  2. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  3. torch.nn.Unfold()详细解释

    torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1) 功能:从一个批次的输入张量中提取出滑动的局部区域块.(Extracts sl ...

  4. torch.nn.Unfold类

    PyTorch中torch.nn.Unfold类 官方文档连接:torch.nn.Unfold类. class torch.nn.Unfold(kernel_size, dilation=1, pad ...

  5. Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化

    Pytorch 学习(6):Pytorch中的torch.nn  Convolution Layers  卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...

  6. PyTorch基础(12)-- torch.nn.BatchNorm2d()方法

    Batch Normanlization简称BN,也就是数据归一化,对深度学习模型性能的提升有很大的帮助.BN的原理可以查阅我之前的一篇博客.白话详细解读(七)----- Batch Normaliz ...

  7. PyTorch基础(七)----- torch.nn.AdaptiveAvgPool2d()方法

    一.方法简介 方法 含义:对张量应用一个自适应的平均池化,只需要输入输出的size即可. - torch.nn.AdaptiveAvgPool2d(output_size) 参数及注意事项 - 1.o ...

  8. pytorch笔记:torch.nn.Threshold

    1 使用方法 torch.nn.Threshold(threshold, value, inplace=False) 对输入进Threshold的Tensor进行阈值操作 2 使用举例 import ...

  9. pytorch笔记:torch.nn.functional.pad

    1 torch.nn.functional.pad函数 torch.nn.functional.pad是pytorch内置的tensor扩充函数,便于对数据集图像或中间层特征进行维度扩充 torch. ...

最新文章

  1. mac linux loader,M3 Bitlocker Loader Mac版
  2. 天气模式_江西现罕见持续阴雨寡照天气 市民开启“花式吐槽”模式
  3. php 接口继承,详细对比php中类继承和接口继承
  4. LinuxC高级编程——线程间同步
  5. Ubicomp2018年论文列表
  6. 前端组件化和模块化最大的区别是什么_7招提升你的前端开发效率
  7. 聚焦BCS|北京网络安全大会产业峰会:探寻产业规模增长之道
  8. SIGIR‘22 推荐系统论文之对比学习篇
  9. 好用的电脑录屏软件!来看看这几款!
  10. 数字信号处理经典书籍
  11. 芯片设计流程 芯片的设计原理图
  12. 为什么“家徒四壁”中的徒是仅仅,只有的意思?
  13. SMART 原则以及实际案例
  14. BeautyGAN图片的高精度美颜
  15. Android使用java和kotlin混合开发时 发现黄油刀BindView失效
  16. k8s——flannel网络
  17. log4j2输出中文乱码
  18. 关于babe-loader^8.0.6的配置问题
  19. 将数字上调至8的倍数
  20. 亚马逊电商数据自动化管理接口平台JAVA SP-API接口开发(中)

热门文章

  1. redis主从架构宕机问题手动解决
  2. 七.Hystrix Timeout机制
  3. ubuntu18docker下安装MySQL
  4. Python学习-基础篇4 模块与包与常用模块
  5. JAVA --BYTECODE
  6. 配置树莓派3和局域网NTP服务器实现内网时间校准
  7. LINUX下的tty,console与串口
  8. UVA 10564 计数DP
  9. ▲教你如何轻易的做linux计划任务▲——小菜一碟
  10. Machine Learning No.7: Support Vector Machines