注意力机制学习(二)——空间注意力与pytorch案例
文章目录
- 一、空间注意力机制简介
- 二、空间注意力与pytorch代码
- 三、使用案例
一、空间注意力机制简介
空间注意力的示意图如下:
长条的是通道注意力机制,而平面则是空间注意力机制,可以发现:
- 通道注意力在意的是每个特怔面的权重
- 空间注意力在意的是面上每一个局部的权重。
注意:空间注意力是右边的部分:Spatial Attention Module
二、空间注意力与pytorch代码
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x): # x.size() 30,40,50,30avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True) # 30,1,50,30x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x) # 30,1,50,30return self.sigmoid(x) # 30,1,50,30
简单的使用方法如下:
import torch
import torch.nn as nn
import torch.utils.data as Dataclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x): # x.size() 30,40,50,30avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True) # 30,1,50,30x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x) # 30,1,50,30return self.sigmoid(x) # 30,1,50,30def get_total_train_data(H, W, C, class_count):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, H, W, C))) # 维度是 [ 数据量, 高H, 宽W, 长C]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long() # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':# ================训练参数=================epochs = 100batch_size = 30output_class = 14H = 40W = 50C = 30# ================准备数据=================x_train, y_train = get_total_train_data(H, W, C, class_count=output_class)train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train), # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size, # 每块的大小shuffle=True, # 要不要打乱数据 (打乱比较好)num_workers=6, # 多进程(multiprocess)来读数据drop_last=True,)# ================初始化模型=================model = SpatialAttention()# ================开始训练=================for i in range(epochs):for seq, labels in train_loader:attention_out = model(seq)seq_attention_out = attention_out.squeeze()for i in range(seq_attention_out.size()[0]):print(seq_attention_out[i])
三、使用案例
import torch
import torch.nn as nn
import torch.utils.data as Dataclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x): # x.size() 30,40,50,30avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True) # 30,1,50,30x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x) # 30,1,50,30return self.sigmoid(x) # 30,1,50,30class UseAttentionModel(nn.Module):def __init__(self):super(UseAttentionModel, self).__init__()self.channel_attention = SpatialAttention()def forward(self, x): # 反向传播attention_value = self.channel_attention(x)out = x.mul(attention_value)return outdef get_total_train_data(H, W, C, class_count):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, H, W, C))) # 维度是 [ 数据量, 高H, 宽W, 长C]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long() # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':# ================训练参数=================epochs = 100batch_size = 30output_class = 14H = 40W = 50C = 30# ================准备数据=================x_train, y_train = get_total_train_data(H, W, C, class_count=output_class)train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train), # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size, # 每块的大小shuffle=True, # 要不要打乱数据 (打乱比较好)num_workers=6, # 多进程(multiprocess)来读数据drop_last=True,)# ================初始化模型=================model = UseAttentionModel()cross_loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 优化器model.train()# ================开始训练=================for i in range(epochs):for seq, labels in train_loader:attention_out = model(seq)print(attention_out.size())print(attention_out)
注意力机制学习(二)——空间注意力与pytorch案例相关推荐
- 万字长文解析CV中的注意力机制(通道/空间/时域/分支注意力)
点击下方卡片,关注"自动驾驶之心"公众号 ADAS巨卷干货,即可获取 点击进入→自动驾驶之心技术交流群 后台回复[transformer综述]获取2022最新ViT综述论文! 注意 ...
- Pytorch:Transformer(Encoder编码器-Decoder解码器、多头注意力机制、多头自注意力机制、掩码张量、前馈全连接层、规范化层、子层连接结构、pyitcast) part1
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) Encoder编码器-Decoder解码器框架 + Atten ...
- 【动手深度学习-笔记】注意力机制(一)注意力机制框架
生物学中的注意力提示 非自主性提示: 在没有主观意识的干预下,眼睛会不自觉地注意到环境中比较突出和显眼的物体. 比如我们自然会注意到一堆黑球中的一个白球,马路上最酷的跑车等. 自主性提示: 在主观意识 ...
- 注意力机制学习(一)——通道注意力与pytorch案例
文章目录 一.通道注意力机制简介 二.通道注意力机制pytorch代码 1. 单独使用通道注意力机制的小案例 2. 使用通道注意力机制的小案例 一.通道注意力机制简介 下面的图形象的说明了通道注意力机 ...
- Attention注意力机制学习(二)------->一些注意力网络整理
SKNet--SENet孪生兄弟篇(2019) 论文 Selective Kernel Networks https://arxiv.org/abs/1903.06586 2019年 介绍 SKNe ...
- 注意力机制学习 BAM
注意力机制学习-BAM 简介 思考 步骤 代码 实验 最后 简介 2018年BMVC,从通道和空间两方面去解释注意力机制,和CBAM为同一团队打造.论文连接:BAM BAM:Bottleneck At ...
- 第七周作业:注意力机制学习的part2
[BMVC2018]BAM: Bottleneck Attention Module PDF:1807.06514.pdf (arxiv.org) 为使神经网络获得更强的表征能力,在文中作者提出了一种 ...
- 注意力机制 ——学习笔记
文章目录 一.生物神经网络的注意力 1.1 生物注意力的种类 视觉注意力 听觉注意力 语言注意力 1.2 生物注意力的优势 1.3 注意力与记忆力的关系 1.4 人工神经网络的注意力 二.为什么要使用 ...
- Transformer、多头注意力机制学习笔记:Attention is All You Need.
文章目录 相关参考连接: https://blog.csdn.net/hpulfc/article/details/80448570 https://blog.csdn.net/weixin_4239 ...
最新文章
- 在研究所工作是什么体验?和互联网公司比,你会怎么选?
- Python分布式+云计算
- OpenResty安装--增强版的nginx
- C++实现bellman ford贝尔曼-福特算法(最短路径)(附完整源码)
- 有机晶体数据库_技术专栏:一篇文章搞懂晶体学信息文件CIF及其获取方法
- 微信突然出现redirect_uri 参数错误
- 10年老电脑如何提速_电信宽带免费提速至200M,面向全国用户活动日期2020年11月9日至12月31日...
- c++中判断某个值在字典的value中_Python核心知识系列:字典
- [译]关于NODE_ENV,哪些你应该了解
- Win32编程之基于MATLAB与VC交互的多项式回归
- zookeeper中的ZAB协议理解
- C++ writestring 为什么不能写进中文 CStdioFile向无法向文本中写入中文【一】
- Mysql查询某列最长字符串记录
- 11.sql条件查询
- MailKit使用IMAP读取邮件找不到附件Attachments为空的解决方法
- 科技创新全球资本财富盛会暨联盟系统2.0启动大会圆满举行
- H5游戏开发:游戏引擎入门推荐
- EasyCVR人脸识别框在播放器上显示及消失的机制设定
- 罗技推出“语音鼠标”,隐藏着百度AI的产业化范式
- 友盟APM和bugly全面对比
热门文章
- php技术计算字符个数的函数是什么,php计算字符串中的单词数的函数str_word_count()...
- python嵌入shell代码_小白进!嵌入式开发如何快速入门?
- gitlab windows安装_【Thrift】Windows编译Thrift源码及其依赖库
- 485串口测试工具软件_(案例)电脑和仪表之间485通讯的奇怪现象及解决方案
- 关于ddx/ddy重建法线在edge边沿上的artifacts问题
- [五]java函数式编程归约reduce概念原理 stream reduce方法详解 reduce三个参数的reduce方法如何使用...
- Office web apps 服务器运行一段时间之后CPU就是达到100%
- Android ListView中 每一项都有不同的布局
- H3C S1526交换机端口镜像配置
- Closure--1