原文:https://arxiv.org/pdf/1811.11721.pdf

记录一下,目前习惯用笔写的:

cc attention这个模块的代码:

from .functions import CrissCrossAttention'''
This code is borrowed from Serge-weihao/CCNet-Pure-Pytorch
'''import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmaxdef INF(B,H,W):return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)class CrissCrossAttention(nn.Module):""" Criss-Cross Attention Module"""def __init__(self, in_dim):super(CrissCrossAttention,self).__init__()self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)self.softmax = Softmax(dim=3)self.INF = INFself.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):m_batchsize, _, height, width = x.size()proj_query = self.query_conv(x)proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)proj_key = self.key_conv(x)proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)proj_value = self.value_conv(x)proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)concate = self.softmax(torch.cat([energy_H, energy_W], 3))att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)#print(concate)#print(att_H) att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)#print(out_H.size(),out_W.size())return self.gamma*(out_H + out_W) + xif __name__ == '__main__':model = CrissCrossAttention(64)x = torch.randn(2, 64, 5, 6)out = model(x)print(out.shape)

(CCNET)criss-cross attention network学习笔记相关推荐

  1. 深度学习(二十六)Network In Network学习笔记-ICLR 2014

    Network In Network学习笔记 原文地址:http://blog.csdn.net/hjimce/article/details/50458190 作者:hjimce 一.相关理论 本篇 ...

  2. 深度学习(二十六)Network In Network学习笔记

    Network In Network学习笔记 原文地址:http://blog.csdn.net/hjimce/article/details/50458190 作者:hjimce 一.相关理论 本篇 ...

  3. NIN(Network in Network)学习笔记

    NIN(Network in Network)学习笔记 一.前言 <Network In Network>是一篇比较老的文章了(2014年ICLR的一篇paper),是当时比较牛逼的一篇论 ...

  4. matlab rbm 语音,Deep Belief Network 学习笔记-RBM

    Deep Belief Network 学习笔记-RBM By Placebo (纯属个人笔记) 第一次知道deep learning,是上学期dengli博士来实验室的一次报告,他讲到,当神经网络的 ...

  5. GNN金融应用之Classifying and Understanding Financial Data Using Graph Neural Network学习笔记

    Classifying and Understanding Financial Data Using Graph Neural Network 摘要 1. 概述 2. 数据表示-加权图 3. GNN利 ...

  6. 图神经网络框架DGL实现Graph Attention Network (GAT)笔记

    参考列表: [1]深入理解图注意力机制 [2]DGL官方学习教程一 --基础操作&消息传递 [3]Cora数据集介绍+python读取 一.DGL实现GAT分类机器学习论文 程序摘自[1],该 ...

  7. network 学习笔记

    #cat /etc/modprobe.conf Ethernet : eth0,eth1,ethN Token Ring: tr0,tr1,trN FDDI : fddi0,fddi1,fddiN ( ...

  8. Neural Network学习笔记2

    torch.nn: Containers: 神经网络骨架 Convolution Layers 卷积层 Pooling Layers  池化层 Normalization Layers 正则化层 No ...

  9. 内容分发网络 - Content Delivery Network 学习笔记

    缓存是将文件副本存储在缓存或临时存储位置的过程,以便可以更快地访问它们.从技术上讲,缓存是文件或数据副本的任何临时存储位置,但该术语通常用于指代 Internet 技术. Web 浏览器缓存 HTML ...

  10. Focusing Attention Network(FAN)自然图像文本识别 学习笔记

    Focusing Attention: Towards Accurate Text Recognition in Natural Images Author: Zhanzhan Cheng,Fan B ...

最新文章

  1. 安装Windows 2012域控(For SQLServer 2014 AlwaysOn)
  2. [codevs 2926] 黑白瓷砖(2002年安徽省队选拔赛)
  3. 微型计算机技术第三版第三章答案,第3章微机组装技术作业(答案)
  4. python解析pcap包已text格式输出_python分析pcap包
  5. python进阶装饰器_Python进阶: 通过实例详解装饰器(附代码)
  6. 终极解决方案UnicodeEncodeError: 'ascii' codec can't encode character u'\uff08' in position 13: ordinal not
  7. python图像跟踪代码_python如何实现图像外边界跟踪 python实现图像外边界跟踪代码示例...
  8. cssci核心期刊(cssci核心期刊目录)
  9. 萝卜小姐的整车第一弹—MCU 软件烧录及升级说明
  10. 应用锁(AppLocker)原理及代码实现
  11. 暴力电脑锁机生成器(加机械硬盘锁)
  12. 三元运算 微信小程序_微信小程序使用三元运算符代替wx:if
  13. 浙江大学许威威教授招聘博士后
  14. uhuntu五笔输入法fcitx安装
  15. TCP/IP 主要报文头格式
  16. 小型网络拓扑(vlan)
  17. [GDOI2016][树链剖分+主席树]疯狂动物城
  18. 论文阅读(3) 用气泡PIV测量加利福尼亚海狮推进冲程的速度场(2022)
  19. LIVE555再学习 -- live555实现RTSP直播服务器
  20. Windows下UDP编程

热门文章

  1. Apple Store教育优惠(161103)
  2. 5分+细胞器基因组好文!多线南蜥线粒体基因组及比较基因组研究
  3. 世界500强企业名称中英对照
  4. AI安全技术总结与展望
  5. C语言每日一练——第61天:掷骰子游戏
  6. 三层交换机也不贵:自己动手做三层交换机
  7. 获取多边形的最大最小坐标
  8. http概述(相关详解)
  9. HTTP协议:无状态协议
  10. 马氏距离 java实现_马氏距离与欧氏距离