pytorch中的torch.nn.Unfold和torch.nn.Fold

  • 目的
    • Unfold
    • Fold

目的

平时使用卷积操作时,既卷积核滑动窗口操作,对于pytorch,以二维图像为例,调用nn.Conv2d就能完成对输入(feature maps)的卷积操作。
但有时,maybe要探究卷积核对应的某一channel的单个窗口的卷积操作,或显式地进行卷积操作。此时,就需要nn.Unfold和nn.Fold。前段时间引起较大争议的BagNet(Bag of local feature net) 的分块卷积操作既由此函数完成。
一般来说,Conv2d 就是 Unfold + matmul + fold

Unfold

torch.nn.Unfold按照官方的说法,既从一个batch的样本中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。

如上图所示,蓝色框部分就是kernel filter(红色框部分)对应的滑动窗口。
首先来看下torch.nn.Unfold的参数:

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

跟nn.Conv2d的参数很相似,卷积核的尺寸,空洞大小,填充大小和步长。

官方解释中:unfold的输入为(NNN, CCC, HHH, WWW),其中N为batch_size,C是channel个数,H和W分别是channel的长宽。则unfold的输出为(NNN, C×∏C \times \prodC×∏(kernel_size), LLL),其中∏\prod∏(kernel_size)为kernel_size长和宽的乘积, L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。

以输入(1, 3, 4, 4)为例,设定kernel_size = (2, 2),stride = 2,根据官方给出的LLL计算公式:L=∏d⌊spatial_size[d]+2×padding[d]−dilation[d]×(kernel_size[d]−1)stride[d]+1⌋L = \prod_d \lfloor{\dfrac {spatial\_size[d] + 2 \times padding[d] - dilation[d] \times (kernel\_size[d]-1)}{stride[d]} + 1}\rfloorL=d∏​⌊stride[d]spatial_size[d]+2×padding[d]−dilation[d]×(kernel_size[d]−1)​+1⌋ddd是channel的维度,二维图像既长宽的维度。
则LLL(区块数量)为⌊(4+2×0−1×(2−1)−12+1)⌋×⌊(4+2×0−1×(2−1)−12+1)⌋=4\lfloor(\dfrac{4 + 2 \times 0 - 1 \times (2-1) -1}{2} + 1) \rfloor \times \lfloor(\dfrac{4 + 2 \times 0 - 1 \times (2-1) -1}{2} + 1) \rfloor = 4⌊(24+2×0−1×(2−1)−1​+1)⌋×⌊(24+2×0−1×(2−1)−1​+1)⌋=4,每个区块的大小为C×kernel_size[0]×kernel_size[1]C \times kernel\_size[0] \times kernel\_size[1]C×kernel_size[0]×kernel_size[1],既2×2×2=82 \times 2 \times 2 = 82×2×2=8,做为输出的第二个维度。
为了更直观的展示unfold函数所做的操作,以下述代码和结果为例:

inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)
torch.Size([1, 2, 4, 4])
tensor([[[[ 1.4818, -0.1026, -1.7688,  0.5384],[-0.4693, -0.0775, -0.7504,  0.2283],[-0.1414,  1.0006, -0.0942,  2.2981],[-0.9429,  1.1908,  0.9374, -1.3168]],[[-1.8184, -0.3926,  0.1875,  1.3847],[-0.4124,  0.9766, -1.3303, -0.0970],[ 1.7679,  0.6961, -1.6445,  0.7482],[ 0.1729, -0.3196, -0.1528,  0.2180]]]])
torch.Size([1, 8, 4])
tensor([[[ 1.4818, -1.7688, -0.1414, -0.0942],[-0.1026,  0.5384,  1.0006,  2.2981],[-0.4693, -0.7504, -0.9429,  0.9374],[-0.0775,  0.2283,  1.1908, -1.3168],[-1.8184,  0.1875,  1.7679, -1.6445],[-0.3926,  1.3847,  0.6961,  0.7482],[-0.4124, -1.3303,  0.1729, -0.1528],[ 0.9766, -0.0970, -0.3196,  0.2180]]])

对代码结果分析,nn.Unfold对输入channel的每一个kernel_size[0]×kernel_size[1]kernel\_size[0] \times kernel\_size[1]kernel_size[0]×kernel_size[1]的滑动窗口区块做了展平操作。

Fold

torch.nn.Fold的操作与Unfold相反,将提取出的滑动局部区域块还原成batch的张量形式。
代码如下:

fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())
tensor([[[[ 1.4818, -0.1026, -1.7688,  0.5384],[-0.4693, -0.0775, -0.7504,  0.2283],[-0.1414,  1.0006, -0.0942,  2.2981],[-0.9429,  1.1908,  0.9374, -1.3168]],[[-1.8184, -0.3926,  0.1875,  1.3847],[-0.4124,  0.9766, -1.3303, -0.0970],[ 1.7679,  0.6961, -1.6445,  0.7482],[ 0.1729, -0.3196, -0.1528,  0.2180]]]])
torch.Size([1, 2, 4, 4])

分析结果,Fold的操作通过设定output_size=(4, 4),完成与Unfold的互逆的操作。

pytorch中的torch.nn.Unfold和torch.nn.Fold相关推荐

  1. pytorch中的卷积操作详解

    首先说下pytorch中的Tensor通道排列顺序是:[batch, channel, height, width] 我们常用的卷积(Conv2d)在pytorch中对应的函数是: torch.nn. ...

  2. pytorch中实现Balanced Cross-Entropy

    当你明白了pytorch中F.cross_entropy以及F.binary_cross_entropy是如何实现的之后,你再基于它们做改进重新实现一个损失函数就很容易了. 1.背景 变化检测中,往往 ...

  3. Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

    Pytorch中的学习率调整:lr_scheduler,ReduceLROnPlateau torch.optim.lr_scheduler:该方法中提供了多种基于epoch训练次数进行学习率调整的方 ...

  4. pytorch中RNN注意事项(关于input和output维度)

    pytorch中RNN注意事项 batch_first为False的情况下,认为input的数据维度是(seq,batch,feature),output的数据维度(seq,batch,feature ...

  5. PyTorch基础(13)-- torch.nn.Unfold()方法

    前言 最近在看新论文的过程中,发现新论文中的代码非常简洁,只用了unfold和fold方法便高效的将论文的思想表达出,因此学习记录一下unfold和fold方法. 一.方法详解 方法 torch.nn ...

  6. PyTorch中的torch.nn.Parameter() 详解

    PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...

  7. opencv和pytorch中的warp操作函数:cv2.warpAffine, torch.nn.functional.grid_sample, cv2.warpPerspective

    关于图像的warp操作是指利用一个旋转缩放矩阵对图像进行操作. 常见的操作有,平移,绕某个点旋转,缩放. opencv中有getRotationMatrix2D,warpAffine, getAffi ...

  8. Pytorch中torch.nn.Softmax的dim参数含义

    自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义

  9. Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化

    Pytorch 学习(6):Pytorch中的torch.nn  Convolution Layers  卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...

  10. pytorch中torch.nn.utils.rnn相关sequence的pad和pack操作

    目录 一.pad_sequence 二.pack_padded_sequence 三.pad_packed_sequence 四.pack_sequence 自然语言处理任务中,模型的输入一般都是变长 ...

最新文章

  1. 如何在小型pcb的移动设备上获得更好的无线性能
  2. firebug 的使用
  3. 深入理解学习Git常用工作流
  4. 工业机器人电柜布线_协作并联,重新注解并联机器人
  5. linux 光盘yum源搭建
  6. C# Winfrom DataGridView DataSource绑定数据源后--解决排序问题
  7. 记录——《C Primer Plus (第五版)》第十章编程练习第十二题
  8. 斐讯k3cfe刷lede_斐讯 K3 A1 刷机经历
  9. CentOS 8使用 Kickstart配置 UEFI PXE 启动
  10. Python入门教程之安装MyEclipse插件和安装Python环境
  11. STM32Cube IDE环境安装
  12. java转大写的方法_Java字母大小写转换的方法
  13. 腾讯企业版邮箱服务器类型,腾讯邮箱企业版怎样开通,企业邮箱服务器系统申请...
  14. 什么是大数据,大数据最缺什么样的人才?
  15. 胃与十二指肠溃疡的食疗方
  16. Python的电子邮件操作
  17. psid mysql_使用Python对MySQL数据库插入二十万条数据
  18. html中怎样插入视频博客园,关于博客园内嵌入bilibili视频
  19. sql server 数字转大写
  20. Fastadmin创蓝短信插件源码

热门文章

  1. Draco嵌入式AI开发板使用手册V0.1.1
  2. 向云再出发:如数据般飞驰的内蒙古
  3. 蓝桥杯校赛第十二届第二期模拟赛 c语言
  4. Matlab学习笔记(8)——hist函数
  5. 通过OPENSSL建立证书以及CSR证书签名过程
  6. 中国需要怎样的智慧城市联盟?中外41家联盟组织大起底
  7. PMSG孕马血清促性腺激素适用的应用方案
  8. Spring文件上传接口学习(MultipartFile,MultiparHttpservletRequest,MultipartResolver)
  9. 日期格式 Wed Oct 16 00:00:00 CEST 2020 转换
  10. python中、常见的结构化数据不包括_数据分析的主要内容仍是结构化计算_数据分析师...