先看文字版解释相对位置编码解释

visiontransformer中使用到了可学习的绝对位置编码。

swintransformer中将相对值位置编码应用到了图像之中,其中的相对位置代码是通用的,在别的网络中也是这样用的。

1:位置编码应该加在那些地方?

2:位置编码前后的数据流是什么样的?

3:位置编码的代码是如何编写的?

答:

可学习的绝对位置编码在输入图片经过分块后,图片由(B,C,H,W)变成(B,num_patch,emb_dim)后,加上class_token后,加上位置编码。而可学习的编码则是直接初始化为(B,num_patch,emb_dim)大小的0,然后在学习中不断更新。

self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))def forward_features(self, x):# [B, C, H, W] -> [B, num_patches, embed_dim]x = self.patch_embed(x)  # [B, 196, 768]# [1, 1, 768] -> [B, 1, 768]cls_token = self.cls_token.expand(x.shape[0], -1, -1)if self.dist_token is None:x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]else:x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)x = self.pos_drop(x + self.pos_embed)

而对于相对位置编码:根据公式我们可以看到在Q与K转置相乘后与相对位置编码相加。这里使用Utnet的代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class depthwise_separable_conv(nn.Module):def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, bias=False):super().__init__()self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias, stride=stride)self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias)def forward(self, x):out = self.depthwise(x)out = self.pointwise(out)return outclass RelativePositionBias(nn.Module):# input-independent relative position attention# As the number of parameters is smaller, so use 2D here# Borrowed some code from SwinTransformer: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.pydef __init__(self, num_heads, h, w):  # (4,16,16)super().__init__()self.num_heads = num_heads #4self.h = h #16self.w = w #16self.relative_position_bias_table = nn.Parameter(torch.randn((2 * h - 1) * (2 * w - 1), num_heads) * 0.02)  # (961,4)coords_h = torch.arange(self.h)  # [0,16]coords_w = torch.arange(self.w)  # [0,16]coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # (2, 16, 16)coords_flatten = torch.flatten(coords, 1)  # (2, 256)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] #(2,256,256)relative_coords = relative_coords.permute(1, 2, 0).contiguous() #(256,256,2)#转换到大于0relative_coords[:, :, 0] += self.h - 1 #(256,256,2)relative_coords[:, :, 1] += self.w - 1relative_coords[:, :, 0] *= 2 * self.h - 1#二维转换到一维relative_position_index = relative_coords.sum(-1)  # (256, 256)self.register_buffer("relative_position_index", relative_position_index)def forward(self, H, W):#relative_position_index->(256,256)#relative_position_bias_table->(961,4)relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.h,self.w,self.h * self.w,-1)  # h, w, hw, nH (16,16,256,4)relative_position_bias_expand_h = torch.repeat_interleave(relative_position_bias, H // self.h,dim=0)  # (在dim=0维度重复7次)->(112,16,256,4)relative_position_bias_expanded = torch.repeat_interleave(relative_position_bias_expand_h, W // self.w,dim=1)  # HW, hw, nH #(在dim=1维度重复7次)relative_position_bias_expanded = relative_position_bias_expanded.view(H * W, self.h * self.w,self.num_heads).permute(2, 0,1).contiguous().unsqueeze(0)return relative_position_bias_expanded
class LinearAttention(nn.Module):def __init__(self, dim, heads=4, dim_head=64, attn_drop=0., proj_drop=0., reduce_size=16, projection='maxpool',rel_pos=True):super().__init__()self.inner_dim = dim_head * headsself.heads = headsself.scale = dim_head ** (-0.5)self.dim_head = dim_headself.reduce_size = reduce_sizeself.projection = projectionself.rel_pos = rel_pos# depthwise conv is slightly better than conv1x1# self.to_qkv = nn.Conv2d(dim, self.inner_dim*3, kernel_size=1, stride=1, padding=0, bias=True)# self.to_out = nn.Conv2d(self.inner_dim, dim, kernel_size=1, stride=1, padding=0, bias=True)self.to_qkv = depthwise_separable_conv(dim, self.inner_dim * 3)self.to_out = depthwise_separable_conv(self.inner_dim, dim)self.attn_drop = nn.Dropout(attn_drop)self.proj_drop = nn.Dropout(proj_drop)if self.rel_pos:# 2D input-independent relative position encoding is a little bit better than# 1D input-denpendent counterpartself.relative_position_encoding = RelativePositionBias(heads, reduce_size, reduce_size)# self.relative_position_encoding = RelativePositionEmbedding(dim_head, reduce_size)def forward(self, x):# x = torch.rand(1,64,112,112)B, C, H, W = x.shape# B, inner_dim, H, Wqkv = self.to_qkv(x)  # (1,768,112,112)q, k, v = qkv.chunk(3, dim=1)  # (1,256,112,112)if self.projection == 'interp' and H != self.reduce_size:# 将(k,v)插值到reduce_size大小,(1,256,16,16)k, v = map(lambda t: F.interpolate(t, size=self.reduce_size, mode='bilinear', align_corners=True), (k, v))elif self.projection == 'maxpool' and H != self.reduce_size:k, v = map(lambda t: F.adaptive_max_pool2d(t, output_size=self.reduce_size), (k, v))# q--->rearrange--->(1,256(64*4),112,112)->(1,4,12544(112,112),64)q = rearrange(q, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head, heads=self.heads,h=H, w=W)# k,v--->map--->(1,256(64*4),16,16)->(1,4,256(16,16),64)k, v = map(lambda t: rearrange(t, 'b (dim_head heads) h w -> b heads (h w) dim_head', dim_head=self.dim_head,heads=self.heads, h=self.reduce_size, w=self.reduce_size), (k, v))# q@k--->(1,4,12544,64)@(1,4,64,256)=(1,4,12544,256)q_k_attn = torch.einsum('bhid,bhjd->bhij', q, k)if self.rel_pos:relative_position_bias = self.relative_position_encoding(H, W)  # (1,4,12544,256)q_k_attn += relative_position_bias# rel_attn_h, rel_attn_w = self.relative_position_encoding(q, self.heads, H, W, self.dim_head)# q_k_attn = q_k_attn + rel_attn_h + rel_attn_wq_k_attn *= self.scaleq_k_attn = F.softmax(q_k_attn, dim=-1)q_k_attn = self.attn_drop(q_k_attn)#(1,4,12544,256)@(1,4,256,64)=(1,4,12544,64)out = torch.einsum('bhij,bhjd->bhid', q_k_attn, v)#(1,4,12544,64)--->(1,256(64*4),112,112)out = rearrange(out, 'b heads (h w) dim_head -> b (dim_head heads) h w', h=H, w=W, dim_head=self.dim_head,heads=self.heads)#(1,256(64*4),112,112)--->(1,64,112,112)out = self.to_out(out)out = self.proj_drop(out)return out, q_k_attn
def main():#--------------------------------实例化-------------------------model = LinearAttention(64) #(传入参数)print(model)# m = model.state_dict()# print(type(m))# for key,value in m.items():#     print(key)model.eval()x = torch.rand(1,64,112,112)with torch.no_grad():output,q_k_attn= model(x)print(output.shape) #(1,64,112,112)if __name__ == '__main__':main()

