写在前面

最近阅读了一篇CVPR上关于联邦学习的文章(将对比学习的思想融入到联邦学习中),作者是新加坡国立大学的Qinbin Li(博士生,导师 何炳胜),Bingsheng He(何炳胜教授,导师 宋晓东)以及加州大学伯克利分校的Dawn Song(宋晓东教授,论文总引用量7万+)。

  • 论文一作个人主页: https://qinbinli.com
  • 论文链接:CVPR版本,Arxiv版本
  • 代码:https://github.com/QinbinLi/MOON
  • 会议:CVPR 2021

CVPR作为计算机视觉领域的顶级会议(CCF-A),目前有4篇联邦学习相关的论文

  1. Multi-Institutional Collaborations for Improving Deep Learning-Based Magnetic Resonance Image Reconstruction Using Federated Learning
  2. Model-Contrastive Federated Learning
  3. FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space
  4. Soteria: Provable Defense Against Privacy Leakage in Federated Learning From Representation Perspective

今天要介绍的就是其中一篇论文《Model-Contrastive Federated Learning》

一、Motivation

  • 联邦学习的关键挑战是客户端之间数据的异质性(Non-IID),尽管已有很多方法(例如FedProx,SCAFFOLD)来解决这个问题,但是他们在图像数据集上的效果欠佳(见实验Table1)。
  • 传统的对比学习是data-level的,本文改进了FedAvg的本地模型训练阶段,提出了model-level的联邦对比学习(Model-Contrastive Federated Learning)
  • 作者从NT-Xent loss中获得灵感,提出了model-contrastive loss。model-contrastive loss可以从两方面影响本地模型 1. 本地模型能够学到接近于全局模型的representation 2. 本地模型可以学到比上一轮本地模型更好的representation

简单来说,作者在本地模型训练的时候加了个model-contrastive loss,使得在Non-IID的图片数据集上训练的联邦学习模型效果很好。

二、背景知识

联邦学习FedAvg训练过程

本文主要针对客户端本地训练阶段进行了改进(说白了就是加了个loss)。

对比学习SimCLR

对比学习的基本想法是同类相聚,异类相离

从不同的图像获得的表征应该相互远离,从相同的图像获得的表征应该彼此靠近

上图来自blog

这个想法是凭直觉获知的,但是已经被证明效果很好

SimCLR是对比学习中经典的方法。

每次采样N=128张图片,对这128张图片做两次augmentation,所以输入图片数量其实是256,然后把同一张图片的两个augmentation当作一对正样本xix_ixi​, xjx_jxj​,计算l(i,j)l(i,j)l(i,j)时,iii是锚点,分子是正样本对(xi,xj)(x_i,x_j)(xi​,xj​),分母是正样本对(xi,xj)(x_i,x_j)(xi​,xj​) + 2N-2个负样本对(xi,xk)(x_i,x_k)(xi​,xk​),其中k≠i,jk\neq i, jk=i,j

常用NT-Xent loss(the normalized temperature-scaled cross entropy loss)
li,j=−log⁡exp⁡(sim⁡(xi,xj)/τ)∑k=12NI[k≠i]exp⁡(sim⁡(xi,xk)/τ)l_{i, j}=-\log \frac{\exp \left(\operatorname{sim}\left(x_{i}, x_{j}\right) / \tau\right)}{\sum_{k=1}^{2 N} \mathbb{I}_{[k \neq i]} \exp \left(\operatorname{sim}\left(x_{i}, x_{k}\right) / \tau\right)} li,j​=−log∑k=12N​I[k=i]​exp(sim(xi​,xk​)/τ)exp(sim(xi​,xj​)/τ)​

SimCLR伪代码 paper

Preliminary Experiment

本文基于这样一个直观的想法来解决Non-IID问题:

the model trained on the whole dataset is able to extract a better feature representation than the model trained on a skewed subset.

作者在CIFAR-10做了个实验,来验证他的这种直觉。

做法:用t-SNE可视化训练好的CNN模型在测试集上获得的隐藏层的表征向量。

  • 2a:用所有数据集放在一起训练一个CNN模型。
  • 2b:将所有数据集以Non-IID的方式划分10个客户端,各自训练CNN模型,最后随机选择一个客户端的模型。
  • 2c:在10个客户端上使用FedAvg算法训练得到一个global model(10个本地模型加权平均)
  • 2d:在10个客户端上使用FedAvg算法训练,然后随机选择一个客户端的local model。(2d学习到的蓝色的类别表征明显比2c差)

通过T-SNE可视化表征向量,证实了如下观点:全局模型应该要比本地模型的性能好(全局模型能学到一个更好的表征),因此在non-iid的场景下,我们应该控制这种drift以及处理好由全局模型和本地模型学到的表征。

三、方法:MOON

问题定义

MOON的目标

