三篇论文

《Supervised Contrastive Learning》
《A Simple Framework for Contrastive Learning of Visual Representations》
《What Makes for Good Views for Contrastive Learning》

对比学习的思想起源于无监督学习,相比于监督学习算法,无监督学习由于没有标签的指导,训练过程学习样本的特征会更加困难。对比学习的核心思想就是通过数据增强构造原来样本的多样性,损失函数的设计用来拉进正样本与锚样本的距离,增大与负样本的距离,在这一过程中,网络更容易学到由源样本经过数据增强之后的多个样本所具有的共同特征,而这一特征对于源样本来说更可能是本质性的。

《A Simple Framework for Contrastive Learning of Visual Representations》SimCLR

论文提出了一种更简洁的对比学习算法,主要有三个贡献:

  • 使用不同组合形式的数据增强对于下游的预测任务非常重要
  • 在特征提取的encoder和对抗损失之间引入可学的多层感知机可以提高网络的学习能力
  • 在一个batch中,样本的数目越多越容易提高训练性能

这一工作也是后面诸多对比学习工作的基础。

网络框架

  1. 对于一个锚样本xxx,使用随机的数据增强方式生成一对正样本,用x~i\tilde{x}_ix~i​和x~j\tilde{x}_jx~j​表示
  2. 一个特征提取网络encoderf(⋅)f(\cdot)f(⋅)用来提取x~i\tilde{x}_ix~i​和x~j\tilde{x}_jx~j​的特征,用来提高网络的泛化能力,hi=f(x~i),hj=f(x~j)h_i=f(\tilde{x}_i),h_j=f(\tilde{x}_j)hi​=f(x~i​),hj​=f(x~j​)。特征提取的网络通常使用resnet。
  3. 在特征表示和对抗损失之间添加一个架构为多层感知机的投影网络,即zi=g(hi)=W(2)σ(W(1)hi)z_i=g(h_i)=W^{(2)}\sigma(W^{(1)}h_i)zi​=g(hi​)=W(2)σ(W(1)hi​),这也是文章的贡献之一:在投影网络的输出端进行对比损失的计算要比直接在f(⋅)f(\cdot)f(⋅)的输出计算更有用。

对三个贡献做出解释

1. 为什么不同形式的数据增强的组合有助于学到好的特征?
对比学习的目的是学到对于一个样本最核心的特征,如果使用单一的数据增强,比如只使用随机裁剪(random cropping),那么网络在训练过程就会认为颜色信息可能也是有用的,因为没有label来指导它学到下游任务的目标,网络无法提取对于下游更核心的特征。而采用多个数据增强的组合可以让网络认识到什么信息是不相关的,比如一个颜色失真的样本和一个高斯噪声的样本,这两个样本来源于同一个样本,网络在优化过程中需要认为他们两个着某些特征上是相同的,从而认识到颜色和噪声对于要提取的信息都是不重要的。
2. 为什么在encoder后面添加一个多层感知机可以提高学习能力?
z=g(h)z=g(h)z=g(h)的训练目的是增加对于数据变换的不变性,根据神经网络传统的学习方式,由于投影层处于较高的网络层次,网络学到的特征就更倾向于任务相关(high-level),低层的网络学到的更倾向于细节特征,如果没有投影层来学习高级特征,全部由encoder完成的话,encoder学到的特征在不同下游任务上的泛化能力会下降。
3. 为什么batchsize越大越容易收敛?
根据损失函数可以知道,当batchsize比较大的时候,意味着分母上的负样本数量也比较多,损失函数的目的是从一堆样本中找出锚样本,或者说,找出最能够区分锚样本与负样本的表征,当负样本数目多的时候,网络更容易排除什么信息对于该样本是不相关的,所以能够加快训练。

损失函数

