论文地址点这里

一. 介绍

联邦学习强调确保本地隐私情况下,对多个客户端进行训练,客户端之间不交换数据而交换参数来进行通信。目的是聚合成一个全局的模型,使得这个模型再各个客户端上读能取得较好的成果。联邦学习中FedAvg方法最为广泛,但由于本地数据分片之间的固有多样性和数据再客户端的高度非iid(独立同分布),FedAvg对超参数非常敏感,不能从良好的手链保证中获益。因此在设备异质性存在的情况下,全局模型不能很好的概括每个客户单独的本地数据。
随着客户端数据的多样性增加,全局模型和个性化(客户端)模型的误差将会越来越大,好的全局模型回导致一个差的本地客户端的模型。
在这项工作中,作者提出了一个新的联邦学习框架,该框架优化了所有客户急的性能。减少泛化误差依赖于局部数据的分布特征。因此,该模型目标为倾向于学习一种混合了全球模式和本地模式的个性化模式。但难点在于如何确保局部数据是适合所有客户的全局模型。

二. 相关工作

联邦学习的主要目标为学习一个全局模型,这个全局模型对与尚未看到的数据足够好,并且能快速收敛到局部最优,这一点和元学习有一些相似之处。但尽管存在这种相似性,元学习方法主要是常识学习多个模型,针对每个新任务进行个性化学习,而在大多数联邦学习中,更关注单个全局模型。而全局模型和本地模型的差异性就是个性化的重要表现。联邦学习中个性化主要有三大类,本地微调,多任务学习和情景化。
本地微调(Local fine tuning): 本地微调即每个客户端接收到一个全局模型,并使用自己的局部数据和几个梯度下降步骤对其进行调优,这种方法主要结合了元学习。
多任务学习(multi_task learning): 对个性化问题的另一种观点是视为多任务学习问题。这种设置下对每个客户端的优化可以看做是一个新的任务。
情景化(Contextualization): 个性化联邦学习中的一个重要应用是在不同情境下使用模型。我们需要在不同的环境下对一个客户端进行个性化的模型。
通过模型混合进行个性化(Personalization via mixing models): 通过混合全局和局部的模型引入不同的个性化方法来进行联邦学习。基于此,有三种不同的个性化方法,即客户聚类、数据插值和模型插值。而前两种对数据隐私性造成破坏,只有第三种是较为合理的模式。

三. 个性化联邦学习

3.1 定义:

Di:D_i :Di​:第i个客户端上数据集(有标签)
Dˉ=(1/n)∑i=1nDi\bar{D} = (1/n)\sum_{i=1}^{n}D_iDˉ=(1/n)∑i=1n​Di​:所有客户端的平均分布
LDi(h)=E(x,y)∈Di[l(h(x),y)]:\mathcal{L}_{D_i}(h) = \mathbb{E}_{(x,y)\in D_i}[\mathcal{l(h(x),y)}]:LDi​​(h)=E(x,y)∈Di​​[l(h(x),y)]:在客户端i上的真实风险。
L^Di(h):\widehat{\mathcal{L}}_{D_i}(h):LDi​​(h):在客户端i上对于h的经验风险

3.2 个性化模型

在一个标准的联邦学习场景中,目的是为所有设备合作学习一个全局模型。同时各个客户端存在着本地模型,在自适应个性化联邦学习中,目标是找到全局模型和局部模型的最优组合,以实现更好的针对客户的模型。在这种设置中,每个用户训练一个局部模型,同时合并部分全局模型,并使用一些混合权重,数学表达如下:
hαi=αih^i∗+(1−αi)hˉ∗h_{\alpha_i} = \alpha_i \widehat{h}_i^*\ +\ (1 -\alpha_i )\bar{h}^*hαi​​=αi​hi∗​ + (1−αi​)hˉ∗
其中hˉ∗=argmin⁡h∈HL^Dˉ(h)\bar{h}^* = arg\min_{h\in\mathcal{H}}\widehat{\mathcal{L}}_{\bar{D}}(h)hˉ∗=argminh∈H​LDˉ​(h)为全局的经验优化最小,
h^i∗=argmin⁡h∈HL^Dˉ(αih+(1−αi)hˉ∗)\widehat{h}_i^* = arg\min_{{h\in\mathcal{H}}}\widehat{\mathcal{L}}_{\bar{D}}(\alpha_ih+(1-\alpha_i)\bar{h}^*)hi∗​=argminh∈H​LDˉ​(αi​h+(1−αi​)hˉ∗)是一个在第i个客户端上取得最小损失的混合模型。
(这里我解释一下,就是说我们的模型由两部分组成,一个是全局的模型,另一个是客户端的模型,至于为什么客户端的模型又是由一个混合组成呢?这里考虑成多轮训练即可,假设t-1轮全局模型为w,本地模型为v,然后我们融合成混合模型为h=w+v。在t轮的时候,全局模型为新的w,而本地模型则是继承t-1轮的混合模型h,所以对应的v可以代指为本地模型)

