Learning Attentive Pairwise Interaction for Fine-Grained Classification

2020 AAAI。网络结构倒是不复杂,但是这么大的batch size要怎么跑起来。

文章目录

  • Learning Attentive Pairwise Interaction for Fine-Grained Classification
    • 摘要
    • 1 引言
    • 2 API-Net
      • 2.1 互矢量学习
      • 2.2 门向量
      • 2.3 成对交互
      • 2.4 训练与测试
    • 3 实验
      • 3.1 消融实验
      • 3.2 比较SOTA
      • 3.3 可视化
    • 4 源码阅读

摘要

动机:目前方法都是通过单张图像学习区分性表示,而人类可以通过比较图像对来有效地识别。

网络:注意力成对交互网络(API-Net),通过交互逐步识别成对的细粒度图像。

  1. 先学习一个共同的特征向量,捕获输入对中的语义差异
  2. 将该向量与各个向量比较,为每个输入图像生成门
  3. 端到端,分数排序正则化

代码:https://github.com/PeiqinZhuang/API-Net

1 引言

人类是在比较中识别细粒度对象的。

引入注意力成对交互网络API-Net,可以从一对细粒度图像中自适应地发现对比线索,并通过成对交互进行区分。

API-Net由三个子模块组成,即相互向量学习,门向量生成和成对交互。输入一对图像,先学习一个互矢量,以将输入对的对比线索概括为上下文。再将互向量与单个向量进行比较生成不同的,可以从每个单个图像的角度突出显示语义差异。将这些门作为区分性注意力执行成对交互。每个图像可以生成两个增强的特征向量,分别从其自身的门矢量和该对中另一个图像的门矢量激活。通过端到端的训练方式和分数排名正则化。

即插即用。

2 API-Net

2.1 互矢量学习

两张图片分别经过主干网络生成 D D D维特征向量 x 1 , x 2 x_1,x_2 x1,x2,映射函数(多层感知机)学习一个 D D D维互矢量 x m = f m ( [ x 1 , x 2 ] ) x_m=f_m([x_1,x_2]) xm=fm([x1,x2])。由于 x m x_m xm是两个的自适应总结,通常包含特征通道,指示成对的高层次的对比线索。

2.2 门向量

x m x_m xm作为指导,寻找每个 x i x_i xi包含的对比线索,生成门:
g i = s i g m o i d ( x m ⊙ x i ) , i ∈ { 1 , 2 } g_i=sigmoid(x_m\odot x_i),i\in \{1,2\} gi=sigmoid(xmxi),i{1,2}
g i g_i gi成为有区别的注意力,以不同角度指出了每个 x i x_i xi的语义差异。

2.3 成对交互

通过门向量进行成对交互:
x 1 s e l f = x 1 + x 1 ⊙ g 1 x 2 o t h e r = x 2 + x 2 ⊙ g 2 x 1 s e l f = x 1 + x 1 ⊙ g 2 x 2 o t h e r = x 2 + x 2 ⊙ g 1 x^{self}_1=x_1+x_1\odot g_1\\ x^{other}_2=x_2+x_2\odot g_2\\ x^{self}_1=x_1+x_1\odot g_2\\ x^{other}_2=x_2+x_2\odot g_1\\ x1self=x1+x1g1x2other=x2+x2g2x1self=x1+x1g2x2other=x2+x2g1
x i s e l f x^{self}_i xiself由自己的门向量激活, x i o t h e r x^{other}_i xiother由另一个图像的门向量激活。

2.4 训练与测试

特征向量经过softmax分类器,得到 p i j p^j_i pijj ∈ { s e l f , o t h e r } , i ∈ { 1 , 2 } j\in\{self,other\},i\in \{1,2\} j{self,other},i{1,2})。

损失: L = L c e + λ L r k L=L_{ce}+\lambda L_{rk} L=Lce+λLrk
L r k = ∑ i ∈ { 1 , 2 } max ⁡ ( 0 , p i o t h e r ( c i ) − p i s e l f ( c i ) + ϵ ) L_{rk}=\sum_{i\in\{1,2\}}\max(0,p^{other}_i(c_i)-p^{self}_i(c_i)+\epsilon) Lrk=i{1,2}max(0,piother(ci)piself(ci)+ϵ)
由自己的门激活得到的结果应当更有区别性。

