SCAFFOLD: Stochastic Controlled Averaging for Federated Learning
背景
传统的联邦学习在数据异构(non-iid)的场景中很容易产生“客户漂移”(client-drift)的现象,这会导致系统的收敛不稳定或者缓慢
贡献
提出了考虑到client sampling和数据异构的一个更接近的收敛边界
证明即便没有client sampling,使用全批次梯度(full batch gradients),传统的FedAvg依旧会因为client-drift而比SGD收敛速度更慢
提出Stochastic Controlled Averaging algorithm(SCAFFOLD),目的便是为了解决client-drift的问题,并证明了SCAFFOLD算法在数据异构的情况下收敛速度至少和SGD一样快
SCAFFOLD算法还可以利用client之间的相似度来减少通信开销
证明了SCAFFOLD算法不会被client sampling所影响,这使得SCAFFOLD算法更适合联邦学习
论文思想
传统联邦学习的方法FedAvg算法在异构数据集上表现不好的原因是有一些client会带偏整个系统的收敛结果,如下图所示:
在上图中,黑色点是全局模型,也就是每个训练轮次各个局部模型的“训练起点”,假设在某一轮训练中,服务器选择了client1和client2两个客户端来训练,然后client1是偏离整个系统的客户端,那么在客户端上训练三个轮次中,我们可以看到client1上的局部模型已经偏离了训练的方向(x∗x^*x∗所在的方向),然后聚合得到的server model也会稍微偏离x∗x^*x∗,使得系统向着偏离学习模型的方向上收敛。最终的结果不是造成整个系统的性能下降就是导致整个系统收敛缓慢
为了解决这个问题,论文使用一个“控制变量”(control variate)ccc来“纠正”系统训练的方向,在client对模型进行更新的时候,也会对该变量进行更新
算法
与传统的联邦学习类似,SCAFFOLD算法也分为三个主要的部分:
- 局部更新模型(local updates to the client model)
- 局部更新控制变量(local updates to the client control variate)
- 对局部的更新进行聚合
先给出算法的流程,后面再做出解释:
算法具体流程:
- 服务器输入:初始化的xxx和ccc,全局更新步长ηg\eta_gηg
- 客户端输入:控制变量cic_ici和局部更新步长ηl\eta_lηl
- 在每一个全局轮次
- 对clients进行sample
- 对于每一个client iii
- 发送全局模型xxx和控制变量ccc
- 用全局模型xxx来初始化yiy_iyi
- 在每一个局部轮次:
- 计算mini-batch梯度gi(yi)g_i(y_i)gi(yi)
- 用yi←yi−ηl(gi(yi)−ci+c)y_i \leftarrow y_i - \eta_l(g_i(y_i)-c_i+c)yi←yi−ηl(gi(yi)−ci+c)进行更新
- 局部训练完毕之后,得到新的控制变量ci+←(i)gi(x)or(ii)ci−c+1Kηl(x−yi)c_i^+ \leftarrow (i)g_i(x)or(ii)c_i - c + \frac{1}{K\eta_l}(x-y_i)ci+←(i)gi(x)or(ii)ci−c+Kηl1(x−yi)
- 上传梯度yi−xy_i-xyi−x和ci+−cic_i^+-c_ici+−ci
- 更新控制变量ci←ci+c_i\leftarrow c_i^+ci←ci+
- 对上传的梯度进行聚合(Δx,Δc)←1∣S∣∑i∈S(Δyi,Δci)(\Delta x,\Delta c)\leftarrow \frac{1}{|S|}\sum_{i\in S}(\Delta y_i,\Delta c_i)(Δx,Δc)←∣S∣1∑i∈S(Δyi,Δci)
- 全局更新:x←x+ηgΔxx\leftarrow x + \eta_g \Delta xx←x+ηgΔx和c←c+∣S∣NΔcc\leftarrow c + \frac{|S|}{N}\Delta cc←c+N∣S∣Δc
局部更新方式
SCAFFOLD算法在局部的更新方式是:
yi←yi−ηl(gi(yi)−ci+c)y_i \leftarrow y_i - \eta_l(g_i(y_i)-c_i+c) yi←yi−ηl(gi(yi)−ci+c)
其中控制变量ccc的作用很明显,便是用全局模型的知识去约束局部模型的训练,以防止其偏离系统的正确训练方向,如下图所示:
并且该控制变量也会更新,以下面的方式:
论文给给出了上面两种更新方式的选择,其中第一种是用局部的梯度来更新全局模型中的控制变量ccc,第二种复用了全局模型的知识,直观上理解是根据全局模型与局部模型的差异来更新ccc。论文中给出的上面两种选择的区别是第一种方法要更稳定,第二种方法更加取决于应用场景,但是第二种方法更容易计算并且在通常情况下也已经足够优秀
全局更新方式
对于模型的更新与传统联邦并无太大区别:
x←x+ηgΔxx\leftarrow x + \eta_g \Delta x x←x+ηgΔx
控制变量的更新:
c←c+1N∑i∈S(ci+−ci)c\leftarrow c + \frac{1}{N}\sum_{i\in S}(c_i^+ - c_i) c←c+N1i∈S∑(ci+−ci)
控制变量的更新方法也是和模型的更新方法差不多,本质上都是将局部模型的知识更新传递到全局模型
实验
实验在EMNIST数据集上进行,结果证明了SCAFFOLD算法与FedAvg算法和FedProx算法相比是表现最好的,如下图所示:
上面的3幅图表示的是SGD和FedAvg的比较实验,可以看到当梯度差异(G)很小是,FebAvg在训练刚刚开始的时候要比SGD好,但是在当G比较大的时候,由于客户容易发生“客户漂移”现象,容易带偏系统的训练方向,因此收敛效果和速度都会变差。下面的3幅图表示的是论文提出的算法Scaffold与SGD的比较,可以看到Scaffold算法无论是收敛速度和效果都比SGD要好
总结
论文的一个基本思想本质上便是用全局模型的知识去限制局部模型的训练方向,以防止那些与全局模型相差较大的局部模型带偏整个系统的训练方向
SCAFFOLD: Stochastic Controlled Averaging for Federated Learning相关推荐
- SCAFFOLD: Stochastic Controlled Averaging for Federated Learning学习
SCAFFOLD: Stochastic Controlled Averaging for Federated Learning学习 背景 贡献 论文思想 算法 局部更新方式 全局更新方式 实验 总结 ...
- 【阅读笔记】Towards Personalized Federated Learning个性化联邦综述
文章目录 前言 1 背景 1.1 机器学习.联邦学习 1.2 促进个性化联邦学习的动机 2 个性化联邦学习的策略 2.1 全局模型个性化 2.1.1 基于数据的方法 2.1.1.1 数据增强 Data ...
- Federated Learning in Mobile Edge Networks: AComprehensive Survey(翻译)
名词:联邦学习(FL).ML.MEC BAA(宽带模拟聚合).CNN(卷积神经网络).CV(计算机视觉). DDQN(双深度Q网络).DL(深度学习)DNN(深度神经网络). DP(差分隐私).DQL ...
- 【Paper Reading】BatchCrypt: Efficient Homomorphic Encryption for Cross-Silo Federated Learning
BatchCrypt: Efficient Homomorphic Encryption for Cross-Silo Federated Learning 原文来源:[ATC 2020] Batch ...
- Stochastic Weight Averaging (SWA) 随机权重平均
文章目录 相关链接 基础 思路 主要内容 概括 SWA图示 SWA算法 LR The Algorithm Batch normalization 在PyTorch中使用swa 最佳实践 Demo 最近 ...
- Stochastic Weight Averaging
PyTorch从1.6.0版本以后开始支持Stochastic Weight Averaging. That is, after the conventional training of an obj ...
- 【联邦学习】横向联邦学习(Horizontal Federated Learning,HFL)
文章目录 一.横向联邦学习的定义 二.横向联邦学习的安全性 三.横向联邦学习架构 1. 客户-服务器架构 2. 对等网络架构 四.联邦优化 五.联邦平均算法 参考链接 一.横向联邦学习的定义 横向联邦 ...
- 2.Paper小结——《Privacy-preserving blockchain-based federated learning for traffic flow prediction》
题目: 基于区块链的基于隐私保护的交通流量预测的联邦学习 0.Abstract: 交通流量预测已成为智能交通系统的重要组成部分.然而,现有的基于集中式机器学习的交通流量预测方法需要收集原始数据以进行模 ...
- 论文阅读(联邦学习):Exploiting Shared Representations for Personalized Federated Learning
Exploiting Shared Representations for Personalized Federated Learning 原文传送门 http://proceedings.mlr.p ...
最新文章
- moss管理中心崩溃之解决
- C++返回栈上的数组(局部变量)问题探索
- QEMU-KVM中的多线程压缩迁移技术
- 基于树的模型的机器学习
- 解决 Unmapped Spring configuration files found.Please configure Spring facet.
- Linux awk的 if语句,AWK if(条件)语句与循环简介
- 华为P30系列再曝光:屏幕参数揭晓 还要用水滴全面屏?
- C++ 多线程 atomic
- 网络编程中客户端链接的合法性,socketserver模块
- 初学 快速幂 的理解
- 特征提取之——Haar特征
- 判断URL的HTTP状态
- carplay是否可以用安卓系统_Microsoft Teams正在支持CarPlay通话,我还是期待微信支持CarPlay...
- MUSIC算法相关原理知识(物理解读+数学推导+Matlab代码实现)
- CSS综合案例——淘宝焦点图(轮播图)布局及网页布局总结
- matlab排序算法,相同位置返回元素排名
- 计算机1946考试试题,统考计算机考试试题及答案
- 中小企业数据防泄密怎么做,墨门云数据防泄密指南
- Python个人项目2 --------青蛙旅行项目
- design contains shelved or modified (but not repoured) polygons. the result....继续铺铜还是报警,解决方案如下: