在pytorch中,我们常用的卷积是封装好了的卷积,如nn.Conv2d, 对应到原理图的左上那个子图

但如果封装好的这个卷积操作不能满足我们想要更细粒度的操作的话,pytorch还为我们提供了 unfold , matmul , fold 三个操作(conv = unfold + matmul + fold

torch.nn.Unfold就是原理图下面中间的那个,也就算把一个立体的tensor(feature)分成个部分(kernel_size-sized block),然后把每一个准备和kernel相乘的部分拉直。该类的构造器的参数有:

torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)

我们来看下unfold的输入和输出,其输入形状如:, 输出就是

unfold之后,我们构造个可以学习的tensor作为kernels,并把它像左下图那样展开成 ,注意这里的是kernel的个数

然后就用pytorch自带的matmul,把kernels 的展开乘unfold 之后的input tensor得到Output Maps,维度为

通过GPU并行加速相乘之后,我们还需要把计算结果Output Maps通过fold把它还原回tensor(feature),此时就需要用的pytorch中提供的

torch.nn.Fold(output_size, kernel_size, dilation=1, padding=0, stride=1)# Combines an array of sliding local blocks into a large containing tensor.# output_size (int or tuple) – the shape of the spatial dimensions of the output # (i.e., output.sizes()[2:])# kernel_size (int or tuple) – the size of the sliding blocks# stride (int or tuple) – the stride of the sliding blocks in the input spatial # dimensions. Default: 1# padding (int or tuple, optional) – implicit zero padding to be added on both sides of input. Default: 0# dilation (int or tuple, optional) – a parameter that controls the stride of elements within the neighborhood. Default: 1

fold的输入是Output Maps,输出是tensor

可以看出unfold 和 fold所作的仅仅是对 tensor的reshape而已.

当然实际上的官方文档,输入输出在维度上会在前面多加一个batch_num

实战例子:

输入 theta.shape = torch.Size([4, 256, 64, 64])

这里的

  • 输入图片大小 W×W 为(64 * 64)
  • Filter大小 F×F为(3*3)
  • 步长 S 为 1
  • padding的像素数 P 为 1

F.unfold(theta, kernel_size=3, padding=1).shape = {Size: 3} torch.Size([4, 2304, 4096])

2304  =  256 * 3 * 3

64 = (W − F + 2P )/S+1 = (64 - 3 + 2* 1) / 1 + 1

4096 = 64 * 64

图解卷积计算原理与pytorch中fold和unfold函数的使用相关推荐

  1. Pytorch中的nn.Unfold()和nn.Fold()详解

    1. nn.Unfold()函数 描述:pytorch中的nn.Unfold()函数,在图像处理领域,经常需要用到卷积操作,但是有时我们只需要在图片上进行滑动的窗口操作,将图片切割成patch,而不需 ...

  2. 【数字信号处理】卷积编程实现 ( 卷积计算原理 | 卷积公式计算 | 使用 matlab 计算卷积 | 使用 C 语言实现卷积计算 )

    文章目录 一.卷积计算原理 二.卷积计算 1.计算 y(0) 2.计算 y(1) 3.计算 y(2) 三.使用 matlab 计算卷积 四.使用 C 语言实现卷积计算 一.卷积计算原理 对于 线性时不 ...

  3. pytorch手动实现滑动窗口操作,论fold和unfold函数的使用

    ∇ \nabla ∇ 联系方式: e-mail: FesianXu@gmail.com QQ: 973926198 github: https://github.com/FesianXu 知乎专栏: ...

  4. gather torch_浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  5. python中squeeze函数_详解pytorch中squeeze()和unsqueeze()函数介绍

    squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的 ...

  6. pytorch 中 expand ()函数

    pytorch 中 expand ()函数 expand函数的功能就是 用来扩展张量中某维数据的尺寸,它返回输入张量在某维扩展为更大尺寸后的张量. 例如: x = torch.tensor([1, 2 ...

  7. [PyTorch] 深度学习框架PyTorch中的概念和函数

    Pytorch的概念 Pytorch最重要的概念是tensor,意为"张量". Variable是能够构建计算图的 tensor(对 tensor 的封装).借用Variable才 ...

  8. Pytorch中的torch.where函数

    首先我们看一下Pytorch中torch.where函数是怎样定义的: @overload def where(condition: Tensor) -> Union[Tuple[Tensor, ...

  9. opencv和pytorch中的warp操作函数:cv2.warpAffine, torch.nn.functional.grid_sample, cv2.warpPerspective

    关于图像的warp操作是指利用一个旋转缩放矩阵对图像进行操作. 常见的操作有,平移,绕某个点旋转,缩放. opencv中有getRotationMatrix2D,warpAffine, getAffi ...

  10. Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

最新文章

  1. Hhadoop环境部署
  2. 元宇宙iwemeta: 重庆打造“数据之都”,拟成立重庆数据交易所
  3. 玩美自由行体验报告 | 手摸手产品研究院
  4. V2EX › 郁闷于Python GUI开发,有没有好的框架啊?
  5. 大学有哪些专业python_python就业方向有哪些?
  6. 【C++】 类型转换
  7. 外挂学习之路(14)--- 游戏中的二叉树
  8. eclipse显示包名的方式
  9. [渝粤教育] 西南科技大学 计算机辅助设计 在线考试复习资料2021版
  10. python连接mysql数据库简单例子
  11. 十年Java面向对象编程心路——函数与方法的概念区别
  12. access和filemaker_四个替代微软Access的开源产品
  13. C语言 输出斐波那契数列
  14. 学习可爱彩色线条PS极简马克笔简笔画:饮品篇
  15. java 并发 csp_CSP与并发编程
  16. Ab3d.PowerToys 破解
  17. [Solved] Splunk: Cannot get username when all users are selected“
  18. 实用科普|推荐收藏:我的车,到底该选什么功率充电桩?
  19. html 输入框变红色,为什么CAD的动态输入框变成红色?
  20. C++ STL函数 queue (henu.hjy)

热门文章

  1. mplab java失败_Microchip工程师社区 - MPLABX用PICC编译失败 - Microchip C语言编译器论坛 - 麦田论坛...
  2. 【代码优化】考虑使用静态工厂方法代替构造器
  3. vant组件二次封装-下拉刷新列表组件
  4. vue中动画效果的实现
  5. redux-saga中间件的安装和使用-(三)
  6. html+css基础-1-屏幕居中、双飞翼布局、清除浮动
  7. nginx rtmp直播无延迟_Ubuntu中使用Nginx+rtmp搭建流媒体直播服务
  8. java中12个月_C中的12个月日历
  9. java文件读取路径_java文件读取路径问与答
  10. Mybatis日志实现