SimSiam

Abstract

模型坍塌,在siamese中主要是输入数据经过卷积激活后收敛到同一个常数上,导致无论输入什么图像,输出结果都能相同。

而He提出的simple Siamese networks在没有采用之前的避免模型坍塌那些方法:

  • 使用负样本
  • large batches
  • momentum encoders(论文直接用的encoder)

实验表明对于损失和结构确实存在坍塌解,但stop-gradient操作在防止坍塌方面起着至关重要的作用。

Method

如图为simsiam 的结构,输入是训练集中随机选取的一个图像,使用随机数据增强生成两个图像;左右两个encoder是完全一样的,包含卷积和全连接,将图像进行编码(特征提取);perdictor 是一般的MLP,左右都是有predictor模块的(看伪代码),只右侧是没画出来,用来转换视图的输出,并将其与另一个视图相匹配,(encoder是一样的,x1和x2即使经过数据增强大小也是一样的,那为啥要再加一个predictor模块使两个视图相匹配呢?);

similarity是对比predictor输出的特征向量,loss为经过encoder的p和predictor的输出z,p1和z2对比,p2和z1的负余弦相似度 如 D ( p 1 , z 2 ) = − p 1 ∣ ∣ p 1 ∣ ∣ 2 z 2 ∣ ∣ z 2 ∣ ∣ 2 D(p_1,z_2)=-\frac{p_1}{||p_1||_2} \frac{z_2}{||z_2||_2} D(p1​,z2​)=−∣∣p1​∣∣2​p1​​∣∣z2​∣∣2​z2​​ (论文中说这个与l2正则化的mse相同?)

总的网络的loss 为 L = D ( p 1 , z 2 ) / 2 + D ( p 2 , z 1 ) / 2 L=D(p_1, z_2)/2 + D(p_2, z_1)/2 L=D(p1​,z2​)/2+D(p2​,z1​)/2

# f: backbone + projection mlp
# h: prediction mlp
for x in loader: # load a minibatch x with n samplesx1, x2 = aug(x), aug(x) # random augmentation对图像进行随机数据增强,这样就生成 z1, z2 = f(x1), f(x2) # projections, n-by-d encodeer的计算p1, p2 = h(z1), h(z2) # predictions, n-by-d predictor的计算L = D(p1, z2)/2 + D(p2, z1)/2 # loss  两个向量的负余弦相似度L.backward() # back-propagateupdate(f, h) # SGD update
def D(p, z): # negative cosine similarityz = z.detach() # stop gradientp = normalize(p, dim=1) # l2-normalizez = normalize(z, dim=1) # l2-normalizereturn -(p*z).sum(dim=1).mean()

在backward()时,如果y是标量,则不需要为backward()传入任何参数;否则,需要传入一个与y同形的Tensor。

如果不想要被继续追踪,可以调用.detach()将其从追踪记录中分离出来,这样就可以防止将来的计算被追踪,这样梯度就传不过去了。还可以用with torch.no_grad()将不想被追踪的操作代码块包裹起来,这种方法在评估模型的时候很常用,因为在评估模型时,我们并不需要计算可训练参数(requires_grad=True)的梯度。

上面将z给detach了, z 2 ∣ ∣ z 2 ∣ ∣ 2 \frac{z_2}{||z_2||_2} ∣∣z2​∣∣2​z2​​所以会被看成为常数只有 p 1 ∣ ∣ p 1 ∣ ∣ 2 \frac{p_1}{||p_1||_2} ∣∣p1​∣∣2​p1​​会产生梯度,

为了进一步确认那一部分的设计在本文的框架中是至关重要的,作者设计了以下的消融实验。


Empirical Study

stop grad

显然如果使两侧的梯度都进行传递网络的loss是非常小的,因为两个网络的参数是接近一模一样的所以两个网络很容易就达到一致了。而且这样的性能表现是非常差的,因为很容易达到两个网络参数一样,最后导致模型坍塌。实际上并不能学到什么有效的特征。


使用不同的predictor的结果

