pytorch PixelShuffle和Upscale函数

​ 该函数设计思想来源于2016年的一篇SR文章,Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network 原理如下图:

​ 子像素卷积的实现原理:利用卷积得到图像r2r^2r2个通道的特征图,并且特征图的大小和输入图像的大小一致,然后将特征图上的一个元素位置的r2r^2r2个特征点按次序排列开,形成r∗rr*rr∗r的像素分布,实现图像扩大的功能;

pytorch中的定义在:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/pixelshuffle.py文件中

函数定义

其中upscale_factor为放大倍数

class torch.nn.PixleShuffle(upscale_factor)

输入输出

Input:KaTeX parse error: Expected '}', got '_' at position 31: …\text { upsclae_̲factor} ^ { 2 }…

output:KaTeX parse error: Expected '}', got '_' at position 23: … \text {upscale_̲factor},W*\text…

例子:

pixel_shuffle = nn.PixelShuffle(3) #放大3倍
input = torch.randn(1, 9, 4, 4)
output = pixel_shuffle(input)
print(output.size())  #输出为[1,1,12,12]

Upsample 函数

pytorch实现文件为:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/upsampling.py

对给定多通道的1维(temporal)、2维(spatial)、3维(volumetric)数据进行上采样。

定义

class torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None)

参数说明:

  • size 是要输出的尺寸,数据类型为tuple: ([optional D_out], [optional H_out], W_out)

  • scale_factor 在高度、宽度和深度上面的放大倍数。数据类型既可以是int——表明高度、宽度、深度都扩大同一倍数;亦或是tuple——指定高度、宽度、深度的扩大倍数。

  • mode 上采样的方法,包括最近邻(nearest),线性插值(linear),双线性插值(bilinear),三次线性插值(trilinear),默认是最近邻(nearest)。

  • align_corners 如果设为True,输入图像和输出图像角点的像素将会被对齐(aligned),这只在mode = linear, bilinear, or trilinear才有效,默认为False。

例子

