ECCV2022 论文 Contrastive Deep Supervision
论文链接:https://arxiv.org/pdf/2207.05306.pdf
代码链接:GitHub - ArchipLab-LinfengZhang/contrastive-deep-supervision: Codes for ECCV2022 paper - contrastive deep supervision
动机
近年来,由于大量数据的出现以及计算机算力的提升,深度学习统治了计算机视觉领域。然而,随着神经网络深度增加的同时,也带来了一些挑战。传统的有监督方法仅对模型的最后一层进行监督,然后再将误差反向传播到中间层。由于反向传播过程中可能会出现梯度消失、爆炸及弥漫等问题,怎么优化好模型中间层的参数成为了一个难点。
近期,深度监督被用于解决上述问题,它的做法是在中间层中添加辅助的分类器。在训练期间,辅助分类器与最终的分类器一同优化。大量实验证明,深度监督加速了模型的收敛。然而,通常来说,不同深度的特征学到的信息不同,底层特征往往含有丰富的纹理及颜色等信息,而深层特征往往含有丰富的语义信息,简单地将辅助分类器应用到中间层特征显然存在问题,因为底层特征没有丰富的语义信息,不适合进行分类 (底层特征往往用于目标定位,因为它含有较多的空间位置信息)。基于这些理论,就有了这篇文章 《Contrastive Deep Supervision》,以下简称 CDS。
创新点
这篇文章的作者认为:相比于有监督的任务损失,对比学习能给中间层的特征提供更好的监督。对比学习通常在同一张图片中使用两种不同的数据增强 (增强方法可以相同,但其中的参数不同),随后将增强后的两张图片视为正样本对,与其余图片构成负样本对。作者提出的方法如下图中的 (d) 所示,几个投影头会附件在中间层的后面,用于将特征映射到嵌入空间,以便进行对比学习,这些投影头在推理期间会被 kill 掉,这样就避免了额外的计算及额外的存储空间。与训练中间层特征去学习特定任务知识的深度监督不同,CDS 学习的是图片中的本质信息,这些信息不受数据增强的影响,这也使神经网络能更好地泛化。此外,由于对比学习可以在未标记的数据上进行,CDS 也可应用到半监督任务中。这篇文章的主要创新点如下:
(1) 提出了 CDS,这是一种神经网络训练方法,其中中间层直接通过对比学习进行优化。它使神经网络能够学习更好的视觉表示,且无需在推理过程中增加额外的开销
(2) 从深度监督的角度来看,作者第一个表明除了有监督任务损失之外,中间层还可以通过其他方式进行训练
(3) 从表示学习的角度来看,作者首个表明对比学习和监督学习可以以一阶段的深度监督的方式联合训练模型,而不是两阶段的 “pretrain-finetune” 方案 (先预训练,后微调)
方法论
CDS
假定一个 minibatch 有 N 张图片,对每张图片都进行两次随机的数据增强,增强后就有 2N 张图片。为了方便,作者把 和 作为来自同一图像的两个增强表示,这两张图片也被视为一个正样本对。 为经过投影层并标准化后的输出,对比学习的公式如下:
鼓励编码器网络从同一图像中学习不同增强的相似表示,同时增加来自不同图像的增强表示之间的差异。
CDS 与深度监督之间的主要区别在于深度监督通过交叉熵损失来训练辅助分类器,而 CDS 则通过对比学习来训练。CDS 整体损失函数公式如下:
这个公式表示有 K-1 个中间层使用了对比学习来训练,最后一层使用交叉熵损失来训练。
CDS 还可以推广到半监督学习和知识蒸馏中:
在半监督学习中,作者假设有 个带标签的图片,对应的标签为 ,无标签的数据为 。在有标签数据中,可以直接使用 CDS。在无标签数据中,只能进行对比学习。整体的损失公式如下:
在知识蒸馏中,作者进一步提出通过将教师模型学到的图像在数据增强中的不变性传递给学生模型,来改进具有 CDS 的知识蒸馏。 和 分别表示知识蒸馏中的学生模型和教师模型,原始的知识蒸馏直接最小化了学生和教师模型的骨干特征之间的距离,可以表示为:
与原始知识蒸馏不同,带有 CDS 的知识蒸馏最小化的是两个模型的嵌入向量 (经投影层得到) 之间的距离,公式如下:
知识蒸馏中的整体损失函数公式如下:
一些细节和 tricks
投影层的设计
在 CDS 的训练期间,将几个投影头添加到神经网络的中间层。这些投影头将骨干特征映射到归一化的嵌入空间,其中应用了对比学习损失。通常,投影头是由两个全连接层和一个 ReLU 函数堆叠而成的非线性投影。然而,在 CDS 中,输入特征来自中间层而不是最终层,因此需要修改投影层的设计。作者通过在非线性投影之前添加卷积层来增加这些投影头的复杂性。
对比学习
CDS 是一个通用的训练框架,不依赖于特定的对比学习方法。在这篇文章中,作者在大多数实验中采用 SimCLR 和 SupCon 作为对比学习的方法。如果使用更好的对比学习算法,模型最终的性能也会进一步提升。
负样本
以前的研究表明,负样本的数量对对比学习的表现有着重要的影响,因此在对比学习中通常使用大的 batch size。但在 CDS 中,作者认为诸如交叉熵之类的损失已经足以防止对比学习收敛到崩溃的解决方案。
实验结果
在 CIFAR100 和 CIFAR10 上的分类结果如下:
ImageNet 上的分类结果如下:
在目标检测数据集 COCO2017 上的结果如下:
在细粒度数据集上的结果如下:
代码
代码也比较简单,拿 resnet18 来举例:
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import torch__all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]# model_urls = {
# "resnet18": "./pretrain/resnet18-5c106cde.pth",
# "resnet34": "./pretrain/resnet34-333f7ec4.pth",
# "resnet50": "./pretrain/resnet50-19c8e357.pth",
# "resnet101": "./pretrain/resnet101-5d3b4d8f.pth",
# "resnet152": "./pretrain/resnet152-b121ed2d.pth",
# }model_urls = {"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth","resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth","resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth","resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth","resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}def conv3x3(in_planes, out_planes, stride=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)def conv1x1(in_planes, out_planes, stride=1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)class SepConv(nn.Module):def __init__(self, channel_in, channel_out, kernel_size=3, stride=2, padding=1, affine=True):# depthwise and pointwise convolution, downsample by 2super(SepConv, self).__init__()self.op = nn.Sequential(nn.Conv2d(channel_in,channel_in,kernel_size=kernel_size,stride=stride,padding=padding,groups=channel_in,bias=False,),nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0, bias=False),nn.BatchNorm2d(channel_in, affine=affine),nn.ReLU(inplace=False),nn.Conv2d(channel_in,channel_in,kernel_size=kernel_size,stride=1,padding=padding,groups=channel_in,bias=False,),nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, bias=False),nn.BatchNorm2d(channel_out, affine=affine),nn.ReLU(inplace=False),)def forward(self, x):return self.op(x)class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock, self).__init__()self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=100, zero_init_residual=False, align="CONV"):super(ResNet, self).__init__()self.inplanes = 64self.align = align# reduce the kernel-size and stride of ResNet on cifar datasets.self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)# remove maxpooling layer for ResNet on cifar datasets.# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.auxiliary1 = nn.Sequential(SepConv(channel_in=64 * block.expansion, channel_out=128 * block.expansion),SepConv(channel_in=128 * block.expansion, channel_out=256 * block.expansion),SepConv(channel_in=256 * block.expansion, channel_out=512 * block.expansion),nn.AvgPool2d(4, 4),)self.auxiliary2 = nn.Sequential(SepConv(channel_in=128 * block.expansion,channel_out=256 * block.expansion,),SepConv(channel_in=256 * block.expansion,channel_out=512 * block.expansion,),nn.AvgPool2d(4, 4),)self.auxiliary3 = nn.Sequential(SepConv(channel_in=256 * block.expansion,channel_out=512 * block.expansion,),nn.AvgPool2d(4, 4),)self.auxiliary4 = nn.AvgPool2d(4, 4)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),nn.BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return nn.Sequential(*layers)def forward(self, x):feature_list = []x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.layer1(x)feature_list.append(x)x = self.layer2(x)feature_list.append(x)x = self.layer3(x)feature_list.append(x)x = self.layer4(x)feature_list.append(x)out1_feature = self.auxiliary1(feature_list[0]).view(x.size(0), -1)out2_feature = self.auxiliary2(feature_list[1]).view(x.size(0), -1)out3_feature = self.auxiliary3(feature_list[2]).view(x.size(0), -1)out4_feature = self.auxiliary4(feature_list[3]).view(x.size(0), -1)out = self.fc(out4_feature)feat_list = [out4_feature, out3_feature, out2_feature, out1_feature]for index in range(len(feat_list)):feat_list[index] = F.normalize(feat_list[index], dim=1)if self.training:return out, feat_listelse:return outdef resnet18(pretrained=False, **kwargs):"""Constructs a ResNet-18 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))return model
就是在 resnet 的4个 layer 后添加了 auxiliary head,而 auxiliary head 又由深度可分离卷积与平均池化层构成,用于进一步提取特征 (因为作者认为 resnet 提取的特征的表达能力还不够强,需要进一步提取)
对比学习的损失函数代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SupConLoss(nn.Module):"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.It also supports the unsupervised contrastive loss in SimCLR"""def __init__(self, temperature=0.07, contrast_mode='all',base_temperature=0.07):super(SupConLoss, self).__init__()self.temperature = temperatureself.contrast_mode = contrast_modeself.base_temperature = base_temperaturedef forward(self, features, labels=None, mask=None):"""Compute loss for model. If both `labels` and `mask` are None,it degenerates to SimCLR unsupervised loss:https://arxiv.org/pdf/2002.05709.pdfArgs:features: hidden vector of shape [bsz, n_views, ...].labels: ground truth of shape [bsz].mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample jhas the same class as sample i. Can be asymmetric.Returns:A loss scalar."""device = (torch.device('cuda')if features.is_cudaelse torch.device('cpu'))if len(features.shape) < 3:raise ValueError('`features` needs to be [bsz, n_views, ...],''at least 3 dimensions are required')if len(features.shape) > 3:features = features.view(features.shape[0], features.shape[1], -1)batch_size = features.shape[0]if labels is not None and mask is not None:raise ValueError('Cannot define both `labels` and `mask`')elif labels is None and mask is None:mask = torch.eye(batch_size, dtype=torch.float32).to(device)elif labels is not None:labels = labels.contiguous().view(-1, 1)if labels.shape[0] != batch_size:raise ValueError('Num of labels does not match num of features')mask = torch.eq(labels, labels.T).float().to(device)else:mask = mask.float().to(device)contrast_count = features.shape[1] # 2contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)# 256 x 512if self.contrast_mode == 'one':anchor_feature = features[:, 0]anchor_count = 1elif self.contrast_mode == 'all':anchor_feature = contrast_feature # 256 x 512anchor_count = contrast_count # 2else:raise ValueError('Unknown mode: {}'.format(self.contrast_mode))# compute logitsanchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T),self.temperature)# for numerical stability# print (anchor_dot_contrast.size()) 256 x 256logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)logits = anchor_dot_contrast - logits_max.detach()# tile maskmask = mask.repeat(anchor_count, contrast_count)# mask-out self-contrast caseslogits_mask = torch.scatter(torch.ones_like(mask),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0)mask = mask * logits_mask# compute log_probexp_logits = torch.exp(logits) * logits_masklog_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))# compute mean of log-likelihood over positivemean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)loss = - (self.temperature / self.base_temperature) * mean_log_prob_posloss = loss.view(anchor_count, batch_size).mean()return loss
在 CIFAR100 上,我使用 resnet18 复现的结果为 80.54%,与论文中的 80.84% 差别不大
ECCV2022 论文 Contrastive Deep Supervision相关推荐
- ECCV2022论文列表(中英对照)
Paper ID Paper Title 论文标题 8 Learning Uncoupled-Modulation CVAE for 3D Action-Conditioned Human Motio ...
- 深度学习100问:什么是深监督(Deep Supervision)?
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 所谓深监督(Deep Supervision),就是在深度神经网络 ...
- Deep Supervision:深度监督(2014)+DHM
深度监督(Deep Supervision)又称为(中继监督 intermediate supervision),就是在深度神经网络的某些中间隐藏层加了一个辅助的分类器作为一种网络分支来对主干网络进行 ...
- [论文解读]Deep active learning for object detection
Deep active learning for object detection 文章目录 Deep active learning for object detection 简介 摘要 初步 以前 ...
- deep supervision
深度监督学习(deep supervision learning) 和常规的深度学习机制相比,深度监督学习不仅在网络的最后输出结果out,同时在网络的中间特征图,经过反卷积和上采样操作,得到和out尺 ...
- [论文翻译] Deep Learning
[论文翻译] Deep Learning 论文题目:Deep Learning 论文来源:Deep learning Nature 2015 翻译人:BDML@CQUT实验室 Deep learnin ...
- [论文翻译]Deep Learning 翻译及阅读笔记
论文题目:Deep Learning 论文来源:Deep Learning_2015_Nature 翻译人:BDML@CQUT实验室 Deep Learning Yann LeCun∗ Yoshua ...
- 论文翻译[Deep Residual Learning for Image Recognition]
论文来源:Deep Residual Learning for Image Recognition [翻译人]:BDML@CQUT实验室 Deep Residual Learning for Imag ...
- 读论文Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank
读论文Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank 原地址:https://blog.cs ...
最新文章
- Java Scanner类
- Xshell-密钥登录
- python代码风格_Python编码风格,看这篇就够了
- 客户端页面不更新CSS样式或JS脚本的方法 (2018-08-17 17:33)
- excel 电阻并联计算_电阻器的构成及取代原则
- 找不到android的sdk,CircleCI – 找不到Android Studio项目的SDK位置
- 质量管理系统_晟通集团内训 | 质量管理系统提升实战训练
- SMB、FTP、DNS、等六个服务总结
- 如何批量登陆远程主机和配置【转】
- 蓝桥杯试题开灯游戏c语言,[蓝桥杯][算法提高VIP]开灯游戏 (C++代码)
- 终于给cs来了一次小整容
- 操作系统 第二部分 进程管理(五)
- 我们建立数据中心,需要考虑哪些问题?
- fastDB核心心得
- Java程序员的8个级别,你在哪?
- 半小时一篇文过完C语言基础知识点
- trim函数 html,trim函数的使用方法(你会用TRIMMEAN 函数吗?)
- 使用大白菜装机维护版软件取消Win7开机密码
- Dockerfile 的详解
- 【翻译】--19C Oracle 安装指导
热门文章
- c1能力认证考试训练任务03-web基础与布局
- xAxis、yAxis-配置项
- 中国建设银行信息技术岗笔试
- 中信银行软件开发中心 c语言笔试题目,中信银行软件研发中心笔试内容
- opta球员大数据预测胜负_大数据预测简介及使用流程
- python代码直接关机_关机信号在python脚本中运行代码
- cad画直角命令_cad中怎么把直角倒角
- 卷积神经网络CNN——使用keras识别猫咪
- Latex表格与图片旋转,且标题同时旋转 (表格的标题可设置于表格的上方或下方)
- 【Unicode编码表】UniCode编码表+转化器