简单版ViT(无attention部分)

主要记录一下Patch Embedding怎么处理和了解一下vit的简单基本框架,下一节写完整的ViT框架


图像上的Transformer怎么处理?如图
图片—>分块patch---->映射(可学习)---->特征

整体网络结构:

实践部分:

Patch Embedding用于将原始的2维图像转换成一系列的1维patch embeddings
Patch Embedding部分代码:

class PatchEmbedding(nn.Module):def __init__(self,image_size, in_channels,patch_size, embed_dim,dropout=0.):super(PatchEmbedding, self).__init__()#patch_embed相当于做了一个卷积self.patch_embed=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size,bias=False)self.drop=nn.Dropout(dropout)def forward(self,x):# x[4, 3, 224, 224]x=self.patch_embed(x)# x [4, 16, 32, 32]# x:[n,embed_dim,h',w']x = x.flatten(2)  #将x拉直,h'和w'合并   [n,embed,h'*w']   #x [4, 16, 1024]x = x.permute(0,2,1)     # [n,h'*w',embed]      #x [4, 1024, 16]x = self.drop(x)print(x.shape)           #    [4, 1024, 16] 对应[batchsize,num_patch,embed_dim]return x

ViT部分代码:
省略了attention部分

class Vit(nn.Module):def __init__(self):super(Vit, self).__init__()self.patch_embed=PatchEmbedding(224, 3, 7, 16)     #  image tokenslayer_list = [Encoder(16) for i in range(5)]   # 假设有5层encoder,Encoder维度16self.encoders=nn.Sequential(*layer_list)self.head=nn.Linear(16,10)     #做完5层Encoder后的输出维度16,最后做分类num_classes为10self.avg=nn.AdaptiveAvgPool1d(1)       # 所有tensor去平均def forward(self,x):x=self.patch_embed(x)      # #x [4, 1024, 16]for i in self.encoders:x=i(x)# [n,h*w,c]x=x.permute((0,2,1))  # [4, 16, 1024]# [n,c,h*w]x=self.avg(x)  # [n,c,1]  [4, 16, 1]x=x.flatten(1)  # [n,c]  [4,16]x=self.head(x)return x

完整代码:

from PIL import Image
import numpy as np
import torch
import torch.nn as nn# Identity  什么都不做
class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return x#在Mlp中,其实就是两层全连接层,该mlp一般接在attention层后面。首先将16的通道膨胀4倍到64,然后再缩小4倍,最终保持通道数不变。
class Mlp(nn.Module):def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.):       #  mlp_ratio就是膨胀参数super(Mlp, self).__init__()self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio))       # 膨胀self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)      # 尺寸变回去self.act = nn.GELU()self.dropout = nn.Dropout(dropout)def forward(self,x):x = self.fc1(x)x = self.act(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return xclass PatchEmbedding(nn.Module):def __init__(self,image_size, in_channels,patch_size, embed_dim,dropout=0.):super(PatchEmbedding, self).__init__()#patch_embed相当于做了一个卷积self.patch_embed=nn.Conv2d(in_channels,embed_dim,kernel_size=patch_size,stride=patch_size,bias=False)self.drop=nn.Dropout(dropout)def forward(self,x):# x[4, 3, 224, 224]x=self.patch_embed(x)# x [4, 16, 32, 32]# x:[n,embed_dim,h',w']x = x.flatten(2)  #将x拉直,h'和w'合并   [n,embed,h'*w']   #x [4, 16, 1024]x = x.permute(0,2,1)     # [n,h'*w',embed]      #x [4, 1024, 16]x = self.drop(x)print(x.shape)           #    [4, 1024, 16] 对应[batchsize,num_patch,embed_dim]return xclass Encoder(nn.Module):def __init__(self,embed_dim):super(Encoder, self).__init__()self.atten = Identity()      # self-attention部分先不去实现self.layer_nomer = nn.LayerNorm(embed_dim)   # LN层self.mlp = Mlp(embed_dim)self.mlp_nomer = nn.LayerNorm(embed_dim)def forward(self,x):# 参差结构h = xx = self.atten(x)  # 先做self-attentionx = self.layer_nomer(x)  # 再做LN层x = h+xh = xx = self.mlp(x)  #先做FC层x = self.layer_nomer(x)  # 再做LN层x = h + xreturn xclass Vit(nn.Module):def __init__(self):super(Vit, self).__init__()self.patch_embed=PatchEmbedding(224, 3, 7, 16)     #  image tokenslayer_list = [Encoder(16) for i in range(5)]   # 假设有5层encoder,Encoder维度16self.encoders=nn.Sequential(*layer_list)self.head=nn.Linear(16,10)     #做完5层Encoder后的输出维度16,最后做分类num_classes为10self.avg=nn.AdaptiveAvgPool1d(1)       # 所有tensor去平均def forward(self,x):x=self.patch_embed(x)      # #x [4, 1024, 16]for i in self.encoders:x=i(x)# [n,h*w,c]x=x.permute((0,2,1))  # [4, 16, 1024]# [n,c,h*w]x=self.avg(x)  # [n,c,1]  [4, 16, 1]x=x.flatten(1)  # [n,c]  [4,16]x=self.head(x)return xdef test():# 1. create a imageimg=np.array(Image.open('test.jpg'))   # 224x224t = torch.tensor(img, dtype=torch.float32)print(t.shape)                # [224, 224, 3]sample = t.reshape([4,3,224,224])      # 将[224, 224, 3]reshape成一行print(sample)#print(t.transpose(1,0))# 2. patch embedding--------Patch Embedding用于将原始的2维图像转换成一系列的1维patch embeddings# patch_size是切分的大小,原始224 ∗ 224 ∗ 3 的图片会首先变成32 ∗ 32 ∗ 16# in_channel rgb图是3# embed_dim是需要映射的dimpatch_embedding = PatchEmbedding(image_size=224, patch_size=7, in_channels=3, embed_dim=1)# 做前向操作out = patch_embedding(sample)print(out)#print(out.shape)mlp=Mlp(embed_dim=1)out = mlp(out)print(out.shape)def main():t = torch.randn([4,3,224,224])model=Vit()out=model(t)print(out.shape)if __name__ == "__main__":main()