input=torch.arange(1,5).view(1,1,2,2).float()
input
tensor([[[[ 1.,  2.],[ 3.,  4.]]]])input
m=nn.Upsample(scale_factor=2,mode='nearest')
m(input)
tensor([[[[ 1.,  1.,  2.,  2.],[ 1.,  1.,  2.,  2.],[ 3.,  3.,  4.,  4.],[ 3.,  3.,  4.,  4.]]]])m = nn.Upsample(scale_factor=2, mode='bilinear')
m(input)
tensor([[[[ 1.0000,  1.2500,  1.7500,  2.0000],[ 1.5000,  1.7500,  2.2500,  2.5000],[ 2.5000,  2.7500,  3.2500,  3.5000],[ 3.0000,  3.2500,  3.7500,  4.0000]]]])input
m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
m(input)
tensor([[[[ 1.0000,  1.3333,  1.6667,  2.0000],[ 1.6667,  2.0000,  2.3333,  2.6667],[ 2.3333,  2.6667,  3.0000,  3.3333],[ 3.0000,  3.3333,  3.6667,  4.0000]]]])input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
input_3x3[:, :, :2, :2].copy_(input)
tensor([[[[ 1.,  2.],[ 3.,  4.]]]])input_3x3
tensor([[[[ 1.,  2.,  0.],[ 3.,  4.,  0.],[ 0.,  0.,  0.]]]])m = nn.Upsample(scale_factor=2, mode='bilinear')  # align_corners=False
m(input_3x3)
tensor([[[[ 1.0000,  1.2500,  1.7500,  1.5000,  0.5000,  0.0000],[ 1.5000,  1.7500,  2.2500,  1.8750,  0.6250,  0.0000],[ 2.5000,  2.7500,  3.2500,  2.6250,  0.8750,  0.0000],[ 2.2500,  2.4375,  2.8125,  2.2500,  0.7500,  0.0000],[ 0.7500,  0.8125,  0.9375,  0.7500,  0.2500,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]])m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
m(input_3x3)
tensor([[[[ 1.0000,  1.4000,  1.8000,  1.6000,  0.8000,  0.0000],[ 1.8000,  2.2000,  2.6000,  2.2400,  1.1200,  0.0000],[ 2.6000,  3.0000,  3.4000,  2.8800,  1.4400,  0.0000],[ 2.4000,  2.7200,  3.0400,  2.5600,  1.2800,  0.0000],[ 1.2000,  1.3600,  1.5200,  1.2800,  0.6400,  0.0000],[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]]])

pytorch PixelShuffle和Upscale函数相关推荐

  1. 关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题

    关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构, ...

  2. Pytorch中的collate_fn函数用法

    Pytorch中的collate_fn函数用法 官方的解释:   Puts each data field into a tensor with outer dimension batch size ...

  3. 关于Pytorch的F.unfold函数

    关于Pytorch的F.unfold函数 最近使用到这个函数,但是一直不明白什么意思.在此做一个测试 大致上,可以理解为这个函数在做卷积的滑动窗口.但是只有卷的部分,没有积的部分. 可以看原文 htt ...

  4. 【Pytorch】torch.argmax 函数详解

    文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...

  5. 【python】tensorflow框架中的tf.gather_nd()函数对应的 pytorch框架的gather_nd()函数

    tf.gather_nd 函数对应的pytorch函数 1. 简单介绍 2. 步入正题 2.1 tensorflow tf.gather_nd() 2.2 pytorch框架手动实现gather_nd ...

  6. PyTorch入门笔记-matmul函数详解

    PyTorch入门笔记-matmul函数详解 本文转载自:PyTorch入门笔记-matmul函数详解 - 腾讯云开发者社区-腾讯云 (tencent.com) 41409)]

  7. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  8. Batch Normalization原理及pytorch的nn.BatchNorm2d函数

    下面通过举个例子来说明Batch Normalization的原理,我们假设在网络中间经过某些卷积操作之后的输出的feature map的尺寸为4×3×2×2,4为batch的大小,3为channel ...

  9. python 近期用到的基础知识汇总(主要是numpy和pytorch的相关矩阵变化函数)(一)

    ps两个库好多类似的函数傻傻分不清,总结下. 1.np.where where()的用法 首先强调一下,where()函数对于不同的输入,返回的只是不同的. 1当数组是一维数组时,返回的值是一维的索引 ...

最新文章

  1. pytorch.forward()方法
  2. .net反射详解 原文://http://blog.csdn.net/wenyan07/article/details/27882363
  3. 烂泥:学习Nagios(三): NRPE安装及配置
  4. 设计案例——点和圆的关系
  5. 酷派删除android系统软件,【玩机教程】酷派手机root后不可删除系统自带程序+组件中英对照...
  6. mybatis3.2.2的一些测试
  7. zabbix---agent安装
  8. 主成分分析与因子分析法
  9. 我非英雄,广目无双,我本坏蛋,无限嚣张
  10. 区块链 liquity源代码分析之一 赎回奖励trove_open_liquidate
  11. 三年磨一剑大话数据结构——数据结构起源、概念和术语
  12. 音乐与计算机的论文题目,音乐类毕业论文选题参考
  13. 在uni-app中如何使用一键登录,如何使用手机号一键登录
  14. suse linux如何重置密码忘记,SUSE Linux忘记root密码的对策
  15. 西电数据挖掘实验1——二分网络上的链路预测
  16. linux源码0.11解析pdf,linux0.11 赵炯的Linux源代码剖析中的带中文注释的源代码 - 下载 - 搜珍网...
  17. 常用数学符号读法及其含义
  18. #error “Please select first the target STM32F4xx device used in your application (in stm32f4xx.h
  19. 小米电视android刷机,小米电视怎么刷机?小米电视刷第三方系统固件下载
  20. CentOS更换为阿里云的源

热门文章

  1. 2020年Java语言发展现状
  2. 大一计算机应用基础实验指导,大学计算机应用基础实验指导详解.doc
  3. python最新官网图片_Python轻松爬取Rosimm写真网站全部图片
  4. 图像超分中的深度学习网络
  5. 搞定支付接口(一) 支付宝即时到账支付接口详细流程和代码
  6. ubuntu18.04安装PCL
  7. 资产、负债及所有者权益类帐户
  8. SOLIDWORKS怎么把STEP曲面转换成实体
  9. 副业项目做什么比较靠谱,如何知道自己适合做什么?
  10. python123程序设计题说句心里话_电脑怎么写程序