首先我们实例化LinearAttention类,我们输入x,首先获得x的形状,与VisionTransformer不同的是,(VIT首先会进行patchembedding,然后展平,交换维度,然后加入class_token,再加入可学习的位置编码,再经过线性层,最后生成q,k,v),而这里直接经过self.to_qkv函数,即深度可分离函数,升高维度,加入我们x大小为(1,64,112,112),维度变为(1,768,112,112)。

class depthwise_separable_conv(nn.Module):def __init__(self, in_ch, out_ch, stride=1, kernel_size=3, padding=1, bias=False):super().__init__()self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=kernel_size, padding=padding, groups=in_ch, bias=bias, stride=stride)self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=bias)def forward(self, x):out = self.depthwise(x)out = self.pointwise(out)return outself.to_qkv = depthwise_separable_conv(dim, self.inner_dim * 3)

然后我们经过chunk函数进行划分,沿着通道维度划分三份,分别为q,k,v的维度,分别为(1,256,112,112)。接着将q,k,v投射或者缩减到我们想要的维度,即(1,256,16,16),然后q经过rearrange函数,由(1256,112,112)转换到(1,4,12544,64)这里和·VIT的类似,都转换到了(B,num_head,HxW,dim_head),k和v转换到由(1,256,16,16)到(1,4,256,16),然后Q乘以K转置,维度变换为q@k--->(1,4,12544,64)@(1,4,64,256)=(1,4,12544,256)。

