目录

1、训练过程

A、Weakly Supervised Attention Learning

B、Attention-Guided Data Augmentation

2、预测过程


细粒度分类:
细粒度分类是为了解决“类内分类”问题,有别于猫狗分类,它解决的是“这只狗是萨摩还是哈士奇”这类问题;这类问题是类别之间的区别较小。
WS-DAN(Weakly Supervised Data Augmentation Network)是一种针对细粒度视觉分类任务的方法,采用基于弱监督学习的图像增强的方法,结合注意力机制,这使得网络在不需要额外标注的情况下聚焦那些图像中有“话语权”的部分。

论文亮点:
1、 Bilinear Attention Pooling(双线性注意力池化机制)
2、attention regularization loss (loss惩罚机制)
3、Attention Cropping和Attention Dropping(用于数据的增强)

1、训练过程

上图是整个网络的训练过程,整个训练分为两部分A:Weakly Supervised Attention Learning;B:Attention-Guided Data Augmentation

A、Weakly Supervised Attention Learning

这一步是基于弱监督的注意力区域学习。首先,网络对原始图片基于CNN提取特征,提取到的特征为feature maps,然后feature maps经过kernel size为1的卷积运算得到Attention maps,也就是说Attention maps 经过feature maps降维得到的,具体降到多少维M是一个超参数,可以自行配置。而这M个Attention maps代表着物体的一个部位,例如鸟的头部,飞机的机翼等。后面还会根据Attention maps对图像进行针对性的增强。

attentions map的生成

class BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, **kwargs):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)self.bn = nn.BatchNorm2d(out_channels, eps=0.001)def forward(self, x):x = self.conv(x)x = self.bn(x)return F.relu(x, inplace=True)# num_features:CNN输出值
# M:M个attentions map
attentions = BasicConv2d(num_features, M, kernel_size=1)

backbone网络首先产生feature maps和attention maps; 每个attention map都指向物体的特定部位; feature maps和attention maps的element-wise乘积产生局部feature maps, 并通过卷积或者池化来提取局部特征; 所得局部特征可以组成特征矩阵. 具体而言设有 N 个feature maps和 M 个attention maps, 由第 k 个attention map Ak 和feature map F 进行element-wise乘积生成第 K 个part feature map , 其中 表示element-wise乘法. 以此类推, 一个feature map和 M个attention maps进行element-wise乘法, 得到 M个part feature maps. 这 M个part feature maps经全局池化(GMP 或GAP), 得到 M维向量,其中第 k个元素为, 将这些 M维向量拼接, 生成 N ×M维向量. 这些向量组成物体的特征P。

Feature Matrix的生成

import torch
import torch.nn as nn
import torch.nn.functional as Fclass BAP(nn.Module):def __init__(self, pool='GAP'):super(BAP, self).__init__()assert pool in ['GAP', 'GMP']if pool == 'GAP':self.pool = Noneelse:self.pool = nn.AdaptiveMaxPool2d(1)def forward(self, features, attentions):B, C, H, W = features.size()_, M, AH, AW = attentions.size()# # match sizeif AH != H or AW != W:# attentions = F.upsample_bilinear(attentions, size=(H, W))attentions = nn.functional.interpolate(attentions, size=(H, W), mode='bilinear', align_corners=False)# feature_matrix: (B, M, C) -> (B, M * C)if self.pool is None:feature_matrix = (torch.einsum('imjk,injk->imn', (attentions, features)) / float(H * W)).view(B, -1)else:feature_matrix = []for i in range(M):AiF = self.pool(features * attentions[:, i:i + 1, ...]).view(B, -1)feature_matrix.append(AiF)feature_matrix = torch.cat(feature_matrix, dim=1)# sign-sqrtfeature_matrix = torch.sign(feature_matrix) * torch.sqrt(torch.abs(feature_matrix) + EPSILON)# l2 normalization along dimension M and Cfeature_matrix = F.normalize(feature_matrix, dim=-1)return feature_matrix

B、Attention-Guided Data Augmentation

