torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

参数

  • kernel_size,滑动窗口的大小
  • dilation,和空洞卷积中的dilation一样,控制滑动窗口取值的间隔
  • padding,控制输入张量四周padding的大小
  • stride,控制滑动窗口滑动的步长

输入输出维度

  • 输入维度为 (N, C, *),其中N代表batch_size,C为通道数,*代表任意的空间维度。
  • 输出维度为 \((N,C\times \prod (kernel\_size),L)\),其中 \(\prod (kernel\_size)\) 表示滑动窗口的大小,\(L\) 表示滑动窗口在输入张量中可以滑动的所有位置的个数。\(L\) 的大小计算如下

这里 \(L\) 的计算方式和卷积中输出大小的计算是一模一样的

这个函数的作用就是通过一个滑动窗口遍历输入张量,在这个过程中,取出输入张量中滑动窗口对应位置的值,组成输出向量。

这个过程有点类似于卷积,卷积过程中滑动窗口就是卷积核,卷积核滑动到某个位置,取出输入张量中对应位置的值与该卷积核对应位置的值分别相乘最后相加得到单个值的输出。而unfold是滑动到某一位置,直接取出输入张量中对应位置的值,然后flatten成一个一维向量,放到输出张量的对应位置处。

比如下图中的一个输入张量,batch_size=1,shape=(1, C, W, H),我们想沿WxH维度取出一个个局部特征,就可以通过unfold来实现。

示例

如下代码所示,输入x是一个维度为(2, 5, 3, 4)的张量,其中2是batch_size,5是通道数,3,4代表spatial维度的大小,对应上面输入维度中的 *

kernel大小为(2, 3),unfold做的就是在spatial维度滑动该kernel对应的窗口,取出对应的值,注意这里滑动窗口时取窗口内所有通道的值。上面的 \(\prod (kernel\_size)=2\times 3=6\),窗口内共有2x3x5=30个值,对应输出维度中的 \(C\times \prod (kernel\_size)\)。然后把窗口内对应的block展平成一个一维向量,放到输出中。接下来因为默认stride=1,dilation=1,padding=0,因此(2, 3)大小的kernel在维度(3, 4)中两个方向都只能滑动一次,因此滑动窗口只有4个位置,对应上面的 \(L=4\)。

x = torch.randn(2, 5, 3, 4)
output = F.unfold(x, kernel_size=(2, 3))
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
print(output.size())
# torch.Size([2, 30, 4])

参考

Unfold — PyTorch 1.12 documentation

PyTorch中torch.nn.functional.unfold函数使用详解_咆哮的阿杰的博客-CSDN博客_pytorch unfold

torch.nn.functional.unfold 用法解读相关推荐

  1. 【Pytorch 】nn.functional.unfold()==>卷积操作中的提取kernel filter对应的滑动窗口

    使用方法: def unfold(input, kernel_size, dilation=1, padding=0, stride=1):"""input: tenso ...

  2. [Pytorch]torch.nn.functional.conv2d与深度可分离卷积和标准卷积

    torch.nn.functional.conv2d与深度可分离卷积和标准卷积 前言 F.conv2d与nn.Conv2d F.conv2d 标准卷积考虑Batch的影响 深度可分离卷积 深度可分离卷 ...

  3. torch.nn.functional.interpolate函数

    torch.nn.functional.interpolate实现插值和上采样 torch.nn.functional.interpolate(input, size=None, scale_fact ...

  4. torch.nn.functional.cross_entropy.ignore_index

    ignore_index表示计算交叉熵时,自动忽略的标签值,example: import torch import torch.nn.functional as F pred = [] pred.a ...

  5. 【pytorch】torch.nn.functional.pad的使用

    torch.nn.functional.pad 是对Tensor做padding,输入的参数必须的torch的Tensor 一般地,习惯上会做如下声明 import torch.nn.function ...

  6. torch.nn.functional.pad

    作用 用来对一个tensor进行填充.最典型的就是图片了,原来是2*2的,现在想要变成3*3的,那么就需要填充,此时有很多选择,例如是在原来的右上进行填充还是左下?又或者是左上?等等. 这个函数就可以 ...

  7. PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx

    PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx 在写 PyTorch 代码时,我们会发现在 torch.nn.xxx 和 torch.nn.funct ...

  8. pytorch笔记:torch.nn.functional.pad

    1 torch.nn.functional.pad函数 torch.nn.functional.pad是pytorch内置的tensor扩充函数,便于对数据集图像或中间层特征进行维度扩充 torch. ...

  9. torch.nn.functional.pad(input, pad, mode=‘constant‘, value=0)

    torch.nn.functional.pad(input, pad, mode='constant', value=0) 填充Tensor. 填充大小: 填充input的某些维度的填充大小从最后一个 ...

最新文章

  1. java设计模式 观察者模式_理解java设计模式之观察者模式
  2. TensorFlow、MXNet、Keras如何取舍? 常用深度学习框架对比
  3. 数据处理的两个基本问题---汇编学习笔记
  4. dart系列之:还在为编码解码而烦恼吗?用dart试试
  5. 人工智能第六课:如何做研究
  6. 女博士年薪156万入职华为!实力演绎美貌与智慧并重
  7. android访问静态内部类,Java 内部类详解
  8. Linux中zsh插件,ubuntu / zsh shell / oh-my-zsh / 常用插件
  9. MySQL-高并发优化
  10. 小程序数据框有重影_小程序开发(二):数据绑定
  11. Unity Shader数学基础——笛卡尔坐标,点,矢量
  12. 十大必知开源WebRTC服务器
  13. R语言学习记录:array()函数
  14. win8: html5+css3+js
  15. Oracle报错:不是单组分组函数解决
  16. 基于STM8的数字温度计设计
  17. 利用Ubuntu的U盘安装盘安装build-essential
  18. macbook 插上移动硬盘后 WIFI 上不了网的解决办法
  19. matlab读取二进制文件字符串,matlab读取内容为二进制的TXT文件
  20. 工具介绍:js-beautify,整理压缩混淆后的js,html,css

热门文章

  1. BIM与GIS可视化平台的三维城市应用
  2. Gartner发布2023年十大战略技术趋势,元宇宙等技术上榜
  3. 数字矿山AI综合监控平台
  4. 计算机专业课ds是什么,ds学长科普贴之扯谈计算机
  5. [文献阅读]——ERNIE-Gram: Pre-Training with Explicitly N-Gram Masked Language Modeling for NLU(TBC)
  6. Python实现A*算法的十五数码
  7. 近视200度能学计算机吗,近视200度大概是4.几,4.6的视力相当于近视多少度。很多人不知道...
  8. 抖音3d相册html代码,抖音上很火的3D立体动态相册.html
  9. unicode汉字、数字、英文等字符范围表示
  10. 两线式键盘(AD按键)电路的设计与实现