pytorch中的torch.nn.Unfold和torch.nn.Fold
pytorch中的torch.nn.Unfold和torch.nn.Fold
- 目的
- Unfold
- Fold
目的
平时使用卷积操作时,既卷积核滑动窗口操作,对于pytorch,以二维图像为例,调用nn.Conv2d就能完成对输入(feature maps)的卷积操作。
但有时,maybe要探究卷积核对应的某一channel的单个窗口的卷积操作,或显式地进行卷积操作。此时,就需要nn.Unfold和nn.Fold。前段时间引起较大争议的BagNet(Bag of local feature net) 的分块卷积操作既由此函数完成。
一般来说,Conv2d 就是 Unfold + matmul + fold
Unfold
torch.nn.Unfold按照官方的说法,既从一个batch的样本中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。
如上图所示,蓝色框部分就是kernel filter(红色框部分)对应的滑动窗口。
首先来看下torch.nn.Unfold的参数:
torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
跟nn.Conv2d的参数很相似,卷积核的尺寸,空洞大小,填充大小和步长。
官方解释中:unfold的输入为(NNN, CCC, HHH, WWW),其中N为batch_size,C是channel个数,H和W分别是channel的长宽。则unfold的输出为(NNN, C×∏C \times \prodC×∏(kernel_size), LLL),其中∏\prod∏(kernel_size)为kernel_size长和宽的乘积, L是channel的长宽根据kernel_size的长宽滑动裁剪后,得到的区块的数量。
以输入(1, 3, 4, 4)为例,设定kernel_size = (2, 2),stride = 2,根据官方给出的LLL计算公式:L=∏d⌊spatial_size[d]+2×padding[d]−dilation[d]×(kernel_size[d]−1)stride[d]+1⌋L = \prod_d \lfloor{\dfrac {spatial\_size[d] + 2 \times padding[d] - dilation[d] \times (kernel\_size[d]-1)}{stride[d]} + 1}\rfloorL=d∏⌊stride[d]spatial_size[d]+2×padding[d]−dilation[d]×(kernel_size[d]−1)+1⌋ddd是channel的维度,二维图像既长宽的维度。
则LLL(区块数量)为⌊(4+2×0−1×(2−1)−12+1)⌋×⌊(4+2×0−1×(2−1)−12+1)⌋=4\lfloor(\dfrac{4 + 2 \times 0 - 1 \times (2-1) -1}{2} + 1) \rfloor \times \lfloor(\dfrac{4 + 2 \times 0 - 1 \times (2-1) -1}{2} + 1) \rfloor = 4⌊(24+2×0−1×(2−1)−1+1)⌋×⌊(24+2×0−1×(2−1)−1+1)⌋=4,每个区块的大小为C×kernel_size[0]×kernel_size[1]C \times kernel\_size[0] \times kernel\_size[1]C×kernel_size[0]×kernel_size[1],既2×2×2=82 \times 2 \times 2 = 82×2×2=8,做为输出的第二个维度。
为了更直观的展示unfold函数所做的操作,以下述代码和结果为例:
inputs = torch.randn(1, 2, 4, 4)
print(inputs.size())
print(inputs)
unfold = torch.nn.Unfold(kernel_size=(2, 2), stride=2)
patches = unfold(inputs)
print(patches.size())
print(patches)
torch.Size([1, 2, 4, 4])
tensor([[[[ 1.4818, -0.1026, -1.7688, 0.5384],[-0.4693, -0.0775, -0.7504, 0.2283],[-0.1414, 1.0006, -0.0942, 2.2981],[-0.9429, 1.1908, 0.9374, -1.3168]],[[-1.8184, -0.3926, 0.1875, 1.3847],[-0.4124, 0.9766, -1.3303, -0.0970],[ 1.7679, 0.6961, -1.6445, 0.7482],[ 0.1729, -0.3196, -0.1528, 0.2180]]]])
torch.Size([1, 8, 4])
tensor([[[ 1.4818, -1.7688, -0.1414, -0.0942],[-0.1026, 0.5384, 1.0006, 2.2981],[-0.4693, -0.7504, -0.9429, 0.9374],[-0.0775, 0.2283, 1.1908, -1.3168],[-1.8184, 0.1875, 1.7679, -1.6445],[-0.3926, 1.3847, 0.6961, 0.7482],[-0.4124, -1.3303, 0.1729, -0.1528],[ 0.9766, -0.0970, -0.3196, 0.2180]]])
对代码结果分析,nn.Unfold对输入channel的每一个kernel_size[0]×kernel_size[1]kernel\_size[0] \times kernel\_size[1]kernel_size[0]×kernel_size[1]的滑动窗口区块做了展平操作。
Fold
torch.nn.Fold的操作与Unfold相反,将提取出的滑动局部区域块还原成batch的张量形式。
代码如下:
fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2)
inputs_restore = fold(patches)
print(inputs_restore)
print(inputs_restore.size())
tensor([[[[ 1.4818, -0.1026, -1.7688, 0.5384],[-0.4693, -0.0775, -0.7504, 0.2283],[-0.1414, 1.0006, -0.0942, 2.2981],[-0.9429, 1.1908, 0.9374, -1.3168]],[[-1.8184, -0.3926, 0.1875, 1.3847],[-0.4124, 0.9766, -1.3303, -0.0970],[ 1.7679, 0.6961, -1.6445, 0.7482],[ 0.1729, -0.3196, -0.1528, 0.2180]]]])
torch.Size([1, 2, 4, 4])
分析结果,Fold的操作通过设定output_size=(4, 4),完成与Unfold的互逆的操作。
pytorch中的torch.nn.Unfold和torch.nn.Fold相关推荐
- pytorch中的卷积操作详解
首先说下pytorch中的Tensor通道排列顺序是:[batch, channel, height, width] 我们常用的卷积(Conv2d)在pytorch中对应的函数是: torch.nn. ...
- pytorch中实现Balanced Cross-Entropy
当你明白了pytorch中F.cross_entropy以及F.binary_cross_entropy是如何实现的之后,你再基于它们做改进重新实现一个损失函数就很容易了. 1.背景 变化检测中,往往 ...
- Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau
Pytorch中的学习率调整:lr_scheduler,ReduceLROnPlateau torch.optim.lr_scheduler:该方法中提供了多种基于epoch训练次数进行学习率调整的方 ...
- pytorch中RNN注意事项(关于input和output维度)
pytorch中RNN注意事项 batch_first为False的情况下,认为input的数据维度是(seq,batch,feature),output的数据维度(seq,batch,feature ...
- PyTorch基础(13)-- torch.nn.Unfold()方法
前言 最近在看新论文的过程中,发现新论文中的代码非常简洁,只用了unfold和fold方法便高效的将论文的思想表达出,因此学习记录一下unfold和fold方法. 一.方法详解 方法 torch.nn ...
- PyTorch中的torch.nn.Parameter() 详解
PyTorch中的torch.nn.Parameter() 详解 今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实 ...
- opencv和pytorch中的warp操作函数:cv2.warpAffine, torch.nn.functional.grid_sample, cv2.warpPerspective
关于图像的warp操作是指利用一个旋转缩放矩阵对图像进行操作. 常见的操作有,平移,绕某个点旋转,缩放. opencv中有getRotationMatrix2D,warpAffine, getAffi ...
- Pytorch中torch.nn.Softmax的dim参数含义
自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义
- Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化
Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...
- pytorch中torch.nn.utils.rnn相关sequence的pad和pack操作
目录 一.pad_sequence 二.pack_padded_sequence 三.pad_packed_sequence 四.pack_sequence 自然语言处理任务中,模型的输入一般都是变长 ...
最新文章
- 如何在小型pcb的移动设备上获得更好的无线性能
- firebug 的使用
- 深入理解学习Git常用工作流
- 工业机器人电柜布线_协作并联,重新注解并联机器人
- linux 光盘yum源搭建
- C# Winfrom DataGridView DataSource绑定数据源后--解决排序问题
- 记录——《C Primer Plus (第五版)》第十章编程练习第十二题
- 斐讯k3cfe刷lede_斐讯 K3 A1 刷机经历
- CentOS 8使用 Kickstart配置 UEFI PXE 启动
- Python入门教程之安装MyEclipse插件和安装Python环境
- STM32Cube IDE环境安装
- java转大写的方法_Java字母大小写转换的方法
- 腾讯企业版邮箱服务器类型,腾讯邮箱企业版怎样开通,企业邮箱服务器系统申请...
- 什么是大数据,大数据最缺什么样的人才?
- 胃与十二指肠溃疡的食疗方
- Python的电子邮件操作
- psid mysql_使用Python对MySQL数据库插入二十万条数据
- html中怎样插入视频博客园,关于博客园内嵌入bilibili视频
- sql server 数字转大写
- Fastadmin创蓝短信插件源码
热门文章
- Draco嵌入式AI开发板使用手册V0.1.1
- 向云再出发:如数据般飞驰的内蒙古
- 蓝桥杯校赛第十二届第二期模拟赛 c语言
- Matlab学习笔记(8)——hist函数
- 通过OPENSSL建立证书以及CSR证书签名过程
- 中国需要怎样的智慧城市联盟?中外41家联盟组织大起底
- PMSG孕马血清促性腺激素适用的应用方案
- Spring文件上传接口学习(MultipartFile,MultiparHttpservletRequest,MultipartResolver)
- 日期格式 Wed Oct 16 00:00:00 CEST 2020 转换
- python中、常见的结构化数据不包括_数据分析的主要内容仍是结构化计算_数据分析师...