论文地址点这里

一. 介绍

联邦学习中由于各个客户端上数据异构问题,导致全局训练模型无法适应每一个客户端的要求。作者通过利用客户端之间的共同代表来解决这个问题。具体来说,将数据异构的联邦学习问题视为并行学习任务,这些任务之间可能存在一些共同的结构,作者的目标是学习和利用这种共同的表示来提高每个客户端的模型质量,基于此提出了FedRep(联邦表示学习)。
FedRep: 联邦表示学习利用跨客户机的存储的所有数据,使用基于梯度的更新来学习全局低维表示。此外,使得每个客户端能够计算一个个性化的、低维的分类器,负责每个客户端的本地数据的唯一标识。

二. 问题定义

传统的联邦学习从n个客户端上优化下面目标:

min⁡(q1,...,qn)∈Qn1n∑i=1nfi(qi)\min_{(q_1,...,q_n)\in\mathcal{Q_n}}\frac{1}{n}\sum_{i=1}^nf_i(q_i)(q1​,...,qn​)∈Qn​min​n1​i=1∑n​fi​(qi​)

其中fif_ifi​表示第i个客户端上的损失函数,qiq_iqi​表示第i个客户端上的模型。但由于客户端上的数据较少,同时客户端数量庞大,客户端无法学习到一个很小损失的模型,因此联邦学习允许客户端之间进行参数交互。传统的方式是想让客户端学习到一个共同的模型,也就是q1=q2=...=qnq_1=q_2=...=q_nq1​=q2​=...=qn​,但当客户端数据异构明显时,客户端的模型应该更接近于本地的数据。因此我们有必要去学习到一组{qi}\{q_i\}{qi​}使得其满足于自身的数据。
学习一个共同的表示(Learning a Common Representation)。我们考虑一个全局的表示ϕ:Rd→Rk\phi:\mathbb{R}^d \to \mathbb{R}^kϕ:Rd→Rk,将数据映射到一个更低的维度k;客户端的特殊表示头:Rk→Y\mathbb{R}^k \to \mathcal{Y}Rk→Y。根据此,第i个客户端上的模型是客户端上的局部参数和全局表示的组合:qi(x)=(hi∘ϕ)(x)q_i(x)=(h_i \circ\phi)(x)qi​(x)=(hi​∘ϕ)(x)。值得注意的是,k远远小于d,也就是说每个客户端必须在本地学习的参数数量很少。我们根据新的内容重新改写我们的全局优化目标:

min⁡ϕ∈Φ1n∑i=1nmin⁡hi∈Hfi(hi∘ϕ)\min_{\phi \in \Phi}\frac{1}{n}\sum_{i=1}^n\min_{h_i\in\mathcal{H}}f_i({h_i} \circ\phi)ϕ∈Φmin​n1​i=1∑n​hi​∈Hmin​fi​(hi​∘ϕ)

其中Φ\PhiΦ为可行的表示类,而H\mathcal{H}H为可行的头类。客户端使用所有客户的数据协同学习全局模型,同时使用自己的本地信息学习个性化的头部。

三. FedRep算法

算法思想如图所示:

服务器和客户端共同学习ϕ\phiϕ,客户端自己学习自己的参数头hhh。
客户端更新: 在每一轮,被选中的客户端进行训练。这些客户端通过服务端来的ϕi\phi_iϕi​进行更新自己的hih_ihi​,如下:

hit,s=GRD(fi(hit,s−1,ϕt),hit,s−1,α)h_i^{t,s} = GRD(f_i(h_i^{t,s-1},\phi^t),h_i^{t,s-1},\alpha)hit,s​=GRD(fi​(hit,s−1​,ϕt),hit,s−1​,α)

GRD为一个梯度下降的优化表示,其意思为我们对参数h在f上使用一次梯度下降以α\alphaα为步长进行更新。训练完τh\tau_hτh​步更新h后,我们同样ϕ\phiϕ进行τϕ\tau_\phiτϕ​次更新,如下:

