torch.roll 函数的理解
torch.roll 函数官方解释
如果是看swin-transformer进来的,推荐看看GitHub上的这个问题,会很有帮助!
https://github.com/microsoft/Swin-Transformer/issues/38
翻译
torch.roll(input, shifts, dims=None) → Tensor
- input (Tensor) —— 输入张量。
- shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
- dims (int 或 tuple of python:int) 确定的维度。
沿给定维数滚动张量,移动到最后一个位置以外的元素将在第一个位置重新引入。如果没有指定尺寸,张量将在轧制前被压平,然后恢复到原始形状。
官方例子
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2)
>>> x
tensor([[1, 2],[3, 4],[5, 6],[7, 8]])
'''第0维度向下移1位,多出的[7,8]补充到顶部'''
>>> torch.roll(x, 1, 0)
tensor([[7, 8],[1, 2],[3, 4],[5, 6]])
'''第0维度向上移1位,多出的[1,2]补充到底部'''
>>> torch.roll(x, -1, 0)
tensor([[3, 4],[5, 6],[7, 8],[1, 2]])
'''tuple元祖,维度一一对应:
第0维度向下移2位,多出的[5,6][7,8]补充到顶部,
第1维向右移1位,多出的[6,8,2,4]补充到最左边'''
>>> torch.roll(x, shifts=(2, 1), dims=(0, 1))
tensor([[6, 5],[8, 7],[2, 1],[4, 3]])
简单理解:shifts的值为正数相当于向下挤牙膏,挤出的牙膏又从顶部塞回牙膏里面;shifts的值为负数相当于向上挤牙膏,挤出的牙膏又从底部塞回牙膏里面
以下一个多维张量的例子(参考swin transformer论文源码):
torch.roll(x, shifts=(-20, -20), dims=(1, 2))
完整代码
import torch
import numpy as np
import matplotlib.pyplot as pltshift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)if shift_size > 0:shifted_x = torch.roll(x, shifts=(-20, -20), dims=(1, 2))#shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))print("---------经过循环位移了---------")
else:shifted_x = x'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):plt.title("non_shifted")
else:plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()
torch.roll 函数的理解相关推荐
- [pytorch]torch.roll函数
torch中的roll函数可以用于张量的位置变换操作. 博客推荐 import torch import numpy as np import matplotlib.pyplot as pltshif ...
- torch.roll() 函数用法
Pytorch 官方文档:https://pytorch.org/docs/master/generated/torch.roll.htmlhttps://pytorch.org/docs/maste ...
- torch.roll图片实验
torch.roll(input, shifts, dims=None) → Tensor input为输入张量,shifts表示要滚动的方向.负数表示左上,正数表示右下.dims表示要滚动的维度. ...
- Pytorch中tensor维度和torch.max()函数中dim参数的理解
Pytorch中tensor维度和torch.max()函数中dim参数的理解 维度 参考了 https://blog.csdn.net/qq_41375609/article/details/106 ...
- 【深度学习】pytorch自动求导机制的理解 | tensor.backward() 反向传播 | tensor.detach()梯度截断函数 | with torch.no_grad()函数
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 一.pytorch里自动求导的基础概念 1.1.自动求导 requires_grad=True 1.2.求导 requ ...
- 代码阅读 | torch.sort函数
知识点https://blog.csdn.net/u012495579/article/details/106117511网友讲得非常清晰了. 代码如下: lengths = [17,17,19,23 ...
- gather torch_浅谈Pytorch中的torch.gather函数的含义
pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...
- 2021.08.22学习内容torch.cat()和torch.stack()函数
torch.cat()函数 将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起. def cat(tensors: List[torch.Tensor], d ...
- 关于torch.bmm()函数计算过程
很多框架中提供的矩阵乘法都是出于简化计算的考虑,很多情况下在进行计算时候都会牵扯到 batch size 这一个维度,这就使得很多矩阵的计算是三维的,Pytorch中的bmm()函数就可以很方便的实现 ...
最新文章
- 【青少年编程】【三级】小鸡吃虫
- Linux——top命令查看cpu利用率超过100%
- workerman mmo_2020了,我们为什么还在做MMO端游
- 稀疏多项式的运算用链表_用最简单的大白话聊一聊面试必问的HashMap原理和部分源码解析...
- java是值调用_Java 只有值调用
- 小程序 redux_Redux应用程序最重要的ESLint规则
- python是什么语言-python是什么语言?哪些人适合学习Python?
- MySQL中针对大数据量常用技术
- Apache Flink 官方文档--流(DataStream API)-旁路输出
- getch方法_C++中getch函数使用时注意事项
- 5. DICOM图像层级分类-DCMTK-压缩图像PixelData读取
- 新型计算机离我们还有多远
- 1688API item_search_img - 拍立淘搜索淘宝商品
- android 设置闹钟,android 设置闹钟
- Python 八大数据类型。
- wps如何设置表格中文字的行间距
- CodeForces - 1324D Pair of Topics(二分或双指针)
- ARM CM0 push和pop指令
- 华为交换机SEP双半环设计方案及配置详细步骤
- 论文精读:Selective Convolutional Descriptor Aggregation