接着就到了我们的相对位置编码:

这里我们一步一步debug单步调试,看结果的显示

首先h和w都是16,接着我们生成要训练的relative_position_bias_table,这也是我们要用生成的索引去table查找值,具体看文章开头文字版的解释。

我们生成(2M-1)x(2M-1)个值,分别代表行和列,自己的左边和右边共有31个位置一共961个,共有4个头,所以维度为(961,4)。

然以我们生成长和宽的网格用于生成相对位置索引。长和宽都为16

然后meshgrid生成网格:

接着展平:

然后获得每个位置的索引:

交换维度:

下面的三部将索引的值限制到大于0,且将二维索引转换到一维:

下一步将相对位置索引注册到缓冲区。

在forward函数中,我们将相对位置索引展平,由长和宽拉长为序列,变为

根据生成的索引去relative_position_bias_table列表里面查找对应的值。然后我们将序列再转换到矩阵,大小为 (16,16,256,4)。

由于我们的Q@K大小为(1,4,12544,256),所以我们要将数据进行扩充,长和宽分别扩充七倍。

expand_h为:(112,16,256,4)

expanded为:(112,112,256,4)

生成的结果进行view,将(112,112,256,4)转换为(1,4,12544,256)。

这样我们的bias就与Q@k大小一致了,然后我们相加。接着乘以根号d,在与V相乘。最后reshape为原始大小即可。

最后我们看一下相对位置编码带来的效果提升:以swintransformer为例:

以语义分割为例,在ADE20K,相对位置编码为46.1,绝对位置编码为43.2,提升了快三个点。究其原因Transformer学到了归纳偏置。