如果没有predictor模型不work(原因作者没说);

如果预测MLP头模块h固定为随机初始化,该模型同样不再有效,这是因为模型不收敛,loss太高;

当预测MLP头模块采用常数学习率时,该模型甚至可以取得比基准更好的结果,作者也提出了一个可能的解释:h应当适应最新的表征,所以不需要在表征充分训练之前使用降低学习率的方法迫使其收敛。

不同Batch Size

探究了不同的batch对精度的影响,虽然基础 l r lr lr是0.05,但是学习率会随着batch的变化做线性缩放 l r × B a t c h S i z e / 256 lr×BatchSize/256 lr×BatchSize/256 ,对于batch大于1024时,会采用10个epoch的warm-up学习率。

作者探究了SGD在较大batch上会导致性能退化,但同时也证明了优化器不是防止崩溃的必要条件。


Batch Normalization

移除BN之后可能因为难优化造成了性能下降,但是并没有造成collapsing,只加在隐层精度会提高到67.4%,如果在投影MLP中也加上BN则会提升到68.1%。但是如果把BN加到预测MLP上,就不work了,作者探究了这也不是崩溃问题,而是训练不稳定,loss震荡。

总结下来就是,BN在监督学习和非监督学习中都会使模型易于优化,但是并不能防止collapsing。


Similarity Function

除了余弦相似函数之外,该方法在交叉熵相似函数下也work,这里的softmax是channel维度的,softmax的输出可以认为是属于d个类别中每个类别的概率。

可以看出使用交叉熵相似性依然可以很好地收敛,并没有崩溃,所以避免collapsing与余弦相似性无关。

结果比较

如下图7所示,SimSiam小的batch和没有负样本、momentum encoder的情况下仍然能取得较好的效果。

Hypothesis

为什么这样简单的网络能够work呢?作者提出了一种猜想:SimSiam实际上是一种Expectation-Maximization(EM)的算法。——最大期望算法。

我们最熟悉的最大期望算法就是k-means算法。

L ( θ , η ) = E x , T [ ∥ F θ ( T ( x ) ) − η x ∥ 2 2 ] L(\theta,\eta)=\mathbb{E}_{x,\mathcal{T} }[\|\mathcal{F} _\theta(\mathcal{T}(x)) - \eta_x\|_2 ^2 ] L(θ,η)=Ex,T​[∥Fθ​(T(x))−ηx​∥22​]

这里x输入图像 T \mathcal{T} T是图像的一种增强, F θ \mathcal{F} _\theta Fθ​是encoder, η x \eta _x ηx​不一定局限于图像表征,在训练网络时我们希望找到一个 θ \theta θ,找到一个 η \eta η,使得loss的期望是最小的。

在每一步中首先会确定一个 θ \theta θ使得 loss 最小,这时使用的是一个固定的 η \eta η,从而得到 θ t \theta^t θt

θ t ← arg ⁡ min ⁡ θ L θ η t − 1 \theta^t \gets \mathop{\arg\min}_{\theta} \mathcal{L}\theta\eta^{t-1} θt←argminθ​Lθηt−1(公式 2)

锁定 θ \theta θ,寻找一个使 loss 达到最小的 η \eta η

