1、掩码模式:是相对于变长的循环序列而言的,如果输入的样本序列长度不同,那么会先对其进行对齐处理(对短序列补0,对长序列截断),再输入模型。这样,模型中的部分样本中就会有大量的零值。为了提升运算性能,需要以掩码的方式将不需要的零值去掉,并保留非零值进行计算,这就是掩码的作用
2、均值模式:正常模式对每个维度的所有序列计算注意力分数,而均值模式对每个维度注意力分数计算平均值。均值模式会平滑处理同一序列不同维度之间的差异,认为所有维度都是平等的,将注意力用在序列之间。这种方式更能体现出序列的重要性。


代码 Attention_cclassification.py

import torchvision
import torchvision.transforms as tranforms
import pylab
import torch
from matplotlib import pyplot as plt
import numpy as np
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因data_dir = './fashion_mnist'
tranform = tranforms.Compose([tranforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(root=data_dir,train=True,transform=tranform,download=True)
print("训练数据集条数",len(train_dataset))
val_dataset = torchvision.datasets.FashionMNIST(root=data_dir, train=False, transform=tranform)
print("测试数据集条数",len(val_dataset))
im = train_dataset[0][0]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
print("该图片的标签为:",train_dataset[0][1])## 数据集的制造
batch_size = 10
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)def imshow(img):print("图片形状:",np.shape(img))npimg = img.numpy()plt.axis('off')plt.imshow(np.transpose(npimg,(1,2,0)))classes = ('T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle_Boot')
sample = iter(train_loader)
images,labels = sample.next()
print("样本形状:",np.shape(images))
print("样本标签",labels)
imshow(torchvision.utils.make_grid(images,nrow=batch_size))
print(','.join('%5s' % classes[labels[j]] for j in range(len(images))))class myLSTMNet(torch.nn.Module): #定义myLSTMNet模型类,该模型包括 2个RNN层和1个全连接层def __init__(self,in_dim, hidden_dim, n_layer, n_class):super(myLSTMNet, self).__init__()# 定义循环神经网络层self.lstm = torch.nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)self.Linear = torch.nn.Linear(hidden_dim * 28, n_class)  # 定义全连接层self.attention = AttentionSeq(hidden_dim, hard=0.03) # 定义注意力层,使用硬模式的注意力机制def forward(self, t):  # 搭建正向结构t, _ = self.lstm(t)  # 使用LSTM对象进行RNN数据处理t = self.attention(t)   # 对循环神经网络结果进行注意力机制的处理,将处理后的结果变形为二维数据,传入全连接输出层。1t = t.reshape(t.shape[0], -1) # 对循环神经网络结果进行注意力机制的处理,将处理后的结果变形为二维数据,传入全连接输出层。2out = self.Linear(t)  # 进行全连接处理return outclass AttentionSeq(torch.nn.Module):def __init__(self, hidden_dim, hard=0.0): # 初始化super(AttentionSeq, self).__init__()self.hidden_dim = hidden_dimself.dense = torch.nn.Linear(hidden_dim, hidden_dim)self.hard = harddef forward(self, features, mean=False): # 类的处理方法# [batch,seq,dim]batch_size, time_step, hidden_dim = features.size()weight = torch.nn.Tanh()(self.dense(features)) # 全连接计算# 计算掩码,mask给负无穷使得权重为0mask_idx = torch.sign(torch.abs(features).sum(dim=-1))# mask_idx = mask_idx.unsqueeze(-1).expand(batch_size, time_step, hidden_dim)mask_idx = mask_idx.unsqueeze(-1).repeat(1, 1, hidden_dim)# 将掩码作用在注意力结果上# torch.where函数的意思是按照第一参数的条件对每个元素进行检查,如果满足,那么使用第二个参数里对应元素的值进行填充,如果不满足,那么使用第三个参数里对应元素的值进行填充。# torch.ful_likeO函数是按照张量的形状进行指定值的填充,其第一个参数是参考形状的张量,第二个参数是填充值。weight = torch.where(mask_idx == 1, weight,torch.full_like(mask_idx, (-2 ** 32 + 1))) # 利用掩码对注意力结果补0序列填充一个极小数,会在Softmax中被忽略为0weight = weight.transpose(2, 1)# 必须对注意力结果补0序列填充一个极小数,千万不能填充0,因为注意力结果是经过激活函数tanh()计算出来的,其值域是 - 1~1, 在这个区间内,零值是一个有效值。如果填充0,那么会对后面的Softmax结果产生影响。填充的值只有远离这个有效区间才可以保证被Softmax的结果忽略。weight = torch.nn.Softmax(dim=2)(weight) # 计算注意力分数if self.hard != 0:  # hard modeweight = torch.where(weight > self.hard, weight, torch.full_like(weight, 0))if mean: # 支持注意力分数平均值模式weight = weight.mean(dim=1)weight = weight.unsqueeze(1)weight = weight.repeat(1, hidden_dim, 1)weight = weight.transpose(2, 1)features_attention = weight * features # 将注意力分数作用于特征向量上return features_attention # 返回结果#实例化模型对象
network = myLSTMNet(28, 128, 2, 10)  # 图片大小是28x28,28:输入数据的序列长度为28。128:每层放置128个LSTM Cell。2:构建两层由LSTM Cell所组成的网络。10:最终结果分为10类。
#指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
network.to(device)
print(network)  #打印网络criterion = torch.nn.CrossEntropyLoss()  # 实例化损失函数类
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)for epoch in range(2): # 数据集迭代2次running_loss = 0.0for i, data in enumerate(train_loader, 0): # 循环取出批次数据inputs, labels = datainputs = inputs.squeeze(1) # 由于输入数据是序列形式,不再是图片,因此将通道设为1inputs, labels = inputs.to(device), labels.to(device) # 指定设备optimizer.zero_grad() # 清空之前的梯度outputs = network(inputs)loss = criterion(outputs, labels) # 计算损失loss.backward()  #反向传播optimizer.step() #更新参数running_loss += loss.item()if i % 1000 == 999:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss = 0.0
print('Finished Training')#使用模型
dataiter = iter(test_loader)
images, labels = dataiter.next()
inputs, labels = images.to(device), labels.to(device)imshow(torchvision.utils.make_grid(images,nrow=batch_size))
print('真实标签: ', ' '.join('%5s' % classes[labels[j]] for j in range(len(images))))
inputs = inputs.squeeze(1)
outputs = network(inputs)
_, predicted = torch.max(outputs, 1)print('预测结果: ', ' '.join('%5s' % classes[predicted[j]]for j in range(len(images))))#测试模型
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():for data in test_loader:images, labels = dataimages = images.squeeze(1)inputs, labels = images.to(device), labels.to(device)outputs = network(inputs)_, predicted = torch.max(outputs, 1)predicted = predicted.to(device)c = (predicted == labels).squeeze()for i in range(10):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1sumacc = 0
for i in range(10):Accuracy = 100 * class_correct[i] / class_total[i]print('Accuracy of %5s : %2d %%' % (classes[i], Accuracy ))sumacc =sumacc+Accuracy
print('Accuracy of all : %2d %%' % ( sumacc/10. ))

