就是用了一个卷积降了一下k,v 的size

可以理解为将R个点聚合成一个,然后attention的时候Q和聚合成的点的K和V算

import torch
from torch import nnclass SpatialReductionAttention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.dropout = nn.Dropout(proj_drop)self.sr_ratio = sr_ratio# 实现上这里等价于一个卷积层if sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim)def forward(self, x, H, W):B, N, D = x.shape  #N=h*wq = self.q(x).reshape(B, N, self.num_heads, D // self.num_heads).permute(0, 2, 1, 3)if self.sr_ratio > 1:x_ = x.permute(0, 2, 1).reshape(B, D, H, W)x_ = self.sr(x_).reshape(B, D, -1).permute(0, 2, 1) # 这里x_.shape = (B, N/R^2, D)x_ = self.norm(x_)kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4)else:kv = self.kv(x).reshape(B, -1, 2, self.num_heads, D // self.num_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, D)x = self.proj(x)x = self.dropout(x)return xx = torch.rand(4, 224*128, 256)
attn = SpatialReductionAttention(dim=256, sr_ratio = 2)
output = attn(x, H=224, W=128)

PVT的spatial reduction attention(SRA)相关推荐

  1. 【论文笔记】SPAN: Spatial Pyramid Attention Network for Image Manipulation Localization

    SPAN: Spatial Pyramid Attention Network for Image Manipulation Localization 发布于ECCV2020 原文链接:https:/ ...

  2. Vision Transformer在CV任务中的速度如何保证?

    本文作者丨盘子正@知乎    编辑丨极市平台 来源丨https://zhuanlan.zhihu.com/p/569482746 我(盘子正@知乎)的PhD课题是Vision Transformer的 ...

  3. Transformer合集1

    最近Transformer文章太多了 索性一起发了得~~  以后关于这个的都不单发了 如何提高ViT的效率?可以是让模型更容易训练,减少训练时间,也可以减少模型部署在硬件上的功耗等等.本文主要讲inf ...

  4. SepViT:可分离视觉Transformer

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 转载自:集智书童 SepViT: Separable Vision Transformer 论文:https ...

  5. 论文精读:PVT v2: Improved Baselines with Pyramid Vision Transformer

    论文地址:https://arxiv.org/abs/2106.13797 源码地址:https://github.com/whai362/PVT Abstract 在这项工作中,作者改进了PVT v ...

  6. 论文:Pyramid Vision Transformer

    Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions 金字塔视觉Tran ...

  7. PVTV2--Pyramid Vision TransformerV2学习笔记

    PVTV2–Pyramid Vision TransformerV2学习笔记 PVTv2: Improved Baselines with Pyramid Vision Transformer Abs ...

  8. 【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

    目录 0. 详情 1. 简述 2.主要工作 2.1 ViT遗留的问题 2.2 引入金字塔结构 3.PVT的设计方案 3.1 Patch embedding 代码 3.2position embeddi ...

  9. transformer与视觉

    目录 综述 优秀网文 基本transformer 视觉transformer原理 具体的transformer 一般方法 ViT :一张图等于 16x16 个字,计算机视觉也用上 Transforme ...

最新文章

  1. 巩固好基础,才能学好Linux
  2. 【Android UI设计与开发】第02期:引导界面(二)使用ViewPager实现欢迎引导页面
  3. 『HTML5制造仿JQuery作用』减速
  4. python中if错误-Python中常见的异常处理
  5. 滚动触发的翻转式文字引用效果
  6. GetDlgItem的用法
  7. linux怎么开ssh端口,如何查看linux中的ssh端口开启状态
  8. 首次公开!菜鸟弹性调度系统的架构设计
  9. Linux文件操作实用笔记
  10. 工作328:uni-局部过滤器处理数据
  11. 【Day12】整个前端性能提升大致分几类
  12. python有哪些知识_Python有哪些基础知识
  13. mongodb db.serverStatus() 仍然不能提示认证失败
  14. 金蝶eas怎么引出凭证_金蝶EAS该如何导出凭证
  15. Linux的学习之路grep命令
  16. Atitit 防烫伤指南与规范 attilax总结
  17. (转)软件商在做券商的事,券商在做搬运工的事,第三方正变成第三者
  18. 将网页转换成PDF文件的N种方式
  19. 【C语言】输出100内素数
  20. 我们应该怎么去认识信贷

热门文章

  1. Python 趋势:当今最热门语言的热门话题
  2. mysql replication 监控_MySQL之-Replication监控及自动故障切换的详细分析
  3. MySQL优化之Explain
  4. ATM自动取款机程序设计
  5. 易得无价宝,难得有情郎
  6. 漫谈程序员系列 怎么告别 混日子
  7. MIUI“息屏听剧”功能实现调研
  8. 常用音频软件大比拼,再也不为选择哪一款犯愁了!
  9. 从实验室跃进产业,腾讯AI是如何向to B进化的?
  10. 号脉数据中心全生命周期,业务永续从细节做起