文章目录

  • 一、空间注意力机制简介
  • 二、空间注意力与pytorch代码
  • 三、使用案例

一、空间注意力机制简介

空间注意力的示意图如下:

长条的是通道注意力机制,而平面则是空间注意力机制,可以发现:

  • 通道注意力在意的是每个特怔面的权重
  • 空间注意力在意的是面上每一个局部的权重。

    注意:空间注意力是右边的部分:Spatial Attention Module

二、空间注意力与pytorch代码

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):  # x.size() 30,40,50,30avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 30,1,50,30x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)  # 30,1,50,30return self.sigmoid(x)  # 30,1,50,30

简单的使用方法如下:

import torch
import torch.nn as nn
import torch.utils.data as Dataclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):  # x.size() 30,40,50,30avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 30,1,50,30x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)  # 30,1,50,30return self.sigmoid(x)  # 30,1,50,30def get_total_train_data(H, W, C, class_count):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, H, W, C)))  # 维度是 [ 数据量, 高H, 宽W, 长C]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':# ================训练参数=================epochs = 100batch_size = 30output_class = 14H = 40W = 50C = 30# ================准备数据=================x_train, y_train = get_total_train_data(H, W, C, class_count=output_class)train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size,  # 每块的大小shuffle=True,  # 要不要打乱数据 (打乱比较好)num_workers=6,  # 多进程(multiprocess)来读数据drop_last=True,)# ================初始化模型=================model = SpatialAttention()# ================开始训练=================for i in range(epochs):for seq, labels in train_loader:attention_out = model(seq)seq_attention_out = attention_out.squeeze()for i in range(seq_attention_out.size()[0]):print(seq_attention_out[i])

三、使用案例

import torch
import torch.nn as nn
import torch.utils.data as Dataclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):  # x.size() 30,40,50,30avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 30,1,50,30x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)  # 30,1,50,30return self.sigmoid(x)  # 30,1,50,30class UseAttentionModel(nn.Module):def __init__(self):super(UseAttentionModel, self).__init__()self.channel_attention = SpatialAttention()def forward(self, x):  # 反向传播attention_value = self.channel_attention(x)out = x.mul(attention_value)return outdef get_total_train_data(H, W, C, class_count):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, H, W, C)))  # 维度是 [ 数据量, 高H, 宽W, 长C]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':# ================训练参数=================epochs = 100batch_size = 30output_class = 14H = 40W = 50C = 30# ================准备数据=================x_train, y_train = get_total_train_data(H, W, C, class_count=output_class)train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size,  # 每块的大小shuffle=True,  # 要不要打乱数据 (打乱比较好)num_workers=6,  # 多进程(multiprocess)来读数据drop_last=True,)# ================初始化模型=================model = UseAttentionModel()cross_loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器model.train()# ================开始训练=================for i in range(epochs):for seq, labels in train_loader:attention_out = model(seq)print(attention_out.size())print(attention_out)

注意力机制学习(二)——空间注意力与pytorch案例相关推荐

  1. 万字长文解析CV中的注意力机制(通道/空间/时域/分支注意力)

    点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心技术交流群 后台回复[transformer综述]获取2022最新ViT综述论文! 注意 ...

  2. Pytorch:Transformer(Encoder编码器-Decoder解码器、多头注意力机制、多头自注意力机制、掩码张量、前馈全连接层、规范化层、子层连接结构、pyitcast) part1

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) Encoder编码器-Decoder解码器框架 + Atten ...

  3. 【动手深度学习-笔记】注意力机制(一)注意力机制框架

    生物学中的注意力提示 非自主性提示: 在没有主观意识的干预下,眼睛会不自觉地注意到环境中比较突出和显眼的物体. 比如我们自然会注意到一堆黑球中的一个白球,马路上最酷的跑车等. 自主性提示: 在主观意识 ...

  4. 注意力机制学习(一)——通道注意力与pytorch案例

    文章目录 一.通道注意力机制简介 二.通道注意力机制pytorch代码 1. 单独使用通道注意力机制的小案例 2. 使用通道注意力机制的小案例 一.通道注意力机制简介 下面的图形象的说明了通道注意力机 ...

  5. Attention注意力机制学习(二)------->一些注意力网络整理

    SKNet--SENet孪生兄弟篇(2019) 论文 Selective Kernel Networks https://arxiv.org/abs/1903.06586  2019年 介绍 SKNe ...

  6. 注意力机制学习 BAM

    注意力机制学习-BAM 简介 思考 步骤 代码 实验 最后 简介 2018年BMVC,从通道和空间两方面去解释注意力机制,和CBAM为同一团队打造.论文连接:BAM BAM:Bottleneck At ...

  7. 第七周作业:注意力机制学习的part2

    [BMVC2018]BAM: Bottleneck Attention Module PDF:1807.06514.pdf (arxiv.org) 为使神经网络获得更强的表征能力,在文中作者提出了一种 ...

  8. 注意力机制 ——学习笔记

    文章目录 一.生物神经网络的注意力 1.1 生物注意力的种类 视觉注意力 听觉注意力 语言注意力 1.2 生物注意力的优势 1.3 注意力与记忆力的关系 1.4 人工神经网络的注意力 二.为什么要使用 ...

  9. Transformer、多头注意力机制学习笔记:Attention is All You Need.

    文章目录 相关参考连接: https://blog.csdn.net/hpulfc/article/details/80448570 https://blog.csdn.net/weixin_4239 ...

最新文章

  1. 在研究所工作是什么体验?和互联网公司比,你会怎么选?
  2. Python分布式+云计算
  3. OpenResty安装--增强版的nginx
  4. C++实现bellman ford贝尔曼-福特算法(最短路径)(附完整源码)
  5. 有机晶体数据库_技术专栏:一篇文章搞懂晶体学信息文件CIF及其获取方法
  6. 微信突然出现redirect_uri 参数错误
  7. 10年老电脑如何提速_电信宽带免费提速至200M,面向全国用户活动日期2020年11月9日至12月31日...
  8. c++中判断某个值在字典的value中_Python核心知识系列:字典
  9. [译]关于NODE_ENV,哪些你应该了解
  10. Win32编程之基于MATLAB与VC交互的多项式回归
  11. zookeeper中的ZAB协议理解
  12. C++ writestring 为什么不能写进中文 CStdioFile向无法向文本中写入中文【一】
  13. Mysql查询某列最长字符串记录
  14. 11.sql条件查询
  15. MailKit使用IMAP读取邮件找不到附件Attachments为空的解决方法
  16. 科技创新全球资本财富盛会暨联盟系统2.0启动大会圆满举行
  17. H5游戏开发:游戏引擎入门推荐
  18. EasyCVR人脸识别框在播放器上显示及消失的机制设定
  19. 罗技推出“语音鼠标”,隐藏着百度AI的产业化范式
  20. 友盟APM和bugly全面对比

热门文章

  1. php技术计算字符个数的函数是什么,php计算字符串中的单词数的函数str_word_count()...
  2. python嵌入shell代码_小白进!嵌入式开发如何快速入门?
  3. gitlab windows安装_【Thrift】Windows编译Thrift源码及其依赖库
  4. 485串口测试工具软件_(案例)电脑和仪表之间485通讯的奇怪现象及解决方案
  5. 关于ddx/ddy重建法线在edge边沿上的artifacts问题
  6. [五]java函数式编程归约reduce概念原理 stream reduce方法详解 reduce三个参数的reduce方法如何使用...
  7. Office web apps 服务器运行一段时间之后CPU就是达到100%
  8. Android ListView中 每一项都有不同的布局
  9. H3C S1526交换机端口镜像配置
  10. Closure--1