目录

  • 前言
  • 1 引言
  • 2 Per-FedAvg
    • 2.1 初始模型
    • 2.2 本地自适应
  • 3. 总结

前言


题目: Personalized Federated Learning: A Meta-Learning
Approach
论文地址:Personalized Federated Learning: A Meta-Learning
Approach

元学习是当下比较热门的一个研究方向,本篇文章将联邦学习和一种模型不可知元学习方法MAML结合起来,提出了一种新的个性化联邦技术Per-FedAvg。

所谓元学习,就是学会学习。利用元学习得到的模型,当我们在面临一个新的任务时,经过很少的训练步骤就可以得到一个比较好的模型,而不必像经典机器学习一样,需要在一个数据集上进行大量训练。Per-FedAvg的思想类似,我们利用所有客户端的数据得到一个初始模型,然后各个客户端使用该初始模型在本地进行几次梯度下降就能得到最终模型。

1 引言

联邦学习框架中,假设一共nnn个客户端,那么优化函数为:

也就是最小化各个客户端损失的均值。对每个客户端来讲,损失函数可以定义为:

传统算法的局限显而易见:在用户数据分布不完全相同的异质环境中,通过最小化平均损失得到的全局模型一旦应用于每个用户的本地数据集,可能会表现得比较糟糕。

为了应对数据的统计异质性和非IID分布所带来的挑战,需要对全局模型进行个性化处理。前面已经讲过一个比较简单的联邦个性化算法FedPer,FedPer中每个客户端都有自己的模型,所有客户端模型共享神经网络的基础层,而个性化层通过自己本地数据进行训练。如果数据量不足,通过基础层的共享可以获得一个训练比较充分的底层模型,而顶端个性化层又可以保证模型对本地数据具有较好的适应性。

与FedPer不同,Per-FedAvg的目的是获得一个初始模型,然后使用该初始模型在各个客户端的数据上进行少数几轮训练就可以得到一个较好的本地模型。通过这种方式,虽然初始模型是在所有用户上以分布式方式导出的,但每个用户的最终模型都与其他客户端的模型不同,这一点与FedPer一致。

2 Per-FedAvg

2.1 初始模型

我们假设所有用户的都得到了自己的初始模型,然后在本地数据上使用少数几次(比如一次)梯度下降,就可以得到自己需要的模型,那么优化目标可以定义为:

这里α\alphaα为学习率。可以发现,上式与FedAvg的差别在于,Per-FedAvg中客户端需要优化的函数是在FedAvg函数的基础上进行一次梯度下降后得到的。这样,上述公式不仅可以保持联邦学习的优势(联合所有客户端数据),也可以捕捉不同用户间的差异:客户端可以根据自己的数据修改初始模型,进而得到自己的模型

对于每个客户端,我们定义它的元函数Fi(w)F_i(w)Fi​(w):

Per-FedAvg的伪代码描述如下:

为了在本地训练中对Fi(w)F_i(w)Fi​(w)进行更新,我们需要计算其梯度:

可以观察到,由于Fi(w)F_i(w)Fi​(w)的表达式中有fi(w)f_i(w)fi​(w)的梯度∇fi(w)\nabla f_i(w)∇fi​(w),所以在计算∇Fi(w)\nabla F_i(w)∇Fi​(w)时我们需要计算参数的Hessian矩阵∇2fi(w)\nabla^2 f_i(w)∇2fi​(w)。

计算∇fi(w)\nabla f_i(w)∇fi​(w)的代价很大,因此论文中的计算方式为:在客户端本地选取一批数据DiD^iDi,然后利用这批数据来得到∇fi(w)\nabla f_i(w)∇fi​(w)的一个无偏估计。即:

也就是DiD^iDi中所有梯度求均值。

类似地,对于∇2fi(w)\nabla^2 f_i(w)∇2fi​(w),我们同样可以取一批数据得到其无偏估计。

与FedAvg类似,Per-FedAvg中第kkk轮通信时,服务器将模型发送给选中的客户端,然后每个客户端执行τ\tauτ轮本地梯度更新:

由于Fi(w)F_i(w)Fi​(w)实际上是损失函数fi(w)f_i(w)fi​(w)进行一步梯度更新之后得到的,即:

那么我们首先需要选取一批数据,然后对原始的损失函数求梯度,然后更新:

这个时候我们得到了参数w~k+1,ti\tilde{w}_{k+1,t}^iw~k+1,ti​,这其中iii为客户端编号,k+1k+1k+1表示当前的全局轮数,ttt表示本地更新的轮数。

得到w~k+1,ti\tilde{w}_{k+1,t}^iw~k+1,ti​后,然后观察元函数的真实梯度计算公式:

可以发现,我们还需要对w~k+1,ti\tilde{w}_{k+1,t}^iw~k+1,ti​进行二次求导,也就是上式右边那部分,这部分同样选取一批数据求无偏梯度。上式左边部分的二阶梯度就是对原始损失函数求二阶梯度,比较简单。

然后元函数梯度可以表示为:

这里的DtiD_t^iDti​,Dt′iD_t^{'i}Dt′i​以及Dt′′iD_t^{''i}Dt′′i​是在本地选取的三批独立的数据。

此时我们就可以对客户端的本地模型进行参数进行更新了:

更新完毕后将最新的参数传到服务器进行聚合:

然后重复上述步骤。

简单总结下Per-FedAvg:

  1. 服务器初始化模型。
  2. 服务器选择一部分客户端发送模型。
  3. 对被选中的客户端来讲,需要进行τ\tauτ轮本地更新,在每一轮本地更新中:首先选择一批数据计算损失函数fi(w)f_i(w)fi​(w)的梯度,然后进行一步梯度下降得到元函数Fi(w)F_i(w)Fi​(w);然后再选择一批数据对Fi(w)F_i(w)Fi​(w)进行梯度下降得到更新后的元函数。
  4. 客户端将更新好的元函数上传到服务器进行聚合。
  5. 服务器将更新后的模型发往被选中的客户端,然后重复上述步骤。

经过多轮通信后,我们得到一个初始模型F(w)F(w)F(w)。

2.2 本地自适应

经过2.1,我们得到了一个初始模型F(w)F(w)F(w),然后每个客户端利用该模型在自己本地训练很少的轮数,就可以得到一个表现比较好的模型。

本文的重点应该是有关Per-FedAvg收敛性的推导,看着华丽的数学推导,我只能感叹人与人之间的差异之大。

3. 总结

所谓元学习,就是学会如何学习。利用元学习我们可以得到一个初始模型,该初始模型在一批新的数据上进行少数几轮迭代后就能快速收敛,得到一个不错的个性化模型。Per-FedAvg借鉴了这一思想,设计了一个新的优化函数,该优化函数是所有客户端元函数的平均,而客户端元函数则是本地损失函数进行一步梯度下降后的得到的模型。对新的优化函数进行优化后,我们得到的初始模型就能对客户端进行快速自适应。

arXiv | Per-FedAvg:一种联邦元学习方法相关推荐

  1. 今日 Paper | 虚拟试穿网络;人群计数基准;联邦元学习;目标检测等

    2020-01-15 05:41:40 为了帮助各位学术青年更好地学习前沿研究成果和技术,AI科技评论联合Paper 研习社(paper.yanxishe.com),推出[今日 Paper]栏目, 每 ...

  2. 【联邦元学习】论文解读:Federated Meta-Learning for Fraudulent Credit Card Detection

    论文:Zheng W, Yan L, Gou C, et al. Federated Meta-Learning for Fraudulent Credit Card Detection[C], Pr ...

  3. 【论文极速读】VQ-VAE:一种稀疏表征学习方法

    [论文极速读]VQ-VAE:一种稀疏表征学习方法 FesianXu 20221208 at Baidu Search Team 前言 最近有需求对特征进行稀疏编码,看到一篇论文VQ-VAE,简单进行笔 ...

  4. 李飞飞点赞「ARM」:一种让模型快速适应数据变化的元学习方法 | 开源

    鱼羊 编译整理 量子位 报道 | 公众号 QbitAI 训练好的模型,遇到新的一组数据就懵了,这是机器学习中常见的问题. 举一个简单的例子,比如对一个手写笔迹识别模型来说,它的训练数据长这样: 那么当 ...

  5. 新框架ES-MAML:基于进化策略、简易的元学习方法

    作者 | Xingyou Song.Wenbo Gao.Yuxiang Yang.Krzysztof Choromanski.Aldo Pacchiano.Yunhao Tang 译者 | TroyC ...

  6. 001 A Comprehensive Survey of Privacy-preserving Federated Learning(便于寻找:FedAvg、垂直联邦学习的基本步骤)

    这是我看的第一篇关于联邦学习的论文,综述文章,让我对联邦学习有了初步的了解. A Comprehensive Survey of Privacy-preserving Federated Learni ...

  7. 论文阅读:基于区块链的分布式软件定义车载网络一种深度q -学习方法

    关键技术运用:许可区块链+软件定义的vanet+基于优先级体验重放的深度q学习 传统的问题及文章解决方法 问题1:缺乏基础设施和动态性 方法:使用软件定义的vanet动态和安全地管理vanet. 问题 ...

  8. 走进元学习:概述不同类型的元学习方法

    2020-10-04 12:30:00 全文共1596字,预计学习时长4分钟 图源:unsplash 元学习是深度学习领域中最活跃的研究领域之一.人工智能界的一些学派赞同这样一种观点:元学习是开启人工 ...

  9. 五种有效的学习方法 – 方法比努力重要

    说明:此文是有一个博士(王晶)写的,我就是复制过来的,大家有时间可以看这个博主的博客,非常不错.链接 1 .目标学习法 掌握目标学习法是美国心理学家布卢姆所倡导的.布卢姆认为只要有最佳的教学,给学生以 ...

最新文章

  1. shell 取中间行的第一列_shell脚本的使用该熟练起来了,你说呢?(篇三)
  2. 【机器学习算法-python实现】Adaboost的实现(1)-单层决策树(decision stump)
  3. vue引入外部文件_vue文件中引入外部js
  4. Oracle 常用命令
  5. 如何分享文件_分布式文件存储系统如何分享文件
  6. 计算机应用技术一级考试成绩,《计算机应用基础》课程与等级考试成绩的关系...
  7. python权限不够cmd安装不了_python环境配置+matplotlib
  8. Java:一步步带你深入了解神秘的Java反射机制
  9. EMUI10 亮相开发者大会:分布式设计打造全场景体验
  10. 急救模式下安装rpm包
  11. CAsyncSocket使用总结
  12. 关于HTML的FORM上传文件问题
  13. OpenLayers自定义投影,转换OpenLayers中加载的OSM的默认投影坐标
  14. cocoStudio UI编辑器设置自定义字体
  15. 【IDEA】项目集成svn
  16. android定义颜色数组,Colours——一套漂亮的预定义颜色库和方法
  17. [无视][mark]退役记
  18. SICP Python 描述 中文版
  19. 爬虫第3课 -豆瓣TOP250电影爬取
  20. 根据UA获取用户访问操作系统、浏览器名

热门文章

  1. Es Bucket聚合(桶聚合) 第一篇(常用桶聚合一览)
  2. dispatch_source
  3. AI服装生成,帮你完成服装设计的最后一步
  4. 用格里高利公式求π的近似值
  5. 多视角学习 (Multi-View Learning)
  6. 巨变的中国与数字化转型,创造了中国企业技术出海的历史机遇
  7. Xcode7.3.1中通过最新的CocoaPod安装pop动画引擎
  8. 【JavaWeb学习】Vue核心
  9. css-doodle_如何使用CSS Doodle用CSS绘制图案
  10. Matlab + Adobe illustrator科研作图