Lself=∑i∈ILiself=−∑i∈Ilog⁡exp⁡(zi⋅zj(i)/τ)∑α∈A(i)exp⁡(zi⋅zα/τ)\mathcal{L}^{self}=\sum_{i\in I}\mathcal{L}^{self}_i=-\sum_{i\in I}\log \frac{\exp (z_i \cdot z_{j(i)}/\tau)}{\sum_{\alpha \in A(i)}\exp (z_i \cdot z_{\alpha}/\tau)} Lself=i∈I∑​Liself​=−i∈I∑​log∑α∈A(i)​exp(zi​⋅zα​/τ)exp(zi​⋅zj(i)​/τ)​
其中III表示当前的一个batch,算法实现的时候,首先是从定义好的大小为batchsize的样本数目中数据增强出两个batchsize的样本来(multiviewed batch),这个batchsize就是公式中的III,对于一个batch中的每个样本,计算Liself\mathcal{L}^{self}_{i}Liself​,其中ziz_izi​是当前的样本(也称锚样本),zj(i)z_{j(i)}zj(i)​是与ziz_izi​同源的样本(由同一个样本数据增强得到),A(i)A(i)A(i)包含整个batchsize中除了当前样本之外的其他样本,τ\tauτ是温度系数,实际在训练的过程中,一个batch中的每个样本都会做一次锚样本。
这样说感觉上不是很直观,通过代码会加深对公式的理解。

原始无监督对比学习的代码及注释

重点在数据集的加载方式,loss的设计上

数据集准备

class ContrastiveLearningDataset:def __init__(self, root_folder):self.root_folder = root_folder@staticmethoddef get_simclr_pipeline_transform(size, s=1):"""Return a set of data augmentation transformations as described in the SimCLR paper.定义数据增强的方式,选择训练的数据集"""color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),transforms.RandomHorizontalFlip(),transforms.RandomApply([color_jitter], p=0.8),transforms.RandomGrayscale(p=0.2),GaussianBlur(kernel_size=int(0.1 * size)),transforms.ToTensor()])return data_transformsdef get_dataset(self, name, n_views):valid_datasets = {'cifar10': lambda: datasets.CIFAR10(self.root_folder, train=True,transform=ContrastiveLearningViewGenerator(self.get_simclr_pipeline_transform(32),n_views),download=True),'stl10': lambda: datasets.STL10(self.root_folder, split='unlabeled',transform=ContrastiveLearningViewGenerator(self.get_simclr_pipeline_transform(96),n_views),download=True)}try:dataset_fn = valid_datasets[name]except KeyError:raise InvalidDatasetSelection()else:return dataset_fn()class ContrastiveLearningViewGenerator(object):"""Take two random crops of one image as the query and key.默认使用两个view做数据增强,即如果有一个batchsize为4 的样本[a1, b1, c1, d1]经过viewGenerator之后的形式为: [ a1, a2b1, b2c1, c2d1, d2]其中每一行表示同一个源样本产生的两个view样本。"""def __init__(self, base_transform, n_views=2):self.base_transform = base_transformself.n_views = n_viewsdef __call__(self, x):return [self.base_transform(x) for i in range(self.n_views)]

特征提取的模型

class ResNetSimCLR(nn.Module):'''选择使用resnet-18还是resnet-50作为backbone,对应论文里面的encoder ==》 Enc(.)以及投影网络Projection Network ==》 Proj(i)其中encoder使用resnet的非全连接层部分,投影网络使用多层感知机'''def __init__(self, base_model, out_dim):super(ResNetSimCLR, self).__init__()self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, num_classes=out_dim),"resnet50": models.resnet50(pretrained=False, num_classes=out_dim)}self.backbone = self._get_basemodel(base_model)dim_mlp = self.backbone.fc.in_features# add mlp projection headself.backbone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.backbone.fc)def _get_basemodel(self, model_name):try:model = self.resnet_dict[model_name]except KeyError:raise InvalidBackboneError("Invalid backbone architecture. Check the config file and pass one of: resnet18 or resnet50")else:return modeldef forward(self, x):return self.backbone(x)