分批随机抽样 N c l N_{cl} Ncl类,每类随机抽取 N i m N_{im} Nim训练图像,生成其特征向量。对于每个图像,根据欧式距离将其特征与其他特征比较。结果可以为每个图像构造两对:类内最像对、类间最像对。每批共 2 × N c l × N i m 2\times N_{cl}\times N_{im} 2×Ncl×Nim对。

测试时:特征向量直接经过全连接分类。

3 实验

backbone:Resnet101。每个批次中随机抽取30个类别,每类随机采样4张图像,有240个图像对

3.1 消融实验

基线模型:

image-20210525153635871

互向量:

  1. 不采用互向量,各自生成门向量
  2. 双线性池化操作
  3. 逐元素操作,包括平方差、和、点积三种
  4. 权重注意力,两层MLP生成两个向量的权重
  5. MLP

image-20210525154249025

门向量:

  1. 一个门: g m = s i g m o i d ( x m ) g_m=sigmoid(x_m) gm=sigmoid(xm),一种注意力 x i s e l f = x i + x i ⊙ g m x^{self}_i=x_i+x_i\odot g_m xiself=xi+xigm
  2. 两个门

image-20210525154849632

交互:

  1. 仅使用交叉熵损失

  2. 交叉熵+排名损失

    image-20210525154935074

图像对的构建:

  1. 随机对
  2. 类别对

image-20210525155104300

S表示最相似、D表示最不相似

批次的样本数:

image-20210525155506851

3.2 比较SOTA

image-20210525155808445

3.3 可视化

根据门向量得到top-5激活通道,再全局池化前进行可视化。以及resnet101的对应通道。

即使API-Net主要在高级特征上运行,也能自动关注特征图中的可区分对象部分。

4 源码阅读

模型:

def pdist(vectors):"""计算欧氏距离:-2(v1+v2) + v1^2 + v2^2vectors: b*c,b个c维度向量"""# vectors.mm(torch.t(vectors)) v1*v2 ,b*b# vectors.pow(2).sum(dim=1).view(1, -1) v1^2 ,1*b# vectors.pow(2).sum(dim=1).view(-1, 1) v2^2 ,b*1distance_matrix = -2 * vectors.mm(torch.t(vectors)) + vectors.pow(2).sum(dim=1).view(1, -1) + vectors.pow(2).sum(dim=1).view(-1, 1)return distance_matrixclass API_Net(nn.Module):def __init__(self):super(API_Net, self).__init__()resnet101 = models.resnet101(pretrained=True)layers = list(resnet101.children())[:-2]self.conv = nn.Sequential(*layers)self.avg = nn.AvgPool2d(kernel_size=14, stride=1)# 互向量生成self.map1 = nn.Linear(2048 * 2, 512)self.map2 = nn.Linear(512, 2048)self.fc = nn.Linear(2048, 200)self.drop = nn.Dropout(p=0.5)self.sigmoid = nn.Sigmoid()def forward(self, images, targets=None, flag='train'):conv_out = self.conv(images)  # b*c*h*wpool_out = self.avg(conv_out).squeeze()  # b*c*1*1 -> b*cif flag == 'train':intra_pairs, inter_pairs, intra_labels, inter_labels = self.get_pairs(pool_out, targets)features1 = torch.cat([pool_out[intra_pairs[:, 0]], pool_out[inter_pairs[:, 0]]], dim=0)  # 样本,样本, 2b * cfeatures2 = torch.cat([pool_out[intra_pairs[:, 1]], pool_out[inter_pairs[:, 1]]],dim=0)  # 类外最像样本,类内最像样本, 2b * clabels1 = torch.cat([intra_labels[:, 0], inter_labels[:, 0]], dim=0)labels2 = torch.cat([intra_labels[:, 1], inter_labels[:, 1]], dim=0)mutual_features = torch.cat([features1, features2],dim=1)  # dim=1拼接,2b * 2c,前b个是(样本,类外最像样本),后b个是(样本,类内最像样本),map1_out = self.map1(mutual_features)map2_out = self.drop(map1_out)map2_out = self.map2(map2_out)  # 生成互向量gate1 = torch.mul(map2_out, features1)gate1 = self.sigmoid(gate1)gate2 = torch.mul(map2_out, features2)gate2 = self.sigmoid(gate2)  # 生成门向量# 成对交互features1_self = torch.mul(gate1, features1) + features1features1_other = torch.mul(gate2, features1) + features1features2_self = torch.mul(gate2, features2) + features2features2_other = torch.mul(gate1, features2) + features2logit1_self = self.fc(self.drop(features1_self))logit1_other = self.fc(self.drop(features1_other))logit2_self = self.fc(self.drop(features2_self))logit2_other = self.fc(self.drop(features2_other))return logit1_self, logit1_other, logit2_self, logit2_other, labels1, labels2elif flag == 'val':return self.fc(pool_out)def get_pairs(self, embeddings, labels):distance_matrix = pdist(embeddings).detach().cpu().numpy()  # b*blabels = labels.detach().cpu().numpy().reshape(-1, 1)  # b*1num = labels.shape[0]  # 样本数dia_inds = np.diag_indices(num)  # (array([0, 1, 2, ..., num]), array([0, 1, 2, ..., num])lb_eqs = (labels == labels.T)  # 同一类标签的坐标lb_eqs[dia_inds] = False  # 自己不能和自己成对dist_same = distance_matrix.copy()dist_same[lb_eqs == False] = np.inf  # 不能和自己匹配的举例是无穷intra_idxs = np.argmin(dist_same, axis=1)  # 每个样本的同一个类中的最接近的坐标dist_diff = distance_matrix.copy()lb_eqs[dia_inds] = Truedist_diff[lb_eqs == True] = np.infinter_idxs = np.argmin(dist_diff, axis=1)  # 每个样本的不同类中的最接近的坐标# 组对intra_pairs = np.zeros([embeddings.shape[0], 2])inter_pairs = np.zeros([embeddings.shape[0], 2])intra_labels = np.zeros([embeddings.shape[0], 2])inter_labels = np.zeros([embeddings.shape[0], 2])for i in range(embeddings.shape[0]):# 不同类intra_labels[i, 0] = labels[i]intra_labels[i, 1] = labels[intra_idxs[i]]intra_pairs[i, 0] = iintra_pairs[i, 1] = intra_idxs[i]# 同一类inter_labels[i, 0] = labels[i]inter_labels[i, 1] = labels[inter_idxs[i]]inter_pairs[i, 0] = iinter_pairs[i, 1] = inter_idxs[i]intra_labels = torch.from_numpy(intra_labels).long().to(device)intra_pairs = torch.from_numpy(intra_pairs).long().to(device)inter_labels = torch.from_numpy(inter_labels).long().to(device)inter_pairs = torch.from_numpy(inter_pairs).long().to(device)return intra_pairs, inter_pairs, intra_labels, inter_labels

pytorch实现平衡采样:

class BalancedBatchSampler(BatchSampler):def __init__(self, dataset, n_classes, n_samples):self.labels = dataset.labels # dataset自定义的属性self.labels_set = list(set(self.labels.numpy()))self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]for label in self.labels_set}for l in self.labels_set:np.random.shuffle(self.label_to_indices[l])self.used_label_indices_count = {label: 0 for label in self.labels_set}self.count = 0self.n_classes = n_classesself.n_samples = n_samplesself.dataset = datasetself.batch_size = self.n_samples * self.n_classesdef __iter__(self):self.count = 0while self.count + self.batch_size < len(self.dataset):classes = np.random.choice(self.labels_set, self.n_classes, replace=False)indices = []for class_ in classes:indices.extend(self.label_to_indices[class_][self.used_label_indices_count[class_]:self.used_label_indices_count[class_] + self.n_samples])self.used_label_indices_count[class_] += self.n_samplesif self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):np.random.shuffle(self.label_to_indices[class_])self.used_label_indices_count[class_] = 0yield indicesself.count += self.n_classes * self.n_samplesdef __len__(self):return len(self.dataset) // self.batch_size
"""
使用方法:
"""
train_sampler = BalancedBatchSampler(train_dataset, n_classes, n_samples)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=train_sampler,num_workers=args.workers, pin_memory=True)