ϕit,s=GRD(fi(hit,τh,ϕit,s−1),ϕit,s−1,α)\phi_i^{t,s}=GRD(f_i(h_i^{t,\tau_h},\phi_i^{t,s-1}),\phi_i^{t,s-1},\alpha)ϕit,s​=GRD(fi​(hit,τh​​,ϕit,s−1​),ϕit,s−1​,α)

服务端更新: 客户端完成更新后返回给服务端ϕit,τϕ\phi_i^{t,\tau_\phi}ϕit,τϕ​​,服务端聚合后后求平均。
具的算法如下图:

四. 代码详解

作者的代码点这里
相信这一篇文章大家应该理解起来不会有困难,就是分层进行一次处理即可。
首先,最主要关心的就是怎么分层。根据思路,我们需要分为rep层和head层,head为自己的参数。rep则是参与共享,在分层前,我们看一看网络:

class CNNCifar100(nn.Module):def __init__(self, args):super(CNNCifar100, self).__init__()self.conv1 = nn.Conv2d(3, 64, 5)self.pool = nn.MaxPool2d(2, 2)self.drop = nn.Dropout(0.6)self.conv2 = nn.Conv2d(64, 128, 5)self.fc1 = nn.Linear(128 * 5 * 5, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, args.num_classes)self.cls = args.num_classesself.weight_keys = [['fc1.weight', 'fc1.bias'],['fc2.weight', 'fc2.bias'],['fc3.weight', 'fc3.bias'],['conv2.weight', 'conv2.bias'],['conv1.weight', 'conv1.bias'],]def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 128 * 5 * 5)x = F.relu(self.fc1(x))x = self.drop((F.relu(self.fc2(x))))x = self.fc3(x)return F.log_softmax(x, dim=1)

一个很简单的CNN网络,其中我们把每一层名称储存下来,方便进行分层。

if args.alg == 'fedrep' or args.alg == 'fedper':if 'cifar' in  args.dataset:w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,3,4]]elif 'mnist' in args.dataset:w_glob_keys = [net_glob.weight_keys[i] for i in [0,1,2]]elif 'sent140' in args.dataset:w_glob_keys = [net_keys[i] for i in [0,1,2,3,4,5]]else:w_glob_keys = net_keys[:-2]

这里就是简要的分层操作,可以看到对于我们处理cifar100的话,rep层取得是 0 1 3 4,对应的也就是除了fc3的最后一层。因此是最后一层为head,其余为rep。
之后开始训练,训练首先对于客户端来说是获取服务端的参数rep再加上自己的参数head,代码为:

if args.alg != 'fedavg' and args.alg != 'prox':for k in w_locals[idx].keys():if k not in w_glob_keys:w_local[k] = w_locals[idx][k]

其中w_glob_keys 就是rep的参数,w_local为所有的参数。
最后就是训练:

for iter in range(local_eps):done = False# for FedRep, 首先我们训练head固定rep,训练个几轮if (iter < head_eps and self.args.alg == 'fedrep') or last:for name, param in net.named_parameters():if name in w_glob_keys:param.requires_grad = Falseelse:param.requires_grad = True# 然后训练rep固定headelif iter == head_eps and self.args.alg == 'fedrep' and not last:for name, param in net.named_parameters():if name in w_glob_keys:param.requires_grad = Trueelse:param.requires_grad = False