损失函数设计

    def info_nce_loss(self, features):# 这里的labels用来做mask,方便后面与矩阵做逐元素相乘的时候筛选正样本和负样本,以batchsize=3为例,# 经过数据增强后一个batch的大小实际上为6,输入的features = [6, 128]# 最后生成的labels:tensor([[1., 0., 0., 1., 0., 0.],#                        [0., 1., 0., 0., 1., 0.],#                        [0., 0., 1., 0., 0., 1.],#                        [1., 0., 0., 1., 0., 0.],#                        [0., 1., 0., 0., 1., 0.],#                        [0., 0., 1., 0., 0., 1.]])labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()labels = labels.to(self.args.device)features = F.normalize(features, dim=1)# 计算相似度矩阵,即如果一个batch的输入样本为[ a1, a2#                                       b1, b2#                                       c1, c2]# 经过网络特征提取之后为:[a1 b1 c1 a2 b2 c2]# 相应地相似度矩阵为:[a1a1 a1b1 a1c1 a1a2 a1b2 a1c2#                  b1a1 b1b1 b1c1 b1a2 b1b2 b1c2#                  c1a1 c1b1 c1c1 c1a2 c1b2 c1c2#                  a2a1 a2b1 a2c1 a2a2 a2b2 a2c2#                  b2a1 b2b1 b2c1 b2a2 b2b2 b2c2#                  c2a1 c2b1 c2c1 c2a2 c2b2 c2c2]similarity_matrix = torch.matmul(features, features.T)# assert similarity_matrix.shape == (#     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)# assert similarity_matrix.shape == labels.shape# discard the main diagonal from both: labels and similarities matrixmask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)labels = labels[~mask].view(labels.shape[0], -1)# 此时的labels为:# tensor([[0., 0., 1., 0., 0.],#         [0., 0., 0., 1., 0.],#         [0., 0., 0., 0., 1.],#         [1., 0., 0., 0., 0.],#         [0., 1., 0., 0., 0.],#         [0., 0., 1., 0., 0.]])# 相比原来的labels删除了对角线上锚样本与自己做乘积的情况,# 对应在原相似度矩阵的位置上只保留label为1的数,相当于只保留了正样本与锚样本的乘积,即a1a2,b1b2,c1c2...# mask为:tensor([[ True, False, False, False, False, False],#               [False,  True, False, False, False, False],#               [False, False,  True, False, False, False],#               [False, False, False,  True, False, False],#               [False, False, False, False,  True, False],#               [False, False, False, False, False,  True]])# 相应地,在相似度矩阵上面排除锚样本与自己相乘的情况similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)# assert similarity_matrix.shape == labels.shape# select and combine multiple positivespositives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)# positives 保留正样本与锚样本的乘积:[a1a2#                                 b1b2#                                 c1c2#                                 a2a1#                                 b2b1#                                 c2c1]# negatives 保留锚样本与负样本的乘积:[a1b1 a1c1 a1b2 a1c2#                                b1a1 b1c1 b1a2 b1c2#                                c1a1 c1b1 c1a2 c1b2#                                a2b1 a2c1 a2b2 a2c2#                                b2a1 b2c1 b2a2 b2c2#                                c2a1 c2b1 c2a2 c2b2]# select only the negatives the negativesnegatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)logits = torch.cat([positives, negatives], dim=1)# 将positives堆在negatives的前面,形如[a1a2 a1b1 a1c1 a1b2 a1c2#         #                        b1b2 b1a1 b1c1 b1a2 b1c2#         #                        c1c2 c1a1 c1b1 c1a2 c1b2#         #                        a2a1 a2b1 a2c1 a2b2 a2c2#         #                        b2b1 b2a1 b2c1 b2a2 b2c2#         #                        c2c1 c2a1 c2b1 c2a2 c2b2]# 最左边一列为infoloss的分子,右边为分子labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)# labels = [0, 0, 0, 0, 0, 0],这里相当于交叉熵损失函数里面样本的真实标签为0# 因为对比损失函数跟交叉熵损失的计算形式是一样的,所以如果类别全部为0,表示的对于logits的每一行,都使用索引为0(也就是第一个)的元素作为分子logits = logits / self.args.temperaturereturn logits, labels

训练过程

# 损失函数与交叉熵的形式一样
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)def train(self, train_loader):# pytorch的GradScaler和autocast使用混合精度可以节约内存空间,运行较大的batchsizescaler = GradScaler(enabled=self.args.fp16_precision)# save config filesave_config_file(self.writer.log_dir, self.args)n_iter = 0logging.info("Start SimCLR training for {self.args.epochs} epochs.")logging.info("Training with gpu: {self.args.disable_cuda}.")for epoch_counter in range(self.args.epochs):for images, _ in tqdm(train_loader):images = torch.cat(images, dim=0)images = images.to(self.args.device)with autocast(enabled=self.args.fp16_precision):# 对输入的正负样本图像提取的特征features = self.model(images)print(features.shape)logits, labels = self.info_nce_loss(features)loss = self.criterion(logits, labels)self.optimizer.zero_grad()scaler.scale(loss).backward()scaler.step(self.optimizer)scaler.update()

