文章目录

  • 【语义分割】2021-PVT2 CVMJ
    • 1. 简介
    • 2. 网络
      • 2.1 总体架构
      • 2.2 Linear Spatial Reduction Attention
      • 2.3 Overlapping Patch Embedding(重叠补丁嵌入)
      • 2.4 Convolutional FeedForward
    • 3. 代码 pvt2-upernet

【语义分割】2021-PVT2 CVMJ

论文题目:PVT v2: Improved Baselines with Pyramid Vision Transformer

论文链接: https://arxiv.org/abs/2106.13797

论文代码:https://github.com/whai362/PVT

论文翻译:PVT,PVTv2 - 简书 (jianshu.com)

1. 简介

计算机视觉中的Transformer最近取得了令人鼓舞的进展。在这项工作中,作者通过添加3个改进设计来改进原始金字塔视觉Transformer(PVTv1),其中包括:

  • 具有卷积的局部连续特征;
  • 具有zero paddings的位置编码,
  • 具有平均汇集。

通过这些简单的修改,PVTv2在分类、检测和分割方面显著优于PVTv1。此外,PVTv2在ImageNet-1K预训练下取得了比近期作品(包括 Swin Transformer)更好的性能。

2. 网络

PVTv1[33]的主要局限性有以下三个方面:

(1)与ViT类似,在处理高分辨率输入时(如短边为800像素),PVTv1的计算复杂度相对较大。

(2) PVTv1将一幅图像视为一组不重叠的patch序列,在一定程度上丧失了图像的局部连续性;

(3) PVTv1中的位置编码是固定大小的,对于处理任意大小的图像是不灵活的。这些问题限制了PVTv1在视觉任务中的性能。

2.1 总体架构

2.2 Linear Spatial Reduction Attention

用LinearSRA替代SRA。这里需要说明的一个问题,作者在PVTv1中说自己没用到卷积,但是在压缩K、V的时候使用的是Conv2D(参见github中代码)。在PVTv2中使用平均池化替代Conv2D。

2.3 Overlapping Patch Embedding(重叠补丁嵌入)

2.4 Convolutional FeedForward

3. 代码 pvt2-upernet

