理解对比表示学习(Contrastive Learning)
目录
- 一、前言
- 二、对比学习
- 三、主要论文(附代码分析)
- 1. AMDIM ([Bachman](https://arxiv.org/pdf/1906.00910.pdf) *et al.* 2019)
- 2. SIMCLR ([Geoffrey Hinton](https://arxiv.org/pdf/2002.05709.pdf) *et al* 2020)
- 3.MOCO ([Kaiming He](https://ieeexplore.ieee.org/document/9157636) *et al.* 2020)
- 四、总结
一、前言
监督学习近些年获得了巨大的成功,但是有如下的缺点:
- 人工标签相对数据来说本身是稀疏的,蕴含的信息不如数据内容丰富;
- 监督学习只能学到特定任务的知识,不是通用知识,一般难以直接迁移到其他任务中。
由于这些原因,自监督学习的发展被给予厚望。监督学习,无监督学习和自监督学习的区别
如果说自监督学习是蛋糕,那么监督学习就是蛋糕上的小冰块,强化学习就是蛋糕上点缀的樱桃。(“self-supervised learning is the cake, supervised learning is the icing on the cake, reinforcement learning is the cherry on the cake”) —Yann LeCun
自监督学习不需要人工标注的类别标签信息,直接利用数据本身作为监督信息,学习样本数据的特征表达,应用于下游的任务。自监督学习又可以分为对比学习(contrastive learning) 和 生成学习(generative learning) 两条主要的技术路线。对比学习的核心思想是讲正样本和负样本在特征空间对比,学习样本的特征表示,难点在于如何构造正负样本。
最近,诸如BERT和T5之类的自然语言处理模型已经表明,可以通过首先在一个大型的未标记数据集上进行预训练,然后在一个较小的标记数据集上进行微调,从而用很少的类标签来获得良好的结果。 同样,对未标记的大型图像数据集进行预训练,有可能提高计算机视觉任务的性能。这点已经在对比表示学习的相关论文,例如Exemplar-CNN, Instance Discrimination, CPC, AMDIM, CMC, MoCo,获得了证实。对比学习训练得到的神经网络模型,可以被用作下游的任务,例如分类、分割、检测等。经过对比学习预训练得到的神经网络,已经具有很强的表达能力,一般只需要再用很少的有标签数据微调,就可以获得非常优秀的性能。
以下图片引用
二、对比学习
对比学习首先学习未标记数据集上图像的通用表示形式,然后可以使用少量标记图像对其进行微调,以提升在给定任务(例如分类)的性能。简单地说,对比表示学习可以被认为是通过比较学习。相对来说,生成学习(generative learning)是学习某些(伪)标签的映射的判别模型然后重构输入样本。在对比学习中,通过在输入样本之间进行比较来学习表示。对比学习不是一次从单个数据样本中学习信号,而是通过在不同样本之间进行比较来学习。可以在“相似”输入的正对和“不同”输入的负对之间进行比较。以下图片引用。
对比学习通过同时最大化同一图像的不同变换视图(例如剪裁,翻转,颜色变换等)之间的一致性,以及最小化不同图像的变换视图之间的一致性来学习的。 简单来说,就是对比学习要做到相同的图像经过各类变换之后,依然能识别出是同一张图像,所以要最大化各类变换后图像的相似度(因为都是同一个图像得到的)。相反,如果是不同的图像(即使经过各种变换可能看起来会很类似),就要最小化它们之间的相似度。通过这样的对比训练,编码器(encoder)能学习到图像的更高层次的通用特征 (image-level representations),而不是图像级别的生成模型(pixel-level generation)。
Pixel-level generation is computationally expensive and may not be necessary for representation learning. —SimCLR论文
三、主要论文(附代码分析)
1. AMDIM (Bachman et al. 2019)
下图(b)就是本文的Augmented Multiscale Deep InfoMax (AMDIM)结构。
Encoder部分的核心代码如下:
class Encoder(nn.Module):def __init__(self, dummy_batch, num_channels=3, ndf=64, n_rkhs=512, n_depth=3, encoder_size=32, use_bn=False):super(Encoder, self).__init__()self.ndf = ndfself.n_rkhs = n_rkhsself.use_bn = use_bnself.dim2layer = None# encoding block for local featuresprint('Using a {}x{} encoder'.format(encoder_size, encoder_size))if encoder_size == 32:self.layer_list = nn.ModuleList([Conv3x3(num_channels, ndf, 3, 1, 0, False),ConvResNxN(ndf, ndf, 1, 1, 0, use_bn),ConvResBlock(ndf * 1, ndf * 2, 4, 2, 0, n_depth, use_bn),ConvResBlock(ndf * 2, ndf * 4, 2, 2, 0, n_depth, use_bn),MaybeBatchNorm2d(ndf * 4, True, use_bn),ConvResBlock(ndf * 4, ndf * 4, 3, 1, 0, n_depth, use_bn),ConvResBlock(ndf * 4, ndf * 4, 3, 1, 0, n_depth, use_bn),ConvResNxN(ndf * 4, n_rkhs, 3, 1, 0, use_bn),MaybeBatchNorm2d(n_rkhs, True, True)])elif encoder_size == 64:self.layer_list = nn.ModuleList([Conv3x3(num_channels, ndf, 3, 1, 0, False),ConvResBlock(ndf * 1, ndf * 2, 4, 2, 0, n_depth, use_bn),ConvResBlock(ndf * 2, ndf * 4, 4, 2, 0, n_depth, use_bn),ConvResBlock(ndf * 4, ndf * 8, 2, 2, 0, n_depth, use_bn),MaybeBatchNorm2d(ndf * 8, True, use_bn),ConvResBlock(ndf * 8, ndf * 8, 3, 1, 0, n_depth, use_bn),ConvResBlock(ndf * 8, ndf * 8, 3, 1, 0, n_depth, use_bn),ConvResNxN(ndf * 8, n_rkhs, 3, 1, 0, use_bn),MaybeBatchNorm2d(n_rkhs, True, True)])elif encoder_size == 128:self.layer_list = nn.ModuleList([Conv3x3(num_channels, ndf, 5, 2, 2, False, pad_mode='reflect'),Conv3x3(ndf, ndf, 3, 1, 0, False),ConvResBlock(ndf * 1, ndf * 2, 4, 2, 0, n_depth, use_bn),ConvResBlock(ndf * 2, ndf * 4, 4, 2, 0, n_depth, use_bn),ConvResBlock(ndf * 4, ndf * 8, 2, 2, 0, n_depth, use_bn),MaybeBatchNorm2d(ndf * 8, True, use_bn),ConvResBlock(ndf * 8, ndf * 8, 3, 1, 0, n_depth, use_bn),ConvResBlock(ndf * 8, ndf * 8, 3, 1, 0, n_depth, use_bn),ConvResNxN(ndf * 8, n_rkhs, 3, 1, 0, use_bn),MaybeBatchNorm2d(n_rkhs, True, True)])else:raise RuntimeError("Could not build encoder.""Encoder size {} is not supported".format(encoder_size))self._config_modules(dummy_batch, [1, 5, 7], n_rkhs, use_bn)def init_weights(self, init_scale=1.):'''Run custom weight init for modules...'''for layer in self.layer_list:if isinstance(layer, (ConvResNxN, ConvResBlock)):layer.init_weights(init_scale)for layer in self.modules():if isinstance(layer, (ConvResNxN, ConvResBlock)):layer.init_weights(init_scale)if isinstance(layer, FakeRKHSConvNet):layer.init_weights(init_scale)def _config_modules(self, x, rkhs_layers, n_rkhs, use_bn):'''Configure the modules for extracting fake rkhs embeddings for infomax.'''enc_acts = self._forward_acts(x)self.dim2layer = {}for i, h_i in enumerate(enc_acts):for d in rkhs_layers:if h_i.size(2) == d:self.dim2layer[d] = i# get activations and feature sizes at different layersself.ndf_1 = enc_acts[self.dim2layer[1]].size(1)self.ndf_5 = enc_acts[self.dim2layer[5]].size(1)self.ndf_7 = enc_acts[self.dim2layer[7]].size(1)# configure modules for fake rkhs embeddingsself.rkhs_block_1 = NopNet()self.rkhs_block_5 = FakeRKHSConvNet(self.ndf_5, n_rkhs, use_bn)self.rkhs_block_7 = FakeRKHSConvNet(self.ndf_7, n_rkhs, use_bn)def _forward_acts(self, x):'''Return activations from all layers.'''# run forward pass through all layerslayer_acts = [x]for _, layer in enumerate(self.layer_list):layer_in = layer_acts[-1]layer_out = layer(layer_in)layer_acts.append(layer_out)# remove input from the returned list of activationsreturn_acts = layer_acts[1:]return return_actsdef forward(self, x):'''Compute activations and Fake RKHS embeddings for the batch.'''if has_many_gpus():if x.abs().mean() < 1e-4:r1 = torch.zeros((1, self.n_rkhs, 1, 1),device=x.device, dtype=x.dtype).detach()r5 = torch.zeros((1, self.n_rkhs, 5, 5),device=x.device, dtype=x.dtype).detach()r7 = torch.zeros((1, self.n_rkhs, 7, 7),device=x.device, dtype=x.dtype).detach()return r1, r5, r7# compute activations in all layers for xacts = self._forward_acts(x)# gather rkhs embeddings from certain layersr1 = self.rkhs_block_1(acts[self.dim2layer[1]])r5 = self.rkhs_block_5(acts[self.dim2layer[5]])r7 = self.rkhs_block_7(acts[self.dim2layer[7]])return r1, r5, r7
2. SIMCLR (Geoffrey Hinton et al 2020)
核心代码如下:
from PIL import Image
from torchvision import transformstrain_transform = transforms.Compose([transforms.RandomResizedCrop(32),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),transforms.RandomGrayscale(p=0.2),transforms.ToTensor(),transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
第二步,由基编码器f(⋅)f(\cdot)f(⋅)得到表示hih_ihi,hjh_jhj。文章中作者使用ResNet-50作为卷积神经网络编码器。输出向量hih_ihi的维度是2048.
第三步,投影端(projection head) g(⋅)g(\cdot)g(⋅),主要由全连接层和激活层ReLU组成,将表示hih_ihi和hjh_jhj进一步非线性映射为ziz_izi,zjz_jzj。作者说,非线性投影端很重要,一方面可以将映射后的表达ziz_izi用来计算相似度,另一方面,可以让投影端之前的表达hih_ihi保留更多图像信息。
核心代码如下,包含基编码器f(⋅)f(\cdot)f(⋅)和投影端g(⋅)g(\cdot)g(⋅):
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50class Model(nn.Module):def __init__(self, feature_dim=128):super(Model, self).__init__()self.f = []for name, module in resnet50().named_children():if name == 'conv1':module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):self.f.append(module)# encoderself.f = nn.Sequential(*self.f)# projection headself.g = nn.Sequential(nn.Linear(2048, 512, bias=False), nn.BatchNorm1d(512),nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))def forward(self, x):x = self.f(x)feature = torch.flatten(x, start_dim=1)out = self.g(feature)return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)
第四步,训练网络,计算图像之间的相似度,再以此计算网络的交叉熵损失。
相似度:为了比较投影端产生的表示,使用余弦相似度,其定义为:
sim(u,v)=uTv∥u∥∥v∥(1)\operatorname{sim}(u, v)=\frac{u^{T} v}{\|u\|\|v\|} \tag{1}sim(u,v)=∥u∥∥v∥uTv(1)
损失函数:基于相似度,正对示例的损失函数定义为(与MOCO损失函数类似):
ℓi,j=−logexp(sim(zi,zj)/τ)∑k=12N1[k≠i]exp(sim(zi,zk)/τ)(2)\ell_{i, j}=-\log \frac{\exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}, \boldsymbol{z}_{j}\right) / \tau\right)}{\sum_{k=1}^{2 N} \mathbb{1}_{[k \neq i]} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}, \boldsymbol{z}_{k}\right) / \tau\right)} \tag{2}ℓi,j=−log∑k=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)(2)
其中,τ\tauτ被称为temperature parameter。该损失函数又称作normalized temperature-scaled cross-entropy loss。
import torch
from model import Model# train for one epoch to learn unique features
def train(net, data_loader, train_optimizer):net.train()total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)for pos_1, pos_2, target in train_bar:pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True)feature_1, out_1 = net(pos_1)feature_2, out_2 = net(pos_2)# [2*B, D]out = torch.cat([out_1, out_2], dim=0)# [2*B, 2*B]sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()# [2*B, 2*B-1]sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)# compute losspos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)# [2*B]pos_sim = torch.cat([pos_sim, pos_sim], dim=0)loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()train_optimizer.zero_grad()loss.backward()train_optimizer.step()total_num += batch_sizetotal_loss += loss.item() * batch_sizetrain_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))return total_loss / total_num
在对比学习任务中对SimCLR模型进行了训练之后,舍弃投影端g(⋅)g(\cdot)g(⋅),使用基编码器(base encoder) f(⋅)f(\cdot)f(⋅) 获得的图像的表示,将表示向量用于下游任务,例如ImageNet分类。
3.MOCO (Kaiming He et al. 2020)
本文认为,如果字典足够大,包含的负样本足够丰富 (large) 的话,可以学到更好的特征表达。与此同时,用于字典键值的编码器要在学习进化的过程中尽量保持一致 (consistent)。MOCO有两个核心模块:(1) 用队列实现字典,主要的作用是可以实现字典大小和mini-batch大小的耦合,如此便可不受限制地提高bath size;(2) 动量更新,主要是为了解决引入队列维护字典之后,字典的编码器无法通过梯度反传获得参数更新的问题,具体为:
θk←mθk+(1−m)θq(4)\theta_{\mathrm{k}} \leftarrow m \theta_{\mathrm{k}}+(1-m) \theta_{\mathrm{q}} \tag{4}θk←mθk+(1−m)θq(4)
如图©所示,通过这种动量更新的方法,可以从qqq 的梯度反向传播间接获得 kkk 的梯度。相对于直接用 qqq 的梯度更新替代 kkk 的梯度更新,这种动量更新的方式更加平稳。mmm 一般取0.99 ,如果取 0.90.90.9 会太小,实验效果不好 ,这说明 θq\theta_qθq 和 θk\theta_kθk 的耦合不宜过强。
下图总结对比了常用的三种负样本管理机制。图(a)是原始的end-to-end结构,最主要的问题是batch size和ditionary size相互耦合,ditionary size因此受限于GPU的内存大小。图(b)通过增加memory bank 结构,改进了图(a)的结构。memory bank可以存储数据集中所有样本的特征表达,每个字典随机地从memory bank中采样。但是从memory bank随机采样的问题是在不同的更新阶段,样本缺乏一致性,这就是MOCO反复强调的consistent问题。
核心代码如下 (pytorch伪代码):
# f_q, f_k: encoder networks for query and key
# queue: dictionary as a queue of K keys (CxK)
# m: momentum
# t: temperature
f_k.params = f_q.params # initialize
for x in loader: # load a minibatch x with N samplesx_q = aug(x) # a randomly augmented versionx_k = aug(x) # another randomly augmented versionq = f_q.forward(x_q) # queries: NxCk = f_k.forward(x_k) # keys: NxCk = k.detach() # no gradient to keys# positive logits: Nx1l_pos = bmm(q.view(N,1,C), k.view(N,C,1))# negative logits: NxKl_neg = mm(q.view(N,C), queue.view(C,K))# logits: Nx(1+K)logits = cat([l_pos, l_neg], dim=1)# contrastive loss, Eqn.(1)labels = zeros(N) # positives are the 0-thloss = CrossEntropyLoss(logits/t, labels)# SGD update: query networkloss.backward()update(f_q.params)# momentum update: key networkf_k.params = m*f_k.params+(1-m)*f_q.params# update dictionaryenqueue(queue, k) # enqueue the current minibatchdequeue(queue) # dequeue the earliest minibatch# bmm: batch matrix multiplication; mm: matrix multiplication; cat: concatenation.
最后,附上一张现有主流对比学习模型的性能图,来自SIMCLR论文。
四、总结
本文介绍了自监督学习中的一种重要方法–对比学习(contrastive learning)的基本概念和三篇代表性的最新论文,并且从模型创新点和代码实现角度进行了分析。对比学习是当前自监督学习一个重要的分支,目的在于从小样本无标签的数据中,学习到更有效的特征表达。目前的研究进展表明,自监督学习正在逐步逼近监督学习的水平。在很多场景中,例如医学影像分析,有标签的数据极其稀有。用自监督学习进行表示学习和预训练,将会是重要的一环。
本博客撰写过程参考了以下博客内容:
- Google AI blog
- SimCLR Post
- 对比学习(Contrastive Learning)相关进展梳理
- 对比学习(Contrastive Learning)
- A Framework For Contrastive Self-Supervised Learning And Designing A New Approach
理解对比表示学习(Contrastive Learning)相关推荐
- 从对比学习(Contrastive Learning)到对比聚类(Contrastive Clustering)
从对比学习(Contrastive Learning)到对比聚类(Contrastive Clustering) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailug ...
- (ICML-2020)通过超球面的对齐和均匀性理解对比表示学习(一)
文章目录 通过超球面的对齐和均匀性理解对比表示学习 Abstract 1. Introduction 2. Related Work 3.无监督对比表征学习的初步研究 4. Feature Distr ...
- (ICML-2020)通过超球面的对齐和均匀性理解对比表示学习(二)
文章目录 通过超球面的对齐和均匀性理解对比表示学习 5. Experiments 6. Discussion 通过超球面的对齐和均匀性理解对比表示学习 paper题目:Understanding Co ...
- ICML 2020: 从Alignment 和 Uniformity的角度理解对比表征学习
Title: <Understanding Contrastive Representation Learning through Alignment and Uniformity on the ...
- 对比学习Contrastive Learning
对比学习是一种常用的自监督学习方法. 核心思想:把正样本距离拉近,正样本与负样本距离拉远.(类似度量学习中的margin, 但是对比学习为正负样本分类,无margin概念) 方式:通过一个正样本,以及 ...
- 图对比学习入门 Contrastive Learning on Graph
对比学习作为近两年的深度学习界的一大宠儿,受到了广大研究人员的青睐.而图学习因为图可以用于描述生活中广泛出现的非欧式数据,具有广大的应用前景.当图学习遇上了对比学习- 本文从对比学习入手,再介绍图对比 ...
- Contrastive Learning Based on Transformer for Hyperspectral Image Classification
自娱自乐对比学习高光谱图像分类第二篇 1. Introduction 在高光谱图像分类中 3D 比 2D 好?不知道这句话怎么得来的. 无监督学习: 表示学习(AE, GAN) 判别学习 – cont ...
- 对比学习(Contrastive Learning)的理解
参考网址:https://blog.csdn.net/yyhaohaoxuexi/article/details/113824125 一.Info Noise-contrastive estimati ...
- 对比学习(Contrastive Learning)综述
A.引入 https://zhuanlan.zhihu.com/p/346686467 A.引入 深度学习的成功往往依赖于海量数据的支持,其中对于数据的标记与否,可以分为监督学习和无监督学习. 1 ...
最新文章
- HA: Dhanush靶机渗透测试
- 武汉.NET俱乐部论坛已经恢复
- Python轻量级IDE推荐 -- Jupyter QTConosle
- c语言内存拷贝 memcpy()函数
- 虚拟机找不到共享文件夹
- abap 生成流水号每天从1开始_条码软件如何制作循环流水号
- 使用JPA和Spring 3.1进行事务配置
- Web前端笔记-使用@media(媒体查询)展示及隐藏div
- 再谈节奏与动力---平淡与枯燥的力量
- 用WAP手机远程遥控电脑1
- C++ 析构函数不要抛出异常
- java swing取消按钮_在Java Swing中取消选择单选按钮
- 《概率论与数理统计》
- java txt导出_Java导出txt文件的方法
- 【数仓】大数据领域建模综述-《大数据之路》读书笔记
- 鸿蒙系统可以上外网吗,【图片】华为鸿蒙系统的厉害之处在于 你可能非用不可
!【手机吧】_百度贴吧...
- java.sql.SQLNonTransientConnectionException: Data source rejected establishment of connection, messa
- 大数据时代,做大数据开发要学Java框架吗?
- html把图片做成导航条背景,DIV+CSS背景图片导航菜单的实现方法
- 平面设计中的网格系统pdf_【200421】平面设计必看书籍超越平凡的设计平面设计中的网格系统等|电子书资源免费分享...