3.3 APFL算法

就像传统的联邦学习意义,服务器需要解决目标如下:
min⁡w∈Rd[F(w)=1n∑i=1n{fi(w)=Eξ[fi(w,ξi)]}]\min_{\mathcal{w}\in \mathbb{R^d}}[F(w)=\frac{1}{n}\sum_{i=1}^n\{f_i(w)=\mathbb{E_\xi[f_i(w,\xi_i)]}\}]w∈Rdmin​[F(w)=n1​i=1∑n​{fi​(w)=Eξ​[fi​(w,ξi​)]}]
而客户端采取上面的方式(个性化)
min⁡v∈Rdfi(αiv+(1−alphai)w∗)\min_{\mathcal{v}\in \mathbb{R^d}}f_i(\alpha_iv+(1-alpha_i)w^*)v∈Rdmin​fi​(αi​v+(1−alphai​)w∗)
其中w∗=argmin⁡wF(w)w^*=arg\min_w F(w)w∗=argminw​F(w)
具体步骤如下:

对于参与训练的客户端来说,存在着两个参数。一个是w:全局参数,一个是v:自己的参数。首先根据数据集对w进行更新(用t-1轮的参数)。对v更新方式也是如此,v通过混合参数(vˉ\bar{v}vˉ)计算梯度来进行更新,之后将新的w和v合成我们当前的混合参数,再把w传到服务端进行合并。

3.4 α\alphaα的取值

直观来看,当本地数据较为均匀,每个客户端的局部模型接近全局模型时我们需要较小的α\alphaα;相反,当本地数据多样性较强时,α\alphaα应该接近1。我们需要再不同分布的场景下更新我们的α\alphaα:
αi∗=argmin⁡αi∈[0,1]fi(αiv+(1−αi)w)\alpha^*_i = arg\min_{\alpha_i \in[0,1]}f_i(\alpha_iv+(1-\alpha_i)w)αi∗​=argαi​∈[0,1]min​fi​(αi​v+(1−αi​)w)
我们可以使用梯度下降来更新一次α\alphaα。
αi(t)=αi(t−1)−ηt∇αfi(vˉi(t−1);ξit)=αi(t−1)−ηt<vi(t−1)−wi(t−1),∇fi(vˉi(t−1);ξit)>\begin{aligned} \alpha_i^{(t)}&=\alpha_i^{(t-1)}-\eta_t \nabla_\alpha f_i(\bar{v}_i^{(t-1)};\xi_i^t)\\ &=\alpha_i^{(t-1)}-\eta_t <v_i^{(t-1) }-w_i^{(t-1)},\nabla f_i(\bar{v}_i^{(t-1)};\xi_i^t)> \end{aligned} αi(t)​​=αi(t−1)​−ηt​∇α​fi​(vˉi(t−1)​;ξit​)=αi(t−1)​−ηt​<vi(t−1)​−wi(t−1)​,∇fi​(vˉi(t−1)​;ξit​)>​

四. 关键代码解析

作者的代码github地址点这里,这个github还包括很多其他的联邦学习算法,这里只针对APFL算法进行讲解。
APFL主要不同在客户端的更新上,因此我们针对客户端训练进行解读。
首先是对全局模型参数w的更新
直接读取数据,求损失,用SGD更新