对比学习(Contrastive Learning) (1)相关推荐

  1. 从对比学习(Contrastive Learning)到对比聚类(Contrastive Clustering)

    从对比学习(Contrastive Learning)到对比聚类(Contrastive Clustering) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailug ...

  2. 对比学习Contrastive Learning

    对比学习是一种常用的自监督学习方法. 核心思想:把正样本距离拉近,正样本与负样本距离拉远.(类似度量学习中的margin, 但是对比学习为正负样本分类,无margin概念) 方式:通过一个正样本,以及 ...

  3. 理解对比表示学习(Contrastive Learning)

    目录 一.前言 二.对比学习 三.主要论文(附代码分析) 1. AMDIM ([Bachman](https://arxiv.org/pdf/1906.00910.pdf) *et al.* 2019 ...

  4. 理解对比学习(contrasive learning)

    1.什么是对比学习? 对比学习,顾名思义就是在训练中和某些东西进行对比从而学习,在自编码器中,输出与自己进行对比,从而得到一个中间量latent code,我认为这也是一种对比学习. 2.对比学习框架 ...

  5. 图对比学习入门 Contrastive Learning on Graph

    对比学习作为近两年的深度学习界的一大宠儿,受到了广大研究人员的青睐.而图学习因为图可以用于描述生活中广泛出现的非欧式数据,具有广大的应用前景.当图学习遇上了对比学习- 本文从对比学习入手,再介绍图对比 ...

  6. SimCSE:用于句子嵌入的对比学习

    目录 引言 对比学习Contrastive Learning SimCSE思想 无监督下的SimCSE 有监督下的SimCSE 连接各向异性Connection to Anisotropy 分析 引言 ...

  7. MoCo 动量对比学习——一种维护超大负样本训练的框架

    MoCo 动量对比学习--一种维护超大负样本训练的框架 FesianXu 20210803 at Baidu Search Team 前言 在拥有着海量数据的大型互联网公司中,对比学习变得逐渐流行起来 ...

  8. 对比学习顶会论文系列-3-2

    文章目录 一.特定任务中的对比学习 1.2 摘要生成中的对比学习--SimCLS: A Simple Framework for Contrastive Learning of Abstractive ...

  9. 张俊林:对比学习研究进展精要

    作者 | 张俊林 编辑 | 夕小瑶的卖萌屋 对比学习(Contrastive Learning)最近一年比较火,各路大神比如Hinton.Yann LeCun.Kaiming He及一流研究机构比如F ...

  10. 对比学习(一)-双塔模型-simCLR

    对比学习链接 对比学习 引言 bert在对比学习中起到的作用: **对比学习的作用:** 生成式自监督学习: 判别式自监督学习 simCLR SimCLR正负例构建 SimCLR表示学习系统构建 Si ...

最新文章

  1. vue2.0transition过渡的使用介绍
  2. Vivotek 摄像头远程栈溢出漏洞分析及利用
  3. java throw 什么意思_[转载]java中throw和throws的区别
  4. [trouble shoot]atol和atoll
  5. Shell and powershell
  6. P1020 导弹拦截(n*log n时间的最长上升子序列思想)
  7. java实现可选形参_Java:可选的可选实现
  8. C++堆和栈详解(转)
  9. python的map怎么用_python中的map怎么使用
  10. linux下mysql乱码_linux下mysql中文乱码
  11. vscode+vim使用技巧
  12. 统计正数和负数的个数然后计算这些数的平均值_计算机中的二进制原来是这样:原码、反码和补码
  13. 《高质量程序设计指南---C++/C语言》 下载
  14. Win7 下替代NetMeeting的屏幕共享工具 InletexEMC
  15. 第四天:Spark Streaming
  16. 卢卡斯定理求组合数(逆元+费马小定理+扩展欧几里得)
  17. 显示测试漏光软件,屏幕漏光测试怎么做(液晶显示器屏幕漏光的检测方法)
  18. Cell Stem Cell | 动物所刘光慧等显示年轻血液可逆转衰老进程
  19. 凉面经-维恩贝特面试复盘
  20. linux怎么设置永久变量,Linux环境变量永久设置方法(zsh)

热门文章

  1. matlab一维haar信号塔式分解,matlab小波分解与重构
  2. canfd收不到数据_CAN FD网络的通信距离问题分析
  3. 英文词根词典简化笔记
  4. Bus Hound 的使用方法
  5. 使用docx4j解析word模板,替换占位符生成新的docx,并生成pdf
  6. docx4j文档差异比较
  7. 如何准备互联网产品岗面试
  8. 2022年深圳杯数学建模A题代码思路-- 破除“尖叫效应”与“回声室效应”,走出“信息茧房”
  9. 医院绩效考核和奖金分配方案
  10. 华为研发部门绩效考核制度及方案