在随机数据增强方法中,背景噪声等干扰因素会影响最终的效果。这一步是用之前获得的Attention map来指导数据增强,这会比普通的随机数据增强更有优势,将Attention map提取的部位放大,作为增强后的数据进行训练,为细粒度分类这一问题提供了有效的解决方法。可以有效的过滤掉背景噪声的干扰。

在获得的M个Attention map中,随机的选取一个作为后面数据增强的依据,至于随机选取的原因可以理解为一是增加鲁棒性,二是对多个物体部位做到“雨露均沾”。随机选取一个Attention map之后先对其进行归一化,以方便后续的操作。

根据选取的这个Attention map生成Crop Mask。Crop Mask的生成策略:将 Ak*中大于阈值 θc的元素置为 1 ,其他置为 0,这一块为 1 的区域就是我们细粒度分类中需要的细节区域,将它上采样至模型输入的图片大小,当作一个新的“样本”输入对模型进行训练,以强制模型“注意”这些细节区域。上面的 θc 作为一个超参数也是可以根据具体问题进行调节的,文中默认为 0.5。

Attention Dropping 与 Attention Cropping 类似,将 Attention Map 中小于阈值 θd的元素置为 1 ,其他为 0 。加入这个操作是因为作者发现不同的 Attention Maps 可能聚焦了物体相同的部位,为了让模型也可以注意到其他位置,比如上图中的 Attention Map 是鸟的头部,该操作就可以让模型注意到鸟的其他部位。Attention Dropping 操作让模型提高了 0.6% 的准确率。

Data Augumentation

def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1):batches, _, imgH, imgW = images.size()if mode == 'crop':crop_images = []for batch_index in range(batches):atten_map = attention_map[batch_index:batch_index + 1]if isinstance(theta, tuple):theta_c = random.uniform(*theta) * atten_map.max()else:theta_c = theta * atten_map.max()crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_cnonzero_indices = torch.nonzero(crop_mask[0, 0, ...])height_min = max(int(nonzero_indices[:, 0].min().item() - padding_ratio * imgH), 0)height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH)width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0)width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW)crop_images.append(F.upsample_bilinear(images[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max],size=(imgH, imgW)))crop_images = torch.cat(crop_images, dim=0)return crop_imageselif mode == 'drop':drop_masks = []for batch_index in range(batches):atten_map = attention_map[batch_index:batch_index + 1]if isinstance(theta, tuple):theta_d = random.uniform(*theta) * atten_map.max()else:theta_d = theta * atten_map.max()drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d)drop_masks = torch.cat(drop_masks, dim=0)drop_images = images * drop_masks.float()return drop_imageselse:raise ValueError('Expected mode in [\'crop\', \'drop\'], but received unsupported augmentation method %s' % mode)

损失函数

训练过程中损失函数的设计,除了计算结果的交叉熵损失函数外,作者为了每次各个 Attention Map 可以找到相同的物体部位,还加入了特征图与部位中心的平方差之和作为惩罚项,如下公式,这就会让每个特征图固定到每个部位的中心。其中部位中心也是每次学习到的特征图来更新的。(使得相同物体上同一部位的特征尽可能相似)

Ck可初始化为0, 然后按照以下滑动平均公式来更新其值.

2、预测过程

预测过程依然分为两部分:

1、原始图片输入训练好的模型中得到属于各个类别的概率,以及 Attention Maps;

2、将第一步中得到的 M 个 Attention Maps 取平均值,注意这里不是像训练过程里面随机取一个区域,我的理解是这里如果随机取的话,可能会导致模型不稳定,每次的预测结果不一样。下面就是与训练过程类似了,根据 Attention Maps 的平均值 Am 画出截取框,将截取框上采样再放入训练好的网络中,得到“注意力区域”属于各个类别的概率;

3、将上面两部分的结果取平均值得到最后的分类结果。

