pytorch PixelShuffle和Upscale函数
pytorch PixelShuffle和Upscale函数
该函数设计思想来源于2016年的一篇SR文章,Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network 原理如下图:
子像素卷积的实现原理:利用卷积得到图像r2r^2r2个通道的特征图,并且特征图的大小和输入图像的大小一致,然后将特征图上的一个元素位置的r2r^2r2个特征点按次序排列开,形成r∗rr*rr∗r的像素分布,实现图像扩大的功能;
pytorch中的定义在:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/pixelshuffle.py文件中
函数定义
其中upscale_factor为放大倍数
class torch.nn.PixleShuffle(upscale_factor)
输入输出
Input:KaTeX parse error: Expected '}', got '_' at position 31: …\text { upsclae_̲factor} ^ { 2 }…
output:KaTeX parse error: Expected '}', got '_' at position 23: … \text {upscale_̲factor},W*\text…
例子:
pixel_shuffle = nn.PixelShuffle(3) #放大3倍
input = torch.randn(1, 9, 4, 4)
output = pixel_shuffle(input)
print(output.size()) #输出为[1,1,12,12]
Upsample 函数
pytorch实现文件为:https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/upsampling.py
对给定多通道的1维(temporal)、2维(spatial)、3维(volumetric)数据进行上采样。
定义
class torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None)
参数说明:
size 是要输出的尺寸,数据类型为tuple: ([optional D_out], [optional H_out], W_out)
scale_factor 在高度、宽度和深度上面的放大倍数。数据类型既可以是int——表明高度、宽度、深度都扩大同一倍数;亦或是tuple——指定高度、宽度、深度的扩大倍数。
mode 上采样的方法,包括最近邻(nearest),线性插值(linear),双线性插值(bilinear),三次线性插值(trilinear),默认是最近邻(nearest)。
align_corners 如果设为True,输入图像和输出图像角点的像素将会被对齐(aligned),这只在mode = linear, bilinear, or trilinear才有效,默认为False。
例子
input=torch.arange(1,5).view(1,1,2,2).float()
input
tensor([[[[ 1., 2.],[ 3., 4.]]]])input
m=nn.Upsample(scale_factor=2,mode='nearest')
m(input)
tensor([[[[ 1., 1., 2., 2.],[ 1., 1., 2., 2.],[ 3., 3., 4., 4.],[ 3., 3., 4., 4.]]]])m = nn.Upsample(scale_factor=2, mode='bilinear')
m(input)
tensor([[[[ 1.0000, 1.2500, 1.7500, 2.0000],[ 1.5000, 1.7500, 2.2500, 2.5000],[ 2.5000, 2.7500, 3.2500, 3.5000],[ 3.0000, 3.2500, 3.7500, 4.0000]]]])input
m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
m(input)
tensor([[[[ 1.0000, 1.3333, 1.6667, 2.0000],[ 1.6667, 2.0000, 2.3333, 2.6667],[ 2.3333, 2.6667, 3.0000, 3.3333],[ 3.0000, 3.3333, 3.6667, 4.0000]]]])input_3x3 = torch.zeros(3, 3).view(1, 1, 3, 3)
input_3x3[:, :, :2, :2].copy_(input)
tensor([[[[ 1., 2.],[ 3., 4.]]]])input_3x3
tensor([[[[ 1., 2., 0.],[ 3., 4., 0.],[ 0., 0., 0.]]]])m = nn.Upsample(scale_factor=2, mode='bilinear') # align_corners=False
m(input_3x3)
tensor([[[[ 1.0000, 1.2500, 1.7500, 1.5000, 0.5000, 0.0000],[ 1.5000, 1.7500, 2.2500, 1.8750, 0.6250, 0.0000],[ 2.5000, 2.7500, 3.2500, 2.6250, 0.8750, 0.0000],[ 2.2500, 2.4375, 2.8125, 2.2500, 0.7500, 0.0000],[ 0.7500, 0.8125, 0.9375, 0.7500, 0.2500, 0.0000],[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])m = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
m(input_3x3)
tensor([[[[ 1.0000, 1.4000, 1.8000, 1.6000, 0.8000, 0.0000],[ 1.8000, 2.2000, 2.6000, 2.2400, 1.1200, 0.0000],[ 2.6000, 3.0000, 3.4000, 2.8800, 1.4400, 0.0000],[ 2.4000, 2.7200, 3.0400, 2.5600, 1.2800, 0.0000],[ 1.2000, 1.3600, 1.5200, 1.2800, 0.6400, 0.0000],[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
pytorch PixelShuffle和Upscale函数相关推荐
- 关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题
关于PyTorch中的register_forward_hook()函数未能执行其中hook函数的问题 Hook 是 PyTorch 中一个十分有用的特性.利用它,我们可以不必改变网络输入输出的结构, ...
- Pytorch中的collate_fn函数用法
Pytorch中的collate_fn函数用法 官方的解释: Puts each data field into a tensor with outer dimension batch size ...
- 关于Pytorch的F.unfold函数
关于Pytorch的F.unfold函数 最近使用到这个函数,但是一直不明白什么意思.在此做一个测试 大致上,可以理解为这个函数在做卷积的滑动窗口.但是只有卷的部分,没有积的部分. 可以看原文 htt ...
- 【Pytorch】torch.argmax 函数详解
文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...
- 【python】tensorflow框架中的tf.gather_nd()函数对应的 pytorch框架的gather_nd()函数
tf.gather_nd 函数对应的pytorch函数 1. 简单介绍 2. 步入正题 2.1 tensorflow tf.gather_nd() 2.2 pytorch框架手动实现gather_nd ...
- PyTorch入门笔记-matmul函数详解
PyTorch入门笔记-matmul函数详解 本文转载自:PyTorch入门笔记-matmul函数详解 - 腾讯云开发者社区-腾讯云 (tencent.com) 41409)]
- PyTorch中torch.norm函数详解
torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...
- Batch Normalization原理及pytorch的nn.BatchNorm2d函数
下面通过举个例子来说明Batch Normalization的原理,我们假设在网络中间经过某些卷积操作之后的输出的feature map的尺寸为4×3×2×2,4为batch的大小,3为channel ...
- python 近期用到的基础知识汇总(主要是numpy和pytorch的相关矩阵变化函数)(一)
ps两个库好多类似的函数傻傻分不清,总结下. 1.np.where where()的用法 首先强调一下,where()函数对于不同的输入,返回的只是不同的. 1当数组是一维数组时,返回的值是一维的索引 ...
最新文章
- pytorch.forward()方法
- .net反射详解 原文://http://blog.csdn.net/wenyan07/article/details/27882363
- 烂泥:学习Nagios(三): NRPE安装及配置
- 设计案例——点和圆的关系
- 酷派删除android系统软件,【玩机教程】酷派手机root后不可删除系统自带程序+组件中英对照...
- mybatis3.2.2的一些测试
- zabbix---agent安装
- 主成分分析与因子分析法
- 我非英雄,广目无双,我本坏蛋,无限嚣张
- 区块链 liquity源代码分析之一 赎回奖励trove_open_liquidate
- 三年磨一剑大话数据结构——数据结构起源、概念和术语
- 音乐与计算机的论文题目,音乐类毕业论文选题参考
- 在uni-app中如何使用一键登录,如何使用手机号一键登录
- suse linux如何重置密码忘记,SUSE Linux忘记root密码的对策
- 西电数据挖掘实验1——二分网络上的链路预测
- linux源码0.11解析pdf,linux0.11 赵炯的Linux源代码剖析中的带中文注释的源代码 - 下载 - 搜珍网...
- 常用数学符号读法及其含义
- #error “Please select first the target STM32F4xx device used in your application (in stm32f4xx.h
- 小米电视android刷机,小米电视怎么刷机?小米电视刷第三方系统固件下载
- CentOS更换为阿里云的源