最后输出[4,10]
下一节写完整的ViT代码

Transformer——patch embedding代码相关推荐

  1. ViT Patch Embedding理解

    ViT(Vision Transformer)中的Patch Embedding用于将原始的2维图像转换成一系列的1维patch embeddings. 假设输入图像的维度为HxWxC,分别表示高,宽 ...

  2. Swin Transformer原文及其代码的理解

    Swin Transformer原文及其代码的理解 第一版 更好的排版笔记:Notion 名词解释 基础知识: 搞懂Vision Transformer 原理和代码,看这篇技术综述就够了(三) tok ...

  3. Swin Transformer原理与代码精讲

    课程链接:Swin Transformer原理与代码精讲--计算机视觉视频教程-人工智能-CSDN程序员研修院 Transformer在许多NLP(自然语言处理)任务中取得了最先进的成果. Swin ...

  4. Intra-Instance VICReg: Bag of Self-Supervised Image Patch Embedding

    最近,自监督学习(SSL)在学习图像表示方面取得了巨大的经验进步.然而,我们对表示的理解和知识仍然有限.这项工作表明,siamese-network-based SSL取得SOTA的成功主要基于学习图 ...

  5. 【NLP】简单学习一下NLP中的transformer的pytorch代码

    经典transformer的学习 文章转自微信公众号[机器学习炼丹术] 作者:陈亦新(已授权) 联系方式: 微信cyx645016617 欢迎交流,共同进步 代码细讲 transformer Embe ...

  6. 【深度学习】搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

    作者丨科技猛兽 编辑丨极市平台 导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Transformer的实现和代码以及Tr ...

  7. NLP-生成模型-2017-Transformer(二):Transformer各模块代码分析

    一.WordEmbedding层模块(文本嵌入层) Embedding Layer(文本嵌入层)的作用:无论是源文本嵌入还是目标文本嵌入,都是为了将文本中词汇的数字表示转变为向量表示, 由一维转为多维 ...

  8. 搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了

    ↑ 点击蓝字 关注极市平台 作者丨科技猛兽 编辑丨极市平台 极市导读 本文对Vision Transformer的原理和代码进行了非常全面详细的解读,一切从Self-attention开始.Trans ...

  9. 刷爆 AI 圈!基于 Transformer 的 DALL-E 代码刚刚开源了

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 转自 | AI科技评论 OpenAI在1月5日公布DALL-E模型以 ...

最新文章

  1. 开始使用Bootstrap
  2. java中无符号类型的处理[转]
  3. 前端工程师能力评估测试题(2020最新版附答案及解析)
  4. 幅度响应怎么计算_四电平脉冲幅度调制(PAM4)信号的误码分析
  5. hbase 操作指令集合
  6. 三元运算符(Java)
  7. WebUI Case(1): www.swt-designer.com 首页 (续)
  8. layui 监听表单提交form.on(‘submit(sub)‘,function (){}) ajax请求失败问题
  9. 计算机组成原理与体系结构知识概括
  10. 计算机毕业设计springboot门诊管理系统
  11. 华为又要给员工分红了!预计每股 1.61 元,网友:点赞任正非
  12. 介绍两种提取视频语音变成文字的方式
  13. Java Class
  14. 关于一个学习计算机专业,迷茫的大一新生的看法和理解
  15. 业务系统成功微服务化改造的实施步骤
  16. Jodd-Java的瑞士军刀 demo
  17. mysql 数据连续不走索引6_MySql组合索引最左侧原则失效
  18. fcitx只能打繁体字无法切换的一个解决方法linux mint
  19. Android图片海报制作-自定义文字排版控件组件
  20. zabbix3.0版本部署使用

热门文章

  1. 梦幻西游服务器维护,《梦幻西游》12月1日维护公告
  2. 电信融合机ip906h-fv2,线刷包(当贝桌面)
  3. 可以用于毕设参考,请勿过度借鉴
  4. 圣天诺HL加密锁(原HASP加密锁)快速入门
  5. 群晖aria2 bt没速度_群晖终于开窍了!联手迅雷推出下载套件 NAS功能再上一层楼...
  6. postgresql12 pgpool搭建(3)
  7. android启动页背景设置,Android APP启动页白(黑)屏问题及解决方法
  8. 浙江农林大学第二十一届程序设计竞赛校选拔赛A E G H
  9. 电脑键盘equals在哪个位置_总结了一下键盘上所有符号的英文说法
  10. HTML简单练习——个人名片