import torch
from torch import nn, Tensor
from torch.nn import functional as Fclass DropPath(nn.Module):"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Copied from timmThis is the same as the DropConnect impl I created for EfficientNet, etc networks, however,the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted forchanging the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use'survival rate' as the argument."""def __init__(self, p: float = None):super().__init__()self.p = pdef forward(self, x: Tensor) -> Tensor:if self.p == 0. or not self.training:return xkp = 1 - self.pshape = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_()  # binarizereturn x.div(kp) * random_tensorclass DWConv(nn.Module):def __init__(self, dim):super().__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)def forward(self, x: Tensor, H: int, W: int) -> Tensor:B, _, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)return x.flatten(2).transpose(1, 2)class MLP(nn.Module):def __init__(self, dim, hidden_dim, out_dim=None) -> None:super().__init__()out_dim = out_dim or dimself.fc1 = nn.Linear(dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, out_dim)self.dwconv = DWConv(hidden_dim)def forward(self, x: Tensor, H: int, W: int) -> Tensor:return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))class Attention(nn.Module):def __init__(self, dim, head, sr_ratio):super().__init__()self.head = headself.sr_ratio = sr_ratioself.scale = (dim // head) ** -0.5self.q = nn.Linear(dim, dim, bias=True)self.kv = nn.Linear(dim, dim * 2, bias=True)self.proj = nn.Linear(dim, dim)if sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)self.norm = nn.LayerNorm(dim)def forward(self, x: Tensor, H, W) -> Tensor:B, N, C = x.shapeq = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)if self.sr_ratio > 1:x = x.permute(0, 2, 1).reshape(B, C, H, W)x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)x = self.norm(x)k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)return xclass Block(nn.Module):def __init__(self, dim, head, sr_ratio=1, mlp_ratio=4, dpr=0.):super().__init__()self.norm1 = nn.LayerNorm(dim)self.attn = Attention(dim, head, sr_ratio)self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()self.norm2 = nn.LayerNorm(dim)self.mlp = MLP(dim, int(dim * mlp_ratio))def forward(self, x: Tensor, H, W) -> Tensor:x = x + self.drop_path(self.attn(self.norm1(x), H, W))x = x + self.drop_path(self.mlp(self.norm2(x), H, W))return xclass PatchEmbed(nn.Module):def __init__(self, c1=3, c2=64, patch_size=7, stride=4):super().__init__()self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size // 2)self.norm = nn.LayerNorm(c2)def forward(self, x: Tensor) -> Tensor:x = self.proj(x)_, _, H, W = x.shapex = x.flatten(2).transpose(1, 2)x = self.norm(x)return x, H, Wpvtv2_settings = {'B1': [2, 2, 2, 2],  # depths'B2': [3, 4, 6, 3],'B3': [3, 4, 18, 3],'B4': [3, 8, 27, 3],'B5': [3, 6, 40, 3]
}class PVTv2(nn.Module):def __init__(self, model_name: str = 'B1') -> None:super().__init__()assert model_name in pvtv2_settings.keys(), f"PVTv2 model name should be in{list(pvtv2_settings.keys())}"depths = pvtv2_settings[model_name]embed_dims = [64, 128, 320, 512]drop_path_rate = 0.1self.embed_dims = embed_dims# patch_embedself.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]# transformer encodercur = 0self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, 8, dpr[cur + i]) for i in range(depths[0])])self.norm1 = nn.LayerNorm(embed_dims[0])cur += depths[0]self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, 8, dpr[cur + i]) for i in range(depths[1])])self.norm2 = nn.LayerNorm(embed_dims[1])cur += depths[1]self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, 4, dpr[cur + i]) for i in range(depths[2])])self.norm3 = nn.LayerNorm(embed_dims[2])cur += depths[2]self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, 4, dpr[cur + i]) for i in range(depths[3])])self.norm4 = nn.LayerNorm(embed_dims[3])def forward(self, x: Tensor) -> Tensor:B = x.shape[0]# stage 1x, H, W = self.patch_embed1(x)for blk in self.block1:x = blk(x, H, W)x1 = self.norm1(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)# stage 2x, H, W = self.patch_embed2(x1)for blk in self.block2:x = blk(x, H, W)x2 = self.norm2(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)# stage 3x, H, W = self.patch_embed3(x2)for blk in self.block3:x = blk(x, H, W)x3 = self.norm3(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)# stage 4x, H, W = self.patch_embed4(x3)for blk in self.block4:x = blk(x, H, W)x4 = self.norm4(x).reshape(B, H, W, -1).permute(0, 3, 1, 2)return x1, x2, x3, x4class PPM(nn.ModuleList):"""金字塔池化模型 Pyramid Pooling Modulehttps://arxiv.org/abs/1612.01105CVPR 2017年 的工作使用最大池化,获取"""def __init__(self, pool_sizes, in_channels, out_channels):super(PPM, self).__init__()self.pool_sizes = pool_sizesself.in_channels = in_channelsself.out_channels = out_channelsfor pool_size in pool_sizes:self.append(nn.Sequential(nn.AdaptiveMaxPool2d(pool_size),nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1),))def forward(self, x):out_puts = []for ppm in self:ppm_out = nn.functional.interpolate(ppm(x), size=(x.size(2), x.size(3)), mode='bilinear',align_corners=True)out_puts.append(ppm_out)return out_putsclass PPMHEAD(nn.Module):def __init__(self, in_channels, out_channels, pool_sizes=[1, 2, 3, 6], ):super(PPMHEAD, self).__init__()self.pool_sizes = pool_sizesself.in_channels = in_channelsself.out_channels = out_channelsself.psp_modules = PPM(self.pool_sizes, self.in_channels, self.out_channels)self.final = nn.Sequential(nn.Conv2d(self.in_channels + len(self.pool_sizes) * self.out_channels, 4 * self.out_channels,kernel_size=1),nn.BatchNorm2d(4 * self.out_channels),nn.ReLU(),)def forward(self, x):out = self.psp_modules(x)out.append(x)out = torch.cat(out, 1)out = self.final(out)return outclass FPNHEAD(nn.Module):def __init__(self, out_channels=512, num_classes=19, channels=[64, 128, 320, 512]):"""Args:out_channels: 最后一层融合的 通道数,在分类前的通道数num_classes:  最后分类数目channels: 四层backbone的通道数"""super(FPNHEAD, self).__init__()self.num_classes = num_classesself.PPMHead = PPMHEAD(in_channels=channels[-1], out_channels=channels[-1] // 4)self.Conv_fuse1 = nn.Sequential(nn.Conv2d(channels[-2], channels[-2], 1),nn.BatchNorm2d(channels[-2]),nn.ReLU())self.Conv_fuse1_ = nn.Sequential(nn.Conv2d(channels[-2] + channels[-1], channels[-2], 1),nn.BatchNorm2d(channels[-2]),nn.ReLU())self.Conv_fuse2 = nn.Sequential(nn.Conv2d(channels[-3], channels[-3], 1),nn.BatchNorm2d(channels[-3]),nn.ReLU())self.Conv_fuse2_ = nn.Sequential(nn.Conv2d(channels[-3] + channels[-2], channels[-3], 1),nn.BatchNorm2d(channels[-3]),nn.ReLU())self.Conv_fuse3 = nn.Sequential(nn.Conv2d(channels[-4], channels[-4], 1),nn.BatchNorm2d(channels[-4]),nn.ReLU())self.Conv_fuse3_ = nn.Sequential(nn.Conv2d(channels[-4] + channels[-3], channels[-4], 1),nn.BatchNorm2d(channels[-4]),nn.ReLU())self.fuse_all = nn.Sequential(nn.Conv2d(sum(channels), out_channels, 1),nn.BatchNorm2d(out_channels),nn.ReLU())self.cls_seg = nn.Sequential(nn.Conv2d(out_channels, self.num_classes, kernel_size=3, padding=1),)def forward(self, input_fpn):"""Args:input_fpn: 四个特征图Returns:"""############################### x1 = torch.randn(1, 64, 56, 56)# x2 = torch.randn(1, 128, 28, 28)# x3 = torch.randn(1, 320, 14, 14)# x4 = torch.randn(1, 512, 7, 7)#  1/32的特征图 使用PPMHead torch.Size([1, 2048, 7, 7])# x1= [1, 512, 7, 7]x1 = self.PPMHead(input_fpn[-1])# print(x1.shape)# [1, 512, 7, 7]-->[1, 512, 14, 14]x = F.interpolate(x1,size=(x1.size(2) * 2, x1.size(3) * 2),mode='bilinear',align_corners=True)# 融合1/16的图  torch.Size([1, 3072, 14, 14])。仅仅在通道上拼接# [1, 512, 14, 14] + [1,320, 14, 14] =[1, 832, 14, 14]x = torch.cat([x, self.Conv_fuse1(input_fpn[-2])], dim=1)############################### [1, 832, 14, 14] -->[1, 320, 14, 14] ,进行通道数上的减少x2 = self.Conv_fuse1_(x)#  [1, 320, 14, 14]->[1, 320, 28,28]x = F.interpolate(x2,size=(x2.size(2) * 2, x2.size(3) * 2),mode='bilinear',align_corners=True)# 融合1/8的图# [1, 320, 28,28] +[1,  128, 28,28] = [1,  448, 28,28]x = torch.cat([x, self.Conv_fuse2(input_fpn[-3])], dim=1)# print(x.shape)############################### [1,  448, 28,28] -> [1, 128, 28, 28]进行通道上缩减。x3 = self.Conv_fuse2_(x)#  对1/8---> 1/4# [1, 128, 28, 28]-> [1, 128, 56, 56]x = F.interpolate(x3,size=(x3.size(2) * 2, x3.size(3) * 2),mode='bilinear',align_corners=True)# 融合1/4的图# [1, 128, 56, 56]+[1, 64, 56, 56]=[1, 192, 56, 56]x = torch.cat([x, self.Conv_fuse3(input_fpn[-4])], dim=1)############################### [1, 192, 56, 56]-> [1, 64, 56, 56]x4 = self.Conv_fuse3_(x)x1 = F.interpolate(x1, x4.size()[-2:], mode='bilinear', align_corners=True)x2 = F.interpolate(x2, x4.size()[-2:], mode='bilinear', align_corners=True)x3 = F.interpolate(x3, x4.size()[-2:], mode='bilinear', align_corners=True)x = self.fuse_all(torch.cat([x1, x2, x3, x4], 1))# print(x.shape)x = F.interpolate(x, size=(x.size(2) * 4, x.size(3) * 4), mode='bilinear', align_corners=True)# print(x.shape)x = self.cls_seg(x)return xclass pvt2_upernet(nn.Module):def __init__(self, num_classes, channels, size="B1"):"""类别数Args:num_classes:"""super(pvt2_upernet, self).__init__()self.backbone = PVTv2(size)self.decoder = FPNHEAD(num_classes=num_classes, channels=channels)def forward(self, x):x = self.backbone(x)x = self.decoder(x)return xdef pvt2_B1_upernet(num_classes):model = pvt2_upernet(num_classes=num_classes, size="B1", channels=[64, 128, 320, 512])return modeldef pvt2_B2_upernet(num_classes):model = pvt2_upernet(num_classes=num_classes, size="B2", channels=[64, 128, 320, 512])return modeldef pvt2_B3_upernet(num_classes):model = pvt2_upernet(num_classes=num_classes, size="B3", channels=[64, 128, 320, 512])return modeldef pvt2_B4_upernet(num_classes):model = pvt2_upernet(num_classes=num_classes, size="B3", channels=[64, 128, 320, 512])return modelif __name__ == '__main__':x=torch.randn(1,3,224,224)model=pvt2_B2_upernet(num_classes=19)y=model(x)print(y.shape)