论文阅读:API-Net相关推荐

  1. 论文阅读——INSIDER:Designing In-Storage Computing System for Emerging High-Performance Drive

    存算一体论文阅读之 INSIDER:Designing In-Storage Computing System for Emerging High-Performance Drive 相关代码已开源. ...

  2. 论文阅读笔记——VulDeePecker: A Deep Learning-Based System for Vulnerability Detection

    本论文相关内容 论文下载地址--Engineering Village 论文中文翻译--VulDeePecker: A Deep Learning-Based System for Vulnerabi ...

  3. 论文阅读:In the Eye of the Beholder: A Survey of Models for Eyes and Gaze

    In the Eye of the Beholder: A Survey of Models for Eyes and Gaze 第二篇EGT的论文阅读,同样是review性质的一篇论文 In the ...

  4. [论文阅读] (20)USENIXSec21 DeepReflect:通过二进制重构发现恶意行为(恶意代码ROI分析经典)

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  5. 【论文阅读】基于区块链的无人集群作战信息共享架构_臧义华

    区块链论文阅读 以下所有的内容都是我的观点,本人能力有限,该篇主要作为我自己的读书笔记. 基于区块链的无人集群作战信息共享架构_臧义华 一.阅读笔记 1. 本文概述 本文针对无人机群的场景,利用区块链 ...

  6. 论文阅读笔记 - Chubby: The Chubby lock service for loosely-coupled distributed systems

    作者:刘旭晖 Raymond 转载请注明出处 Email:colorant at 163.com BLOG:http://blog.csdn.net/colorant/ 更多论文阅读笔记 http:/ ...

  7. ISCA2022部分论文阅读整理

    ISCA2022部分论文阅读整理 GPU设计: 一.GPU tensor core的扩展设计和编译器优化 二.GPU分析模型 剪枝: 一.剪枝self-attention的冗余计算量 二.增大剪枝带来 ...

  8. 【论文阅读】定量评估服务模式__Quantitative Assessment of Service Pattern: Framework, Language, and Metrics

    [论文阅读]定量评估服务模式__Quantitative Assessment of Service Pattern: Framework, Language, and Metrics 文章目录 [论 ...

  9. 模型预测控制与强化学习-论文阅读(一)Integration of reinforcement learning and model predictive

    模型预测控制与强化学习-论文阅读(一)Integration of reinforcement learning and model predictive 最近才把初步的研究方向定下来,导师放养,实验 ...

  10. [论文阅读] (18)英文论文Model Design和Overview如何撰写及精句摘抄——以系统AI安全顶会为例

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

最新文章

  1. iOS 11开发教程(二十一)iOS11应用视图美化按钮之实现按钮的响应(1)
  2. PHP文件上传,下载,Sql工具类!
  3. JavaBeans四个作用域 范围
  4. explain ref_你必须要掌握的MySQL命令:explain
  5. docker内手动安装python环境
  6. Rails + React +antd + Redux环境搭建
  7. 学习Java: Queue
  8. 大学计算机基础实训excel,大学计算机基础实训指导书
  9. ASP.NET中的两个Cookie类:HttpCookie类与Cookie类
  10. SpringBoot中自定义错误页面
  11. Camera 初始化(Preview) 一(Framework-HAL3)
  12. 深搜和广搜--原理彼此的优缺点
  13. 9550电机_电机扭矩计算公式里面的9550*P是怎么得来的?
  14. Ubuntu 中文转换成英文方法
  15. 【面试题】数字转成汉字形式
  16. ode45 matlab 出错,Matlab中ode45求解微分方程组出错。
  17. 一寸照像素和厘米的关系及换算
  18. 509实验室打印机双面打印的方法
  19. 开源开放 | 多模态地球科学知识图谱GAKG
  20. 一款最好用的windows文件管理器

热门文章

  1. 40 个为开发者提供的免费工具
  2. 小程序 背景图 repeat_小仙女壁纸9月9日热门壁纸
  3. 【Idea技巧】02.Idea包进行展开
  4. LTE----003 eNodeB
  5. 硬核卖家天天骂顾客,美团还给评了一个“人气店铺”。
  6. 鸿蒙系统有没有hicar,华为鸿蒙系统发布后!又一款华为操作系统火了:开启智慧出行新时代...
  7. Karhunen-Loève expansion and random field
  8. Selenium之模拟登录铁路12306
  9. iOS 微信、支付宝、银联、Paypal 支付组件封装
  10. 域名解析的记录类型:A记录、CNAME…