_input, _target = load_data_batch(client.args, _input, _target, tracker)
# Skip batches with one sample because of BatchNorm issue in some models!
if _input.size(0)==1:is_sync = is_sync_fed(client.args)break# inference and get current performance.
client.optimizer.zero_grad()
loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target)# compute gradient and do local SGD step.
loss.backward()
client.optimizer.step(apply_lr=True,apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
)

接下来是本地模型参数的更新v:

client.optimizer_personal.zero_grad()
loss_personal, performance_personal = inference_personal(client.model_personal, client.model, client.args.fed_personal_alpha, client.criterion, client.metrics, _input, _target)# compute gradient and do local SGD step.
loss_personal.backward()
client.optimizer_personal.step(apply_lr=True,apply_in_momentum=client.args.in_momentum, apply_out_momentum=False
)

也是一样的,拿到一个batch数据,求损失,注意这里求损失对应的参数为上一轮的混合参数,而不仅仅是本地参数,求混合参数损失代码如下:
其实就是用α\alphaα来合成混合参数算损失

def inference_personal(model1, model2, alpha, criterion, metrics, _input, _target):"""Inference on the given model and get loss and accuracy."""# TODO: merge with inferenceoutput1 = model1(_input)output2 = model2(_input)output = alpha * output1 + (1-alpha) * output2loss = criterion(output, _target)performance = accuracy(output.data, _target, topk=metrics)return loss, performance

到这里其实就实现了APFL,但还有个关键的地方在于,每一轮在训练前更细呢一次α\alphaα,通过3.4节讲解的方式更新:

def alpha_update(model_local, model_personal,alpha, eta):grad_alpha = 0for l_params, p_params in zip(model_local.parameters(), model_personal.parameters()):## 这里为 v - wdif = p_params.data - l_params.data## 这里为f(\bar{v}的损失)grad = alpha * p_params.grad.data + (1-alpha)*l_params.grad.data## 乘起来即可grad_alpha += dif.view(-1).T.dot(grad.view(-1))grad_alpha += 0.02 * alpha## 进行更新alpha_n = alpha - eta*grad_alpha## 确保在0,1之间alpha_n = np.clip(alpha_n.item(),0.0,1.0)return alpha_n

到这里,APFL算法就介绍完了,最后附上apfl整个训练的代码,方便大家查看:

def train_and_validate_federated_apfl(client):"""The training scheme of Personalized Federated Learning.Official implementation for https://arxiv.org/abs/2003.13461"""log('start training and validation with Federated setting.', client.args.debug)if client.args.evaluate and client.args.graph.rank==0:# Do the testing on the server and returndo_validate(client.args, client.model, client.optimizer,  client.criterion, client.metrics,client.test_loader, client.all_clients_group, data_mode='test')returntracker = define_local_training_tracker()start_global_time = time.time()tracker['start_load_time'] = time.time()log('enter the training.', client.args.debug)# Number of communication rounds in federated setting should be definedfor n_c in range(client.args.num_comms):client.args.rounds_comm += 1client.args.comm_time.append(0.0)# Configuring the devices for this round of communication# TODO: not make the server rank hard codedlog("Starting round {} of training".format(n_c), client.args.debug)online_clients = set_online_clients(client.args)if (n_c == 0) and  (0 not in online_clients):online_clients += [0]online_clients_server = online_clients if 0 in online_clients else online_clients + [0]online_clients_group = dist.new_group(online_clients_server)if client.args.graph.rank in online_clients_server: client.model_server = distribute_model_server(client.model_server, online_clients_group, src=0)client.model.load_state_dict(client.model_server.state_dict())if client.args.graph.rank in online_clients:is_sync = Falseep = -1 # counting number of epochswhile not is_sync:ep += 1for i, (_input, _target) in enumerate(client.train_loader):client.model.train()# update local step.logging_load_time(tracker)# update local index and get local stepclient.args.local_index += 1client.args.local_data_seen += len(_target)get_current_epoch(client.args)local_step = get_current_local_step(client.args)# adjust learning rate (based on the # of accessed samples)lr = adjust_learning_rate(client.args, client.optimizer, client.scheduler)# load data_input, _target = load_data_batch(client.args, _input, _target, tracker)# Skip batches with one sample because of BatchNorm issue in some models!if _input.size(0)==1:is_sync = is_sync_fed(client.args)break# inference and get current performance.client.optimizer.zero_grad()loss, performance = inference(client.model, client.criterion, client.metrics, _input, _target)# compute gradient and do local SGD step.loss.backward()client.optimizer.step(apply_lr=True,apply_in_momentum=client.args.in_momentum, apply_out_momentum=False)client.optimizer.zero_grad()client.optimizer_personal.zero_grad()loss_personal, performance_personal = inference_personal(client.model_personal, client.model, client.args.fed_personal_alpha, client.criterion, client.metrics, _input, _target)# compute gradient and do local SGD step.loss_personal.backward()client.optimizer_personal.step(apply_lr=True,apply_in_momentum=client.args.in_momentum, apply_out_momentum=False)# update alphaif client.args.fed_adaptive_alpha and i==0 and ep==0:client.args.fed_personal_alpha = alpha_update(client.model, client.model_personal, client.args.fed_personal_alpha, lr) #0.1/np.sqrt(1+args.local_index))average_alpha = client.args.fed_personal_alphaaverage_alpha = global_average(average_alpha, client.args.graph.n_nodes, group=online_clients_group)log("New alpha is:{}".format(average_alpha.item()), client.args.debug)# logging locally.# logging_computing(tracker, loss, performance, _input, lr)# display the logging info.# logging_display_training(args, tracker)# reset load time for the tracker.tracker['start_load_time'] = time.time()is_sync = is_sync_fed(client.args)if is_sync:breakelse:log("Offline in this round. Waiting on others to finish!", client.args.debug)do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha)if client.args.fed_personal:do_validate(client.args, client.model, client.optimizer_personal, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation', personal=True, model_personal=client.model_personal, alpha=client.args.fed_personal_alpha)# Sync the model server based on model_clientslog('Enter synching', client.args.debug)tracker['start_sync_time'] = time.time()client.args.global_index += 1client.model_server = fedavg_aggregation(client.args, client.model_server, client.model, online_clients_group, online_clients, client.optimizer)# evaluate the sync timelogging_sync_time(tracker)# Do the validation on the server modeldo_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.train_loader, online_clients_group, data_mode='train')if client.args.fed_personal:do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.val_loader, online_clients_group, data_mode='validation')# logging.logging_globally(tracker, start_global_time)# reset start round time.start_global_time = time.time()# validate the models at the test dataif client.args.fed_personal_test:do_validate(client.args, client.model_client, client.optimizer_personal, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test', personal=True,model_personal=client.model_personal, alpha=client.args.fed_personal_alpha)elif client.args.graph.rank == 0:do_validate(client.args, client.model_server, client.optimizer, client.criterion, client.metrics, client.test_loader, online_clients_group, data_mode='test')else:log("Offline in this round. Waiting on others to finish!", client.args.debug)dist.barrier(group=client.all_clients_group)

注意,这里求α\alphaα是在每一轮训练的第一个batch之后进行更新,我觉得目的是防止一开始的w和v初始化的结果影响太大,因此改为训练一个batch后更新。

