ti_vig

def pvig_ti_224_gelu(pretrained=False, **kwargs):class OptInit:def __init__(self, num_classes=1000, drop_path_rate=0.0, **kwargs):self.k = 9 # 邻域的数目,默认为9self.conv = 'mr' # 图卷积层=mrself.act = 'gelu' # 激活层=geluself.norm = 'batch' # batch or instance normalization {batch, instance}self.bias = True # bias of conv layer True or Falseself.dropout = 0.0 # dropout rateself.use_dilation = True # use dilated knn or notself.epsilon = 0.2 # stochastic epsilon for gcnself.use_stochastic = False # stochastic for gcn, True or Falseself.drop_path = drop_path_rateself.blocks = [2,2,6,2] # number of basic blocks in the backboneself.channels = [48, 96, 240, 384] # number of channels of deep featuresself.n_classes = num_classes # Dimension of out_channelsself.emb_dims = 1024 # Dimension of embeddingsopt = OptInit(**kwargs)model = DeepGCN(opt)model.default_cfg = default_cfgs['vig_224_gelu']return model

DeepGCN

class DeepGCN(torch.nn.Module):def __init__(self, opt):super(DeepGCN, self).__init__()print(opt)k = opt.k  # k=9act = opt.act  # active method = gelunorm = opt.norm  # norm = batchbias = opt.bias  # bias = trueepsilon = opt.epsilon  # epsilon = 0.2stochastic = opt.use_stochastic  # use_stochastic = Falseconv = opt.conv  # conv = mremb_dims = opt.emb_dims  # emb_dims = 1024drop_path = opt.drop_path  # drop_path = drop_path_rate = 0.0blocks = opt.blocks  # blocks = [2, 2, 6, 2]self.n_blocks = sum(blocks)   # n_blocks = 12channels = opt.channels  # channels = [80, 160, 400, 640]reduce_ratios = [4, 2, 1, 1]# stochastic depth decay rule# dpr = 0.0 x 12dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)]# num_knn = 9 x 12num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)]# 最大扩张 max_dilation = 49//9 = 5max_dilation = 49 // max(num_knn)# Stem(out_dim=80, act=gelu), output size = [h/4, w/4, 80]self.stem = Stem(out_dim=channels[0], act=act)# pos_embed = [1, 80, 56, 56]self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224//4, 224//4))HW = 224 // 4 * 224 // 4  # 3136self.backbone = nn.ModuleList([])idx = 0for i in range(len(blocks)):  # [2, 2, 6, 2], i = 0 1 2 3if i > 0:self.backbone.append(Downsample(channels[i-1], channels[i]))HW = HW // 4  #  784for j in range(blocks[i]):self.backbone += [Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act, norm,bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx],relative_pos=True),FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx]))]idx += 1self.backbone = Seq(*self.backbone)## ----- this part x2 -----## Grapher(channel=80, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2,##           reduce_ratios=4, n=3136, drop_path=0.0, relative_pos=True),## FFN(80, 320, act=gelu, drop_path=0.0)## ------------------------## Downsample(80, 160)## HW = 784## ----- this part x2 -----## Grapher(channel=160, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2,##           reduce_ratios=4, n=784, drop_path=0.0, relative_pos=True),## FFN(160, 640, act=gelu, drop_path=0.0)## ------------------------## Downsample(160, 400)## HW = 196## ----- this part x6 -----## Grapher(channel=400, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2,##           reduce_ratios=4, n=196, drop_path=0.0, relative_pos=True),## FFN(400, 1600, act=gelu, drop_path=0.0)## ------------------------## Downsample(400, 640)## HW = 49## ----- this part x2 -----## Grapher(channel=640, num_knn=9, 1, conv=mr, act=gelu, norm=batch, bias=true, stochastic=false, epsilon=0.2,##           reduce_ratios=4, n=196, drop_path=0.0, relative_pos=True),## FFN(640, 2560, act=gelu, drop_path=0.0)## ------------------------self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),nn.BatchNorm2d(1024),act_layer(act),nn.Dropout(opt.dropout),nn.Conv2d(1024, opt.n_classes, 1, bias=True))self.model_init()def model_init(self):for m in self.modules():if isinstance(m, torch.nn.Conv2d):torch.nn.init.kaiming_normal_(m.weight)m.weight.requires_grad = Trueif m.bias is not None:m.bias.data.zero_()m.bias.requires_grad = Truedef forward(self, inputs):x = self.stem(inputs) + self.pos_embedB, C, H, W = x.shapefor i in range(len(self.backbone)):x = self.backbone[i](x)x = F.adaptive_avg_pool2d(x, 1)return self.prediction(x).squeeze(-1).squeeze(-1)

Stem