参考资料

PVTv2: Improved Baselines with Pyramid Vision Transformer——PVT2解读 - 知乎 (zhihu.com)

Transformer主干网络——PVT_V2保姆级解析_只会git clone的程序员的博客-CSDN博客

【语义分割】2021-PVT2 CVMJ相关推荐

  1. ICCV 2021 | 简而优:用分类器变换器进行小样本语义分割

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | 卢治合 编辑 | 王晔 本文是对发表于计算机视觉领域的顶级 ...

  2. 三种基于自监督深度估计的语义分割方法(arXiv 2021)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨泡泡机器人 来源丨泡泡机器人SLAM 标题: Three Ways to Improve Sem ...

  3. ICCV 2021 | PMF: 基于视觉感知的多传感器融合点云语义分割方法

    作者丨月明星稀风萧萧@知乎 来源丨https://zhuanlan.zhihu.com/p/419187044 编辑丨3D视觉工坊 今天,我将分享一个 ICCV 2021 中的工作,基于视觉感知的多传 ...

  4. ICCV 2021 | Transformer再助力!用CWT进行小样本语义分割

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者 | 卢治合  编辑 | 王晔 本文转载自:AI科技评论 本文是对发表于计算机视觉领域的顶级会议 ICC ...

  5. CVPR 2021 | 北大MSRA提出CPS:基于交叉伪监督的半监督语义分割

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Charles  |  源:知乎 https://zhuanlan.zhihu.com/p/37812 ...

  6. 【文献阅读】用对比学习做弱监督语义分割(Sung-Hoon Yoon等人,ArXiv,2021)

    一.背景 文章题目:<Exploring Pixel-level Self-supervision for Weakly Supervised Semantic Segmentation> ...

  7. 【深度学习】语义分割:论文阅读(NeurIPS 2021)MaskFormer: per-pixel classification is not all you need

    目录 详情 知识补充 语义分割 实例分割 基本流程 主要技术路线 自上而下的实例分割方法 自下而上的实例分割方法 掩膜 Mask 什么是mask掩码? mask掩码有什么用? mask classif ...

  8. CVPR2021|基于双边扩充和自适应融合方法的点云语义分割网络

    Semantic Segmentation for Real Point Cloud Scenes via Bilateral Augmentation and Adaptive Fusion 1.M ...

  9. 重磅发布!Google语义分割新数据集来啦!又一个分割SOTA模型

    [导读]自动驾驶里视觉一直为人所诟病,特斯拉就是经常被拉出来批判的典型.谷歌最近开发了一个新模型,效果拔群,已被CVPR2021接收. 对于人来说,看一张平面照片能够想象到重建后的3D场景布局,能够根 ...

  10. 如何用PyTorch进行语义分割?一个教程教会你|资源

    木易 发自 凹非寺  量子位 报道 | 公众号 QbitAI 很久没给大家带来教程资源啦. 正值PyTorch 1.7更新,那么我们这次便给大家带来一个PyTorch简单实用的教程资源:用PyTorc ...