Exploiting Shared Representations for Personalized Federated Learning 论文笔记+代码解读相关推荐

  1. 论文阅读(联邦学习):Exploiting Shared Representations for Personalized Federated Learning

    Exploiting Shared Representations for Personalized Federated Learning 原文传送门 http://proceedings.mlr.p ...

  2. Exploiting Shared Representations for Personalized Federated Learning【2021 icml】

    1.动机: 其中 fi表示第i个客户端上的损失函数, qi表示第i个客户端上的模型.但由于客户端上的数据较少,同时客户端数量庞大,客户端无法学习到一个很小损失的模型,因此在联邦学习中通过与服务器交换消 ...

  3. 【个性化联邦学习】Towards Personalized Federated Learning 论文笔记整理

    Towards Personalized Federated Learning 一.背景 二.解决策略 2.1 策略一.全局模型个性化 2.2 策略二.学习个性化模型 三.具体方案 3.1 全局模型个 ...

  4. Memory-Associated Differential Learning论文及代码解读

    Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论 ...

  5. Adaptive Personalized Federated Learning 论文解读+代码解析

    论文地址点这里 一. 介绍 联邦学习强调确保本地隐私情况下,对多个客户端进行训练,客户端之间不交换数据而交换参数来进行通信.目的是聚合成一个全局的模型,使得这个模型再各个客户端上读能取得较好的成果.联 ...

  6. Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析

    论文地址点这里 一. 介绍 联邦学习具有广阔的应用前景,但面临着来自数据异构的挑战,因为在现实世界中用户数据均为Non-IID分布的.在这样的情况下,传统的联邦学习算法可能会导致无法收敛到各个客户端的 ...

  7. DBA: Distributed Backdoor Attacks against Federated Learning论文笔记

      作者:Chulin Xie  Keli Huang  Pin-Yu Chen  Bo Li 来源:ICLR 2020 发表时间:May 26,2020   背景: 联邦学习能够聚合各方提供的信息, ...

  8. 【论文笔记+代码解读】《ATTENTION, LEARN TO SOLVE ROUTING PROBLEMS!》

    介绍 本文提出了一种注意力层+强化学习的训练模型,以解决TSP.VRP.OP.PCTSP等路径问题.文章致力于使用相同的超参数,解决多种路径问题.文中采用了贪心算法作为基线,相较于值函数效果更好. 注 ...

  9. 论文阅读:Personalized Federated Learning with Moreau Envelopes

    论文名字 Personalized Federated Learning with Moreau Envelopes 来源   年份 2021.3.3 作者 Canh T. Dinh, Nguyen ...

最新文章

  1. Kubernetes集群部署
  2. 深入理解Java:注解
  3. Bumblebee微服务网关之请求统一验证
  4. 一款好看新颖的404页面源码
  5. SetInterval(循环计时器)
  6. 如何用blend创建自定义窗口
  7. 无外网环境下CentOS 7安装MySQL 5.7.18
  8. 验证DG最大性能模式下使用ARCH/LGWR及STANDBY LOG的不同情况
  9. gitlab服务器性能,gitlab服务器搭建
  10. Atitit 安全措施流程法 目录 1. 常见等安全措施方法 2 1.1. 安全的语言 代码法,编译型 java 2 1.2. 安全编码法 2 1.3. 安全等框架类库 api 2 1.4. 加密法
  11. matlab生成点的坐标,根据点的发展坐标,将点的轨迹画出来
  12. 腾讯电脑管家卸载后的残留信息有哪些
  13. 统计学③——总体与样本的差异在哪里
  14. 消费评价网 | 线上保险消费调查报告 虚假宣传多 捆绑销售坑人
  15. Vue+ elementui 布局混乱
  16. 利用FFmpeg编码器将JPG图片进行H.264编码原理
  17. IRT和DINA模型学习总结
  18. USACO Score Inflation 总分
  19. unzip error 22 - Invalid argument
  20. 计组学习笔记2(RISC v版)

热门文章

  1. 反思与总结-10月与11月半
  2. python求解欧拉Euler公式
  3. 经典问题解决办法(转)
  4. 2022 需求工程选择填空题【太原理工大学】
  5. 用noMeiryoUI为Windows10换个OPPO Sans字体吧
  6. spring-cloud服务网关中的Timeout设置
  7. 办公软件excel表格_【办公软件】文字排版 表格制作 PPT培训
  8. TWS耳机及相关蓝牙协议
  9. 客户端deeplink技术
  10. BIGEMAP通过离线地图二次开发接口(离线地图API)