【Pytorch神经网络实战案例】12 利用注意力机制的神经网络实现对FashionMNIST数据集图片的分类相关推荐

  1. (pytorch-深度学习系列)pytorch实现对Fashion-MNIST数据集进行图像分类

    pytorch实现对Fashion-MNIST数据集进行图像分类 导入所需模块: import torch import torchvision import torchvision.transfor ...

  2. 【Pytorch神经网络实战案例】10 搭建深度卷积神经网络

    识别黑白图中的服装图案(Fashion-MNIST)https://blog.csdn.net/qq_39237205/article/details/123379997基于上述代码修改模型的组成 1 ...

  3. 《神经网络与深度学习》-注意力机制与外部记忆

    注意力机制与外部记忆 1. 认知神经学中的注意力 2. 注意力机制 2.1 注意力机制的变体 2.1.1 硬性注意力 2.1.2 键值对注意力 2.1.3 多头注意力 2.1.4 结构化注意力 2.1 ...

  4. 空间注意力机制sam_一种基于注意力机制的神经网络的人体动作识别方法与流程...

    本发明属于计算机视觉领域,具体来说是一种基于注意力机制的神经网络的人体动作识别的方法. 背景技术: 人体动作识别,具有着非常广阔的应用前景,如人机交互,视频监控.视频理解等方面.按目前的主流方法,可主 ...

  5. AI实战:搭建带注意力机制的 seq2seq 模型来做数值预测

    AI实战:搭建带注意力机制的 seq2seq 模型来做数值预测 seq2seq 框架图 环境依赖 Linux python3.6 tensorflow.keras 源码搭建模型及说明 依赖库 impo ...

  6. TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

    TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%) 目录 输出结果 实现代码 输出结果 Successfully downloaded t ...

  7. ML之nyoka:基于nyoka库利用LGBMClassifier模型实现对iris数据集训练、保存为pmml模型并重新载入pmml模型进而实现推理

    ML之nyoka:基于nyoka库利用LGBMClassifier模型实现对iris数据集训练.保存为pmml模型并重新载入pmml模型进而实现推理 目录 基于nyoka库利用LGBMClassifi ...

  8. CV之FDFA:利用MTCNN的脚本实现对LFW数据集进行FD人脸检测和FA人脸校准

    CV之FD&FA:利用MTCNN的脚本实现对LFW数据集进行FD人脸检测和FA人脸校准 目录 运行结果 运行过程 运行(部分)代码 在裁剪好的LFW数据集进行验证 运行结果 运行过程 time ...

  9. 【Pytorch神经网络实战案例】22 基于Cora数据集实现图注意力神经网络GAT的论文分类

    注意力机制的特点是,它的输入向量长度可变,通过将注意力集中在最相关的部分来做出决定.注意力机制结合RNN或者CNN的方法. 1 实战描述 [主要目的:将注意力机制用在图神经网络中,完成图注意力神经网络 ...

最新文章

  1. 谁在为网易云音乐2亿用户的即时通讯保驾护航?
  2. 九十二、动态规划系列之股票问题(上)
  3. win7右键点击文件夹进入命令窗口方法
  4. asp.net控件开发基础(2)
  5. 如果在docker中部署tomcat,并且部署java应用程序
  6. 深海迷航创造模式火箭怎么飞_深海迷航被玩成养鱼游戏 奇葩玩家的养殖之路...
  7. bash中正则表达式
  8. 智能一代云平台(八):代码依赖分析系统
  9. JS的作用域和作用域链
  10. 使用JavaScript分别实现4种样式的九九乘法表(1X1分别在左上、左下、右上、右下)...
  11. bat脚本监控tomcat并启动_windows使用批处理发布web到tomcat并启动tomcat脚本分享
  12. 快排序和堆排序,最小堆、最大堆
  13. 云课堂计算机教室怎么使用,锐捷“云课堂2.0”焕发计算机教室青春活力
  14. 分式的化简(约分、通分)
  15. python xlrd读取excel慢_与xlrd相比,使用openpyxl读取Excel文件要慢很多
  16. 移动硬盘无法读取是怎么回事?
  17. 通过面积证明:两个函数相乘 / 相除的导数为什么长成这样?
  18. 计算机网络汇聚层,【大白电气】接入层、汇聚层、核心层——中大型计算机网络系统结构介绍及交换机选型建议...
  19. 11g ocm认证考试经历
  20. git操作提示warning: redirecting to git@github.com:XXXXX

热门文章

  1. ipv6相对于ipv4的改进
  2. 电开大计算机应用基础作业,2016年电大-电大计算机应用基础作业 答案.doc
  3. java编写螺旋矩阵讲解_Java如何实现螺旋矩阵 Java实现螺旋矩阵代码实例
  4. 机器学习中qa测试_如何对机器学习做单元测试
  5. python有多少种模块_python如何查看有哪些模块
  6. 【机器学习】总结:线性回归求解中梯度下降法与最小二乘法的比较
  7. *args, **kwargs的用法
  8. 华为软件研发面试题1
  9. 本地音频播放,使用AVFoundation.framework中的AVAudioPlayer来实现
  10. CATia对计算机配置要求,【2人回答】求CATIA对电脑的详细配置要求-3D溜溜网