Adaptive Personalized Federated Learning 论文解读+代码解析相关推荐

  1. Data-Free Knowledge Distillation for Heterogeneous Federated Learning论文阅读+代码解析

    论文地址点这里 一. 介绍 联邦学习具有广阔的应用前景,但面临着来自数据异构的挑战,因为在现实世界中用户数据均为Non-IID分布的.在这样的情况下,传统的联邦学习算法可能会导致无法收敛到各个客户端的 ...

  2. Exploiting Shared Representations for Personalized Federated Learning 论文笔记+代码解读

    论文地址点这里 一. 介绍 联邦学习中由于各个客户端上数据异构问题,导致全局训练模型无法适应每一个客户端的要求.作者通过利用客户端之间的共同代表来解决这个问题.具体来说,将数据异构的联邦学习问题视为并 ...

  3. 【个性化联邦学习】Towards Personalized Federated Learning 论文笔记整理

    Towards Personalized Federated Learning 一.背景 二.解决策略 2.1 策略一.全局模型个性化 2.2 策略二.学习个性化模型 三.具体方案 3.1 全局模型个 ...

  4. Gradient Episodic Memory for Continual Learning 论文阅读+代码解析

    一. 介绍 在开始进行监督学习的时候我们需要收集一个训练集 D t r = { ( x i , y i ) } i = 1 n D_{tr}=\{(x_i,y_i)\}^n_{i=1} Dtr​={( ...

  5. ONLINE CORESET SELECTION FOR REHEARSAL-BASED CONTINUAL LEARNING 论文阅读+代码解析

    本篇依旧是针对持续学习的工作,也是FedWEIT的团队进行的研究,论文地址点这里 一. 介绍(简单描述) 在持续学习中为了应对灾难性遗忘,常见的方法有基于正则化.基于记忆重塑以及基于动态架构.其中基于 ...

  6. Improved Schemes for Episodic Memory-based Lifelong Learning 论文阅读+代码解析

    论文地址点这里 一. 介绍 目前深度神经网络能够在单一任务上取得了显著的性能,然而当网络被重新训练到一个新的任务时,他的表现在以前训练过的任务上急剧下降,这种现象被称为灾难性遗忘.与之形成鲜明对比的是 ...

  7. 【阅读笔记】Towards Personalized Federated Learning个性化联邦综述

    文章目录 前言 1 背景 1.1 机器学习.联邦学习 1.2 促进个性化联邦学习的动机 2 个性化联邦学习的策略 2.1 全局模型个性化 2.1.1 基于数据的方法 2.1.1.1 数据增强 Data ...

  8. Memory-Associated Differential Learning论文及代码解读

    Memory-Associated Differential Learning论文及代码解读 论文来源: 论文PDF: Memory-Associated Differential Learning论 ...

  9. 论文阅读:Personalized Federated Learning with Moreau Envelopes

    论文名字 Personalized Federated Learning with Moreau Envelopes 来源   年份 2021.3.3 作者 Canh T. Dinh, Nguyen ...

最新文章

  1. 计算机网络面试题(一)
  2. 1077: 字符串加密
  3. android 实现悬架控制
  4. 如何检查对象的类型[iOS/Android/Windows Phone]
  5. 数据库授予用户增删改查的权限的语句_mysql创建本地用户及赋予数据库权限的方法示例...
  6. (chap8 确认访问用户身份的认证) 基于表单认证
  7. 透过三翼鸟,看品牌背后的“有效创新”
  8. 为什么学Python
  9. 中断方式下进行串口通讯的正确方法
  10. intellij导入scala工程不识别scala语言
  11. opengl与Directx的区别
  12. 现代软件工程 作业 最后一周总结
  13. 微异构Embree照片级光线追踪解决方案
  14. 精选CSDN的ACM-ICPC活跃博客
  15. mysql加索引后查询时间变长了(终于有头绪了)
  16. 【Hive】集合函数
  17. 安装mathtype打开word报错 mathtype.Dll cannot be found 解决方式
  18. web前端开发面试题(五)
  19. 分式加法JAVA程序_分式加减运算的八种技巧,有几种方法学校老师没讲过,记得收藏...
  20. ^^^ 存货盘盈盘亏的账务处理 Accounting for Inventory Profit and Inventory Loss with Goods Stock...

热门文章

  1. python拼音检查
  2. 驭梦KTV点歌系统简介
  3. python语言表白超炫图形_经验分享 篇二:三分钟教你用Excel制作各种尺寸、底色的证件照...
  4. git为私有仓库设置密码_我搭建了一套企业级私有Git服务,抗住了每天上万次攻击!...
  5. Java连接并操纵MySQL数据库的全过程
  6. 一文了解什么是嵌入式?
  7. java有阴历年算法吗_中国农历算法java实现
  8. android 禁用触摸屏,animation时禁用所有触摸屏交互
  9. python画三角函数--cosx
  10. 安徽省2019c语言二级答案,安徽省计算机等级二级考试真题C语言.doc