Since there is always drift in local training and the global model learns a better representation than the local model, MOON aims to decrease the distance between the representation learned by the local model and the representation learned by the global model, and increase the distance between the representation learned by the local model and the representation learned by the previous local model.

MOON的loss函数

MOON在本地训练阶段,会有三个表征(representation)

  • zprev=Rwit−1(x)z_{prev}=R_{w_i^{t-1}}(x)zprev​=Rwit−1​​(x)(上一轮本地训练好的发往server的模型得到的表征)固定
  • zglob =Rwt(x)z_{\text {glob }}=R_{w^{t}}(x)zglob ​=Rwt​(x)(这轮开始时发送到本地的全局模型得到的表征)固定
  • z=Rwit(x)z=R_{w_i^{t}}(x)z=Rwit​​(x) (这轮正在被更新的本地模型得到的表征)不断被更新

With model weight www,Rw(⋅)R_w(·)Rw​(⋅) to denote the network before the output layer (i.e., Rw(x)R_w (x)Rw​(x) is the mapped representation vector of input x).

我们的目标是让zzz靠近zglob z_{\text {glob }}zglob ​(固定),让zzz远离zprevz_{\text {prev}}zprev​(固定)。

我们的本地模型训练时的loss有两部分组成:传统的交叉熵损失lsup\mathcal{l}_{sup}lsup​以及本文提出的model-contrastive loss(lcon\mathcal{l}_{con}lcon​)

类似对比学习中的NT-Xent loss,我们定义model-contrastive loss

ℓcon =−log⁡exp⁡(sim⁡(z,zglob )/τ)exp⁡(sim⁡(z,zglob )/τ)+exp⁡(sim⁡(z,zprev )/τ)\ell_{\text {con }}=-\log \frac{\exp \left(\operatorname{sim}\left(z, z_{\text {glob }}\right) / \tau\right)}{\exp \left(\operatorname{sim}\left(z, z_{\text {glob }}\right) / \tau\right)+\exp \left(\operatorname{sim}\left(z, z_{\text {prev }}\right) / \tau\right)} ℓcon ​=−logexp(sim(z,zglob ​)/τ)+exp(sim(z,zprev ​)/τ)exp(sim(z,zglob ​)/τ)​
其中τ\tauτ为温度系数,分子是正样本对(z,zglob)(z, z_{\text {glob}})(z,zglob​),分母是正样本对(z,zglob)(z, z_{\text {glob}})(z,zglob​)+负样本对(z,zprev)(z, z_{\text {prev}})(z,zprev​)


MOON的优化目标(loss)如下:
ℓ=ℓsup (wit;(x,y))+μℓcon (wit;wit−1;wt;x)\ell=\ell_{\text {sup }}\left(w_{i}^{t} ;(x, y)\right)+\mu \ell_{\text {con }}\left(w_{i}^{t} ; w_{i}^{t-1} ; w^{t} ; x\right)ℓ=ℓsup ​(wit​;(x,y))+μℓcon ​(wit​;wit−1​;wt;x)


The network has three components: a base encoder, a projection head, and an output layer.

MOON伪代码

和FedAvg相比,MOON只在客户端本地训练过程中添加了lconl_{con}lcon​项


SimCLR和MOON

作者还对比了下SimCLR和MOON框架

  • SimCLR是想让同一张图片(数据层面)的不同view的表征ziz_izi​和zjz_jzj​最大程度地相近
  • MOON是想让全局模型和本地模型的参数(模型层面)对应的表征zglobz_{glob}zglob​和zlocalz_{local}zlocal​最大程度地相近。

作者还提到,理想情况下(IID),全局模型和本地模型训练得到的表征应该是一样好的,那么lconl_{con}lcon​是一个常数,此时会得到FedAvg一样的效果。在这种意义上,MOON比FedAvg更具鲁棒性(能处理Non-IID的情况)

四、实验

数据集

Image classification datasets:CIFAR-10, CIFAR-100, and Tiny-Imagenet

作者通过实验展示了在数据集Non-IID的情况下FedProx,SCAFFOLD这些方法应用到图片数据集的效果会大打折扣,甚至和FedAvg一样差。


SOLO表示每个客户端只利用自己本地数据训练模型


五、总结

本文从对比学习中常用的NT-Xent loss中获得灵感,提出了联邦模型对比学习MOON。

一句话总结:作者在联邦学习本地模型训练的时候加了个model-contrastive loss,使得在Non-IID的图片数据集上训练的联邦学习模型效果很好。


嘿嘿,完结撒花❀

【CVPR 2021联邦学习论文解读】Model-Contrastive Federated Learning (MOON) 联邦学习撞上对比学习相关推荐

  1. 【论文解读】CVPR 2021 妆容迁移 论文解读Spatially-invariant Style-codes Controlled Makeup Transfer

    [论文解读]CVPR 2021 妆容迁移 论文解读 Spatially-invariant Style-codes Controlled Makeup Transfer 摘要 方法特点 实现方法 公式 ...

  2. 今晚直播 | AAAI 2022论文解读:重新思考图像融合策略和自监督对比学习

    「AI Drive」是由 PaperWeekly 和 biendata 共同发起的学术直播间,旨在帮助更多的青年学者宣传其最新科研成果.我们一直认为,单向地输出知识并不是一个最好的方式,而有效地反馈和 ...

  3. 联邦学习-论文阅读-NDSS-FLTrust: Byzantine-robust Federated Learning via Trust Bootstrapping

    1.FLTrust: Byzantine-robust Federated Learning via Trust Bootstrapping 1.概要 拜占庭式的鲁棒联邦学习方法中没有信任的根(即不知 ...

  4. 华人占大半壁江山!CVPR 2021 目标检测论文大盘点(65篇论文)

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Amusi  |  来源:CVer 前言 CVer 正式盘点CVPR 2021上各个方向的工作,本篇是 ...

  5. 【强化学习论文解读 1】 NAF

    [强化学习论文解读 1] NAF 1. 引言 2. 论文解读 2.1 背景 2.2 NAF算法原理 2.3 Imagination Rollouts方法 3. 总结 1. 引言 本文介绍一篇2016年 ...

  6. 最新!CVPR 2021 医学图像分割论文大盘点(5篇论文)

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Amusi  |  来源:CVer 前言 昨天分享了MICCAI 2021上Transformer+医 ...

  7. 知识图谱-生物信息学-医学顶刊论文(Bioinformatics-2022)-SGCL-DTI:用于DTI预测的监督图协同对比学习

    14.(2022.5.21)Bioinformatics-SGCL-DTI:用于DTI预测的监督图协同对比学习 论文标题: Supervised graph co-contrastive learni ...

  8. Zero-shot Learning零样本学习 论文阅读(一)——Learning to detect unseen object classes by between-class attribute

    Zero-shot Learning零样本学习 论文阅读(一)--Learning to detect unseen object classes by between-class attribute ...

  9. 【阅读笔记】Towards Personalized Federated Learning个性化联邦综述

    文章目录 前言 1 背景 1.1 机器学习.联邦学习 1.2 促进个性化联邦学习的动机 2 个性化联邦学习的策略 2.1 全局模型个性化 2.1.1 基于数据的方法 2.1.1.1 数据增强 Data ...

  10. 【论文解读】CVPR 2021 妆容迁移 论文+ 代码 汇总,美得很美得很!

    妆容迁移是指将目标图上的妆容直接迁移到原图上的技术.相比传统贴妆技术,妆容迁移具有极高的自由度,它可以让用户不再局限于设计师设计好的妆容,而是可以自主.任意地从真实模特图中获取妆容,极大地丰富了妆容的 ...

最新文章

  1. Exchange Server 2013之CAS服务器NLB负载均衡
  2. node debug包
  3. 人脸识别 轻量级高精度网络推荐
  4. 14-jQuery补充
  5. Echarts使用笔记
  6. JavaFX 架构与框架 (译)
  7. createprocess重启程序_C++_VC程序设计中CreateProcess用法注意事项,对于windows程序设计来说,启动 - phpStudy...
  8. 当杯子中的空气被抽走会发生什么?
  9. python自动化运维书籍推荐_《Python 自动化运维:技术与最佳实践》
  10. 美国对特斯拉“幽灵刹车”问题展开调查 涉及41.6万辆Model 3/Y
  11. 2012浙大878计算机专业基础综合大题答案解析
  12. 解决 ‘Response‘ object has no attribute ‘body‘
  13. ubuntu mysql 操作_Ubuntu系统下MySQL数据库基本操作
  14. 北京思源培训中心---C#下用P2P技术实现点对点聊天(2)
  15. linux桌面版本安装MSDM,Parallel_s desktop怎么安装linux系统
  16. Matlab常用的标记符号和颜色
  17. redis(版本redis-5.0.2)的安装步骤
  18. c语言碰撞的小球,小球碰撞(完全弹性碰撞)
  19. java毕业设计学生考勤系统Mybatis+系统+数据库+调试部署
  20. 精读:理论与实践融合 学者与干将统一

热门文章

  1. dll封装成activex控件_Qt编写自定义控件26-平铺背景控件
  2. docker privileged作用_docker容器性能监控cAdvisor+influxDB+grafana监控系统安装部署
  3. head first html与css 代码_手把手教你使用Flask轻松部署机器学习模型(附代码amp;链接) | CSDN博文精选...
  4. 零基础多久能学会python_零基础小白多久能学会python
  5. layui中table表格的操作列(删除,编辑)等按钮的操作
  6. java.util.Arrays$ArrayList addAll报错
  7. netstat 查看网络状态
  8. 总结---JavaScript数组
  9. 实现multbandblend
  10. word-break: break-all与word-wrap:break-word的区别