最新文章

  1. oracle dataguard
  2. Centos 7.0设置/etc/rc.local无效问题解决
  3. eclipse 启动后maven插件报错
  4. 需要符合互联网时代需求的《飞秋》
  5. 由B+树看MySQL索引结构
  6. 缓存最关心的问题是什么
  7. Linux运维跳槽40道面试精华题
  8. Axure中SVG矢量图标的使用方法
  9. smobiler仿饿了么app筛选页面
  10. Python 调用 kafka 构建完整实例分析与应用
  11. 员工效率低下,责任在管理层的数学解释和分析
  12. h5php大转盘抽奖,html5转盘抽奖 完整代码下载(网页版)
  13. nginx: [emerg] open() /var/run/nginx/nginx.pid failed (2: No such file or directory)解决方法
  14. Android蓝牙搜索连接通信
  15. 工程伦理学_笔记(复习用)
  16. 网络安全“攻防战”:“魔”“道”大盘点
  17. word段落间距调整:格式刷/取消对齐网格/分节符/擦除格式都无效的应对方法
  18. Android长图文截图的实现(支持截取第三方app)-(一)
  19. VS2017生成项目报 找不到资产文件“xxxx\obj\project.assets.json
  20. C#由于从未加载设计器

热门文章

  1. MySql Lock wait timeout exceeded该如何处理
  2. cmd输入光标消失解决
  3. 什么是localhost(127.0.0.1)?
  4. Debian 7 源(32/64bit)好用的源
  5. 计算机技能大赛简讯内,【报道】2010学西城区职业高中计算机排版技能竞赛简讯...
  6. PPT:WMS仓储系统解决方案
  7. kaka启动出现:Java HotSpot(TM) 64-Bit Server VM warning: INFO: os::commit_memory(0x00000c00000, 1073,0)
  8. 管理学研究中应用计算机仿真,计算机仿真在企业流程再造中应用研究.doc
  9. Factory IO的应用(一)
  10. QT 读BIN文件的两种方式