η t ← arg ⁡ min ⁡ η L ( θ t \eta^t \gets \mathop{\arg \min}_\eta \mathcal{L}(\theta^t%2C \eta ηt←argminη​L(θt))

反复进行以上两步最终使训练得到一个满意的结果。

对比学习-SimSiam-论文精读解析相关推荐

  1. 自监督对比学习系列论文(二):有引导对比学习--SCCL,SwAV,PCL,SupervisedCon

    自监督对比学习  本篇承接前一篇自监督对比学习的介绍自监督对比学习系列论文(一):无引导对比学习–MOCO,SimCLR,这次主要关注于包含先验指导的对比学习,这一指导更进一步的可以区分为聚类指导以及 ...

  2. 对比学习 ——simsiam 代码解析。

    Python微信订餐小程序课程视频 https://blog.csdn.net/m0_56069948/article/details/122285951 Python实战量化交易理财系统 https ...

  3. 对比学习 ——simsiam 代码解析。:

    目录 1 : 事先准备 . 2 : 代码阅读. 2.1: 数据读取 2.2: 模型载入 3 训练过程: 4 测试过程: 5 :线性验证 6 : 用自己数据集进行对比学习. 第一:  改数据集 : 2 ...

  4. 对比学习系列论文SDCLR(二)-Self-Damaging Contrastive Learning

    目录 0.Abstract 0.1逐句翻译 0.2总结 1. Introduction 1.1. Background and Research Gaps 1.1.1逐句翻译 第一段(引出对比学习是否 ...

  5. 对比学习系列论文MoCo v1(二):Momentum Contrast for Unsupervised Visual Representation Learning

    0.Abstract 0.1逐句翻译 We present Momentum Contrast (MoCo) for unsupervised visual representation learni ...

  6. 对比学习系列论文CPCforHAR(一):Contrastive Predictive Coding for Human Activity Recognition

    0.Abusurt 0.1逐句翻译 Feature extraction is crucial for human activity recognition (HAR) using body-worn ...

  7. 对比学习系列论文CPC(二)—Representation Learning with Contrastive Predictive Coding

    0.Abstract 0.1逐句翻译 While supervised learning has enabled great progress in many applications, unsupe ...

  8. 对比学习simSiam(一)--Exploring Simple Siamese Representation Learning总体理解

    1.从名字上把握 sim是我们熟知的相似的那个单词,这个Siam是孪生的意思,这里使用这个来命名应该是为了指出孪生的重要性.这里的核心其实是在提出一个思想,对比学习这种由孪生网络结构构成的无监督学习的 ...

  9. 对比学习系列论文SDCLR(一)-Self-Damaging Contrastive Learning论文的概括

    1.研究背景(研究的问题) 一切的大背景是对比学习大发展 1.传统深度学习当中就存在这个问题. 2.虽然之前的工作指出,对比学习受到长尾问题影响比较小:但是作者实验发现长尾问题可能对对比学习影响更大. ...

最新文章

  1. 组合,多态,封装, @property
  2. 1.4 @SuppressWarnings:抑制编译器警告
  3. 笔记-中项案例题-2017年上-计算题
  4. ASP.NET的MVC请求处理流程
  5. 后台管理,有无限可能
  6. SQL实战之查找所有员工入职时候的薪水情况
  7. element时间范围选择添加限制条件
  8. SSM整合时IDE: File is included in 4 contexts
  9. EditPlus Version 3 价格 代理商 销售价格 正版软件价格
  10. 分享一款WIFI抓包工具,支持Mac和Windows
  11. Excel常用函数——count
  12. 来客电商之微信小程序怎么取名字
  13. 在sagemath中运行python文件
  14. 第4章_1——SQL语句实现MySQL增删改查
  15. *p++,*(p++),*++p,(*p)++区别?
  16. folly库安装(5)folly的安装
  17. C语言实战篇-----调试关键参数+printf输出_文件名_函数名_执行数!!!
  18. 图神经网络:GAT学习、理解、入坑
  19. iOS证书申请打包上传App Store审核完整流程(7个步骤)
  20. 猿课,linux系统精讲

热门文章

  1. 解密Google Deepmind AlphaGo围棋算法:真人工智能来自于哪里?
  2. 笑死:Welcome to Skip Thompson's Homepage
  3. Android lunch分析以及产品分支构建
  4. 怎么撰写一份优秀的数据分析报告(五)
  5. web版拳皇,使用html,css,js来制作一款拳皇游戏
  6. 【STM32F407的DSP教程】第28章 FFT和IFFT的Matlab实现(幅频响应和相频响应)
  7. Linux-alias设置命令别名
  8. 谈及区块链,我们脑海当中首先浮现出来的是,狂热、浮躁的场景
  9. 损失函数、代价(成本)函数、目标函数
  10. 详解modprobe的用法