一、前言

在阅读swin transformer这篇论文时,论文中提出了一种较为新颖的创新点,shifted window技术,该技术可以解决窗口之间信息未交互的问题,有兴趣的可以去查阅swin transformer对CNN的降维打击。在阅读代码的过程中,发现shifted window技术的实现原理其实非常简单,仅仅用到了torch.roll()这个方法,相信你理解了该方法后,会进一步理解swin transformer的高明之处。话不多说,let’s go go go!

二、torch.roll()方法解析


从pytorch的官方文档中,我们可以找到torch.roll()方法的解释:沿着给定的维度滚动tensor。在第一个位置处,重新引入超过最后一个位置的元素(这句话非常拗口,大家后续结合实例,可以很清晰地明白这句话是什么意思,请看完下列实例之后再来理解)。如果没有指定一个维度,那么tensor首先被展开,然后重新恢复到原来的shape。此外,roll()方法有三个参数。

  • input:输入的tensor
  • shifts:可以为int,也可以是int型的元组。可以理解为roll的步数(参照大富翁游戏中的步数)。如果是tuple型,那么维度必须与tuple具有相同size,并按照每个维度roll相应的步数。
  • dims:roll的维度

三、案例分析

3.1 例1 — shifts=1 & dims未指定

x = np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
x = torch.from_numpy(x)
print('before roll', x)
x = torch.roll(x, 1)
print('after roll', x)

3.2 例2 — shifts=1 & dims=0

x = np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
x = torch.from_numpy(x)
print('before roll', x)
x = torch.roll(x, 1, dims=0)
print('after roll', x)

3.3 例3 — shifts=1 & dims=1

x = np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
x = torch.from_numpy(x)
print('before roll', x)
x = torch.roll(x, 1, dims=1)
print('after roll', x)

3.4 例4 — shifts=(1,1) & dims=(0,1)

x = np.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
x = torch.from_numpy(x)
print('before roll', x)
x = torch.roll(x, (1, 1), dims=(0,1))
print('after roll', x)

参考

torch.roll方法官方解释

PyTorch基础(14)-- torch.roll()方法相关推荐

  1. 3.Pytorch基础模块torch的API之Indexing,Slicing,Joining,Mutating Ops实例详解

    文章目录 0. torch 1. Tensors 2. Creation Ops 3. Indexing,Slicing,Joining,Mutating Ops 3.1 torch.cat() 3. ...

  2. pytorch基础-使用 TORCH.AUTOGRAD 进行自动微分(5)

    在训练神经网络时,最常用的算法是反向传播.PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计 ...

  3. PyTorch基础(15)-- torch.flatten()方法

    前言 最近在复现论文中一个块的时候需要使用到torch.flatten()这个方法,这个方法其实很简单,但其中有一些细节可能需要注意,且有个关键点很容易忘记,故在此记录以备查阅. 方法解析 flatt ...

  4. PyTorch基础(十)----- torch.max()方法

    一.前言 这个方法跟上一篇文章的torch.max()方法非常类似,只不过一个是求最大值,一个是求平均值.在某些情况下,甚至可以代替下采样中的最大池化和平均池化,所以说,这两个方法的用处还是蛮大的. ...

  5. PyTorch基础(六)----- torch.eq()方法

    一.torch.eq()方法详解 对两个张量Tensor进行逐元素的比较,若相同位置的两个元素相同,则返回True:若不同,返回False. torch.eq(input, other, *, out ...

  6. [pytorch]torch.roll函数

    torch中的roll函数可以用于张量的位置变换操作. 博客推荐 import torch import numpy as np import matplotlib.pyplot as pltshif ...

  7. pytorch基础知识+构建LeNet对Cifar10进行训练+PyTorch-OpCounter统计模型大小和参数量+模型存储与调用

    整个环境的配置请参考我另一篇博客.ubuntu安装python3.5+pycharm+anaconda+opencv+docker+nvidia-docker+tensorflow+pytorch+C ...

  8. 《深度学习之pytorch实战计算机视觉》第6章 PyTorch基础(代码可跑通)

    上一篇文章<深度学习之pytorch实战计算机视觉>第5章 Python基础讲了Python基础.接下来看看第6章 PyTorch基础. 目录 6.1 PyTorch中的Tensor 6. ...

  9. pyTorch——基础学习笔记

    pytorch基础学习笔记博文,在整理的时候借鉴的大量的网上资料,存在和一部分图片定义的直接复制黏贴,在本博文的最后将会表明所有的参考链接.由于参考的内容众多,所以博文的更新是一个长久的过程,如果大佬 ...

  10. python linspace函数_Python torch.linspace方法代碼示例

    本文整理匯總了Python中torch.linspace方法的典型用法代碼示例.如果您正苦於以下問題:Python torch.linspace方法的具體用法?Python torch.linspac ...

最新文章

  1. JDK 1.5 新特性——自动拆箱装箱
  2. oracle中睡眠,sql - ORACLE中的睡眠功能 - 堆栈内存溢出
  3. 从服务器检索时出错dfdferh01_大数据实战项目之海量人脸特征检索解决方案演进...
  4. 未声明spire。它可能因保护级别而不可访问_信息系统安全:访问控制技术概述...
  5. iOS应用开发的五个Java开源工具
  6. ThreadPoolExecutor源码学习(2)-- 在thrift中的应用
  7. 【HTML+CSS网页设计与布局 从入门到精通】第7章-class、ID选择器,CSS格式
  8. java定时器每一分钟执行一次_2行代码搞定一个定时器
  9. Python 学习第十七天 jQuery
  10. Linux下QT创建项目错误处理
  11. 管道 通过匿名管道在进程间双向通信
  12. 连锁加盟网站源码_连锁60秒:招商只是开始,养商才最重要
  13. nova红a6se升级鸿蒙,华为nova 8 SE配置揭晓:麒麟芯片到底是没了
  14. 快递鸟物流电子面单批量打印对接注意事项与技术说明
  15. Latex输入大小写罗马数字
  16. 计算机导论结业报告大一,河北工业大学计算机导论结业论文
  17. 30ea什么意思_阿玛尼ga是什么意思、和ea的区别
  18. 正则表达式--常用用法及lookahead、lookbehind
  19. Java类有个星号标记_Java中import包带*(星号)问题
  20. 【GD32F310开发板试用】编码器接口的使用

热门文章

  1. 寻找矩阵行最大列最小元素
  2. 绘制14段米字数码管显示,显示数字和英文字母。
  3. 220UF25V 10*7.7SMD铝电解电容封装
  4. C语言之父丹尼斯·里奇
  5. 金山词霸2006专业版(300M)的无法屏幕取词问题的解决方法!
  6. 《Windows内核原理与实现笔记》(一)Windows系统结构和基本概念
  7. 影响中国软件开发20人
  8. 计算机excel2010完整教程视频,刘伟公益课-Excel2010基础大全(1-66集)视频教程-高清版...
  9. 电脑计算机c盘缓存清理,电脑C盘缓存文件怎么删除
  10. 动易html编辑器漏洞,动易网站管理系统删除任意文件漏洞