WS-DAN论文解读相关推荐

  1. 目标检测学习笔记2——ResNet残差网络学习、ResNet论文解读

    ResNet残差网络学习.ResNet论文解读 一.前言 为什么会提出ResNet? 什么是网络退化现象? 那网络退化现象是什么造成的呢? ResNet要如何解决退化问题? 二.残差模块 三.残差模块 ...

  2. 【论文解读】CVPR 2021 妆容迁移 论文解读Spatially-invariant Style-codes Controlled Makeup Transfer

    [论文解读]CVPR 2021 妆容迁移 论文解读 Spatially-invariant Style-codes Controlled Makeup Transfer 摘要 方法特点 实现方法 公式 ...

  3. Open-Vocabulary Multi-Label Classification via Multi-modal Knowledge Transfer 论文解读

    Open-Vocabulary Multi-Label Classification via Multi-modal Knowledge Transfer 论文解读 前言 Motivation Con ...

  4. 自监督学习(Self-Supervised Learning)多篇论文解读(下)

    自监督学习(Self-Supervised Learning)多篇论文解读(下) 之前的研究思路主要是设计各种各样的pretext任务,比如patch相对位置预测.旋转预测.灰度图片上色.视频帧排序等 ...

  5. 自监督学习(Self-Supervised Learning)多篇论文解读(上)

    自监督学习(Self-Supervised Learning)多篇论文解读(上) 前言 Supervised deep learning由于需要大量标注信息,同时之前大量的研究已经解决了许多问题.所以 ...

  6. 可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读

    可视化反投射:坍塌尺寸的概率恢复:ICCV9论文解读 Visual Deprojection: Probabilistic Recovery of Collapsed Dimensions 论文链接: ...

  7. 从单一图像中提取文档图像:ICCV2019论文解读

    从单一图像中提取文档图像:ICCV2019论文解读 DewarpNet: Single-Image Document Unwarping With Stacked 3D and 2D Regressi ...

  8. 点云配准的端到端深度神经网络:ICCV2019论文解读

    点云配准的端到端深度神经网络:ICCV2019论文解读 DeepVCP: An End-to-End Deep Neural Network for Point Cloud Registration ...

  9. 图像分类:CVPR2020论文解读

    图像分类:CVPR2020论文解读 Towards Robust Image Classification Using Sequential Attention Models 论文链接:https:// ...

  10. CVPR2020论文解读:手绘草图卷积网络语义分割

    CVPR2020论文解读:手绘草图卷积网络语义分割 Sketch GCN: Semantic Sketch Segmentation with Graph Convolutional Networks ...

最新文章

  1. mysql与ofbiz,ofbiz+mysql安装求教
  2. tcp/ip 协议栈Linux源码分析一 IPv4分片报文重组分析一
  3. colorkey唇釉是否安全_colorkey空气唇釉,19/支
  4. P4548 [CTSC2006]歌唱王国
  5. c++ 解析xml文件
  6. [转]最常用的15大Eclipse开发快捷键技巧
  7. android学习笔记---38_采用广播接收者拦截外拔电话,实现原理以及实例源码
  8. flink Sql查询异常NoResourceAvailableException: Could not acquire the minimum required resources
  9. python封面是什么样子_Python诱变剂:通过url添加封面照片/相册图片?
  10. Xtreme ToolkitPro 编译选项
  11. Pytorch专题实战——逻辑回归(Logistic Regression)
  12. BIRCH算法(Java实现)
  13. jquery 设置背景
  14. 继电器和蜂鸣器的使用
  15. 日期格式化java_JAVA格式化时间日期
  16. 看看在职场里是怎么混社会的……
  17. 计算机软件如何助力科研,研究生必备科研绘图软件,助力科学研究
  18. discuz数据字典
  19. Android MotionEvent详解
  20. Android: MultiDex原理和优化

热门文章

  1. Open-Falcon 安装
  2. 大数据时代的知识图谱
  3. Cisco 开启三层交换机ip routing
  4. 基于555定时器的函数信号发生器
  5. 04-防火墙双机热备
  6. 程序员必学电脑计算机专业英语词汇 09 (111 单词)
  7. 程序员必学电脑计算机专业英语词汇 11 (125 单词)
  8. 10 款开源的在线游戏,点开就能玩的那种
  9. 苹果android投屏,如何将iPhone手机投屏到电脑电视?
  10. golang实现andflow流程引擎