相对位置编码,绝对位置编码代码pytorch实现相关推荐

  1. Transformer升级之路:二维位置的旋转式位置编码

    ©PaperWeekly 原创 · 作者 | 苏剑林 单位 | 追一科技 研究方向 | NLP.神经网络 在之前的文章 Transformer 升级之路:博采众长的旋转式位置编码中我们提出了旋转式位置 ...

  2. positional encoding位置编码详解:绝对位置与相对位置编码对比

    目录 前言 Why What 绝对位置编码 相对位置编码 Sinusoidal Position Encoding Complex embedding How 前言 相信熟悉BERT的小伙伴对posi ...

  3. 百度地图采集经纬度坐标数据定位的javascript实战开发(地理坐标拾取系统、地址定位点选插件、实时定位、数据导入、地理编码、位置纠偏)

    坐标采集 前言 1.百度地图地理坐标拾取系统 2.位置选择插件 百度地图经纬度选择插件 默认参数配置 3.数据导入 4.地理编码 爬取百度webAPI 返回参数 前端封装转换函数 5.手机GPS定位 ...

  4. transformer引入位置信息--Sinusoidal位置编码《个人学习笔记》

    transformer引入位置信息--Sinusoidal位置编码 为什么transformer需要位置编码 Sinusoidal绝对位置编码 首先,所有技术都是个人理解,并感谢技术各位分享,由此根据 ...

  5. 二代身份证编码规则及校验代码实现

    本文主要讨论的是二代身份证编码规则及其Java代码实现,下面的校验方式还不是特别严谨,由于只校验了前两位的省份信息,中间六位的出生日期信息和最后一位的校验码信息,故对于部分不满足要求的证件号码刚好同时 ...

  6. 【小代码讲解】独热编码(One-Hot编码)

    独热编码(One-Hot编码) 独热编码介绍 独热编码表示 独热编码实现 使用sklearn 不使用sklearn 独热编码介绍 在机器学习中,标签的处理总是需要进行独热编码的处理,因为独热编码有以下 ...

  7. js php base64,JavaScript实现Base64编码与解码的代码详解

    本篇文章给大家分享的是jJavaScript实现Base64编码与解码的代码详解,内容挺不错的,希望可以帮助到有需要的朋友 一.加密解密方法使用//1.加密 var str = '124中文内容'; ...

  8. 哈夫曼编码(Huffman)Java实现代码简化版

    这个网上发现的Huffuman编码Java实现在组织上相对简化,便于理解文件压缩过程:提取文件统计字符频度-根据字符频度创建huffman树-根据huffman树生成huffman可变字长无前缀编码- ...

  9. vue 怎么在字符串中指定位置插入字符_vue项目中在可编辑div光标位置插入内容的实现代码...

    vue项目中在可编辑div光标位置插入内容 html: @dragstart="dragStart($event, item.labelName)" draggable='true ...

  10. 【转】刨根究底字符编码之七——ANSI编码与代码页

    一.ANSI编码 1. 如前所述,在全世界所有国家和地区的文字符号统一编码的UCS/Unicode编码方案问世之前(UCS.Unicode后文有详细介绍),各个国家.地区为了用计算机记录并显示自己的字 ...

最新文章

  1. Multisim 12.0 笔记
  2. iphone屏蔽系统更新_iPhone 屏蔽系统更新教程,支持 iOS13 / iOS12 系统
  3. oracle备份慢,诊断Oracle RMAN备份慢的原因
  4. ftp 服务器 文件 连接 导出,ftp 服务器 文件 连接 导出
  5. Median(二分+二分)
  6. LoadRunner场景设置里的各参数解释
  7. ApplicationEventMulticaster not initialized - call 'refresh' before
  8. php 虚拟主机ip配置文件,基于IP的虚拟主机配置
  9. 关于(TabHost),(Button配合Fragment),(Menu)这三种常见的切换界面手法分析...
  10. 吴恩达机器学习系列课程笔记——第一章:什么是机器学习(Machine Learning)
  11. 项目计划概述及计划过程
  12. 谷歌卫星影像免费下载?来看这些软件
  13. 更加清晰的报名要点讲解视频(附图文介绍)
  14. Deepin系统标题栏及其按钮美化
  15. 《OKR工作法》学习总结
  16. IDEA代码以及注释格式化,行宽设置,以及自动换行
  17. python将数据做直方图_用python 制作直方图
  18. hd530黑苹果硬解_解决黑苹果HD3000核显 VGA和HDMI外接显示器无反应问题
  19. SQL Server基础操作(此随笔仅作为本人学习进度记录七 !--函数)
  20. 敲简单前端小游戏——贪吃蛇

热门文章

  1. android之挂断电话
  2. html 网页黑夜模式,网站添加暗黑模式html+js
  3. 设计模式-工厂、建造、观察
  4. global关键字(在局部作用域中声明使用全局变量)
  5. 绘画学习遇到Q版人物不会画怎么办?那你看看这个!!!
  6. ILRuntime1.安装
  7. NullPointException
  8. Python使用Web API数据可视化
  9. ppt的快捷键的使用和显示
  10. conda安装python3.8虚拟环境报错