class Stem(nn.Module):""" Image to Visual EmbeddingOverlap: https://arxiv.org/pdf/2106.13797.pdf"""def __init__(self, img_size=224, in_dim=3, out_dim=768, act='relu'):super().__init__()        self.convs = nn.Sequential(nn.Conv2d(in_dim, out_dim//2, 3, stride=2, padding=1),  # in_ch=3, out_ch=40, outputsize=[h/2,w/2,40]nn.BatchNorm2d(out_dim//2),  # 40act_layer(act),  # relunn.Conv2d(out_dim//2, out_dim, 3, stride=2, padding=1),  # in_ch=40, out_ch=80, outputsize=[h/4,w/4,80]nn.BatchNorm2d(out_dim),  # 80act_layer(act),  # relunn.Conv2d(out_dim, out_dim, 3, stride=1, padding=1),  # in_ch=80, out_ch=80, outputsize=[h/4,w/4,80]nn.BatchNorm2d(out_dim),)def forward(self, x):x = self.convs(x)return x

ViG核心代码及网络结构图相关推荐

  1. 深度学习中经典网络结构图和代码

    Inception网络与其它网络的性能对比 PyTorch-Networks: 包含了分类.检测以及姿态估计等网络的pytorch代码 caffe-model-zoo: AlexNet.VGGNet. ...

  2. Yolov3、v4、v5、Yolox模型权重及网络结构图资源下载

    对于Yolo相关的算法,大白从Yolov3.Yolov4.Yolov5.Yolox的角度,制作了一些文章和视频: 算法开发重磅福利: (1)算法工程师模型部署利器,算法开发平台,安卓手机即可使用,点击 ...

  3. 融资 2000 万美元后,他竟将核心代码全开源,这……能行吗?

    立即报名:https://t.csdnimg.cn/KqnS 有这么一位"任性"的技术创业者: 2017 年,50 岁开始第三次创业,踏足自己从未深入涉及过的物联网大数据平台,敲下 ...

  4. 太牛了!30 年开源老兵,10 年躬耕 OpenStack,开源 1000 万行核心代码!

    受访者 | Jonathan Bryce 记者 | 伍杏玲 出品 | CSDN(ID:CSDNnews) 万物互联时代下,我们的一切都在依赖计算基础设施,科学.金融.政府.教育.通信和医疗保健依赖现代 ...

  5. python 重写断言_历时四年,Dropbox 用 Rust 重写同步引擎核心代码

    开源 GO 语言工具库.研究 iOS 和 Android 的 C++ 跨平台开发,花费五年时间从云平台向数据中心反向迁移-Dropbox 从未停止对技术的"折腾".如今,这家公司又 ...

  6. tensorflow打印模型图_从Tensorflow模型文件中解析并显示网络结构图(pb模型篇)...

    最近看到一个巨牛的人工智能教程,分享一下给大家.教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家.平时碎片时间可以当小说看,[点这里可以去膜拜一下大神的" ...

  7. 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN 代码与网络的 Graph...

    GAN 这个领域发展太快,日新月异,各种 GAN 层出不穷,前几天看到一篇关于 Wasserstein GAN 的文章,讲的很好,在此把它分享出来一起学习:https://zhuanlan.zhihu ...

  8. 30 年开源老兵,10 年躬耕 OpenStack,开源 1000 万行核心代码

    受访者 | Jonathan Bryce 记者 | 伍杏玲 出品 | CSDN(ID:CSDNnews) 万物互联时代下,我们的一切都在依赖计算基础设施,科学.金融.政府.教育.通信和医疗保健依赖现代 ...

  9. linux 内核 核心代码,8分钟掌握Linux内核分析的核心科技

    原标题:8分钟掌握Linux内核分析的核心科技 作者: OUYANG_LINUX007 来源: http://blog.csdn.net/ouyang_linux007/article/details ...

最新文章

  1. 如果我的实验室也这样布置,那多好。
  2. Windows Phone 执行模型概述
  3. 熊猫tv新功能介绍_您应该知道的4种熊猫绘图功能
  4. 终端操作MySQL数据库
  5. iis部署错误:HTTP 错误 500.21 - Internal Server Error
  6. HtmlDocument.ExecCommand() 方法
  7. 【牛客网】滴滴出行2017秋招测试岗笔试真题汇总
  8. 个人随笔/小白应该如何学习Linux,我的一些心得分享.
  9. 逐帧动画案例(奔跑的小人)
  10. 中介者模式 java_Java设计模式学习记录-中介者模式
  11. 解决云服务器添加了安全组端口无法访问问题
  12. 前端每日实战:50# 视频演示如何用纯 CSS 创作一个永动的牛顿摆
  13. 编码的奥秘:两种典型的微处理器
  14. windows安装OpenSSL
  15. [附源码]Python计算机毕业设计Django大学生考勤管理系统论文
  16. CCF C³-19@航天宏图:星链互联,创新未来——商业卫星互联网时代的思考丨开始报名...
  17. uniapp ios原生插件开发之插件包格式(package.json)
  18. 【计算题】(三)连续与导数
  19. Solaris10如何确认DirectIO是否已经启用
  20. 【产业互联网周报】互联网集体进入调整期:张勇兼任阿里云总裁,卢伟冰晋升小米集团总裁,小鹏组织架构调整;...

热门文章

  1. 常用图像数据集:标注、检索
  2. 通信达接口python通过第三方库的pytdx怎样获取最新股票行情?
  3. PVE虚拟机命令行管理虚拟机启动和停止
  4. tableau高级绘图(十一)-tableau绘制日历圆形图
  5. 奇虎360通过亚马逊云科技云服务加速创新
  6. 【LM】360N4S解决手机关屏后经常无法唤醒的情况刷机第三方
  7. iPhone 8快速充电技术简介,苹果再也不能被安卓嘲笑充电慢
  8. 既约分数(最大相除法)
  9. 运维工程师和网络工程师的区别?
  10. 微信小程序开发工具初次使用Git 记录一下