关注公众号,发现CV技术之美

这是一篇最新ICLR2022论文,Acceleration of Federated Learning with Alleviated Forgetting in Local Training,作者通过从一个灾难性遗忘的角度分析联邦学习性能不佳的原因,并进行改进提升收敛速度与精度。

《Acceleration of Federated Learning with Alleviated Forgetting in Local Trainin》

论文:https://arxiv.org/abs/2203.02645

代码:https://github.com/Zoesgithub/FedReg

 1 Abstract

作者观察到,现有方法收敛速度缓慢是由于每个客户端局部训练阶段的灾难性遗忘问题造成的,这导致其他客户的先前训练数据的损失函数大幅增加。

因此作者提出了一种FedReg算法,通过对生成的伪数据的损失来调整局部训练的参数,并对全局模型学习到的先前训练数据的知识进行编码,从而大大提高收敛速度,同时可以更好的保护隐私。

 2 Introduction

一些FL算法被设计要通过减少异质性问题的差异来改进FedAvg,但是当采用深度神经网络架构时,这些方法的性能仍然远不能令人满意,另一方面,最近的文献工作表明训练后的模型参数的传输并不能保证对隐私的保护,虽然DP可以防止隐私泄露,但是当DP加入FL时模型的性能持续衰减。

作者观察到,当数据为non-i.i.d时在整个客户中,本地训练的模型严重忘记了其他客户对以前的训练数据的知识(即众所周知的灾难性遗忘问题),这可能是由于本地数据分布和全局数据分布之间的差异。这种遗忘问题导致客户端损失大幅增加,我们提出FedReg通过减轻局部训练阶段的灾难性遗忘问题来降低训练中的通信成本。

FedReg通过使用生成的伪数据对局部训练参数进行正则化来减少知识遗忘,这些伪数据是通过使用修改后的局部数据对全局模型学习到的先前训练数据的知识进行编码而获得的。伪数据与本地数据中的知识的潜在冲突通过使用扰动数据得到抑制,扰动数据是通过对本地数据进行小扰动而产生的,它们有助于确保其预测值。伪数据和扰动数据的生成只依赖于从服务器接收到的全局模型和当前客户端的本地数据。

作者证明,当跨客户端的数据是非独立同分布的时,本地训练阶段的灾难性遗忘是减慢 FL 训练过程的重要因素,因此提出了一种算法 FedReg,它通过使用生成的伪数据减轻灾难性遗忘来加速 FL。

灾难性遗忘:指的是人工智能系统,如深度学习模型,在学习新任务或适应新环境时,忘记或丧失了以前习得的一些能力。当神经网络在多个任务上按顺序训练时,就会发生灾难性遗忘,在这种情况下,当前任务的最佳参数可能在先前任务的目标上表现不佳。在深度神经网络学习不同任务的时候,相关权重的快速变化会损害先前任务的表现,造成人工智能系统在原有任务或环境性能大幅下降。

 3 Method

主要挑战是如何减轻每个客户对先前学习知识的遗忘,而不必在本地培训阶段访问其他客户的数据。我们首先生成伪数据,然后通过使用伪数据上的损失对局部训练的参数进行正则化来缓解灾难性遗忘问题。

生成伪数据:fast gradient sign method,FGSM,是一种对抗样本生成方法,根据本轮全局模型梯度反方向生成对抗样本:

通过生成对抗样本,基于本地模型分类结果与数据标签相差较大的数据,可以对本地模型很好起到正则化效果。

尽管上面生成的伪数据放松了约束,但由于训练过程中全局模型的不准确性,可能会导致对抗样本与本地数据冲突,导致模型学习到一些错误信息,为了进一步消除这种相互冲突的信息,对本地数据进行轻微扰动,个人理解可以增强鲁棒性:

其中扰动程度非常小,即n_p<<n_s,以确保扰动数据比伪数据更接近本地数据。

接下来,使用生成的伪数据和扰动数据,进行正则化以减轻灾难性遗忘:

其中约束 (4) 缓解了灾难性的遗忘问题,约束 (5) 消除了 (4) 中引入的冲突信息,约束 (5) 还有助于提高结果模型的鲁棒性。

再通过求解以下约束优化问题来逼近最优参数 θ(t,i)∗:

更进一步,本地模型参数在每个训练步骤中更新为:

进一步注意到,在分类问题中,伪数据可用于修改梯度从而增强隐私保护,伪数据与真实数据相比可能是包含相似的语义信息但不同的分类信息,因此伪数据增强FL隐私保护能力而不会严重降低所得模型的性能,如下所示Di代表原数据、Di_s代表伪数据:

从而通过修正梯度来增强隐私保护。

关于FedReg方法流程图如下图1所示,通过对本地模型施加正则化,加快全局模型收敛速度并提升精度。

图1:FedReg方法

 4 Experiments

数据集:MNIST、EMNIST、CIFAR-10、CIFAR-100、CT images related to COVID-19(一个基于COVID-19胸部CT图片)。

收敛速度比较:与基线方法相比,FedReg 只需要更少的通信轮次达到收敛,并获得更高的最终准确度,如下表1所示。

表1:收敛率比较

减轻灾难性遗忘:如下图2所示,为了证明 FedReg 确实减轻了灾难性遗忘,在 FedReg 和 FedAvg 之间比较了其他客户端先前训练数据的损失值的增加情况。FedReg 中损失的增加幅度明显低于 FedAvg,这表明虽然 FedAvg 和 FedReg 都忘记了一些学习知识,但 FedReg 的遗忘问题并不严重。

顶行表示FedReg与FedAvg的损失情况。底行关于信息相似性数据表明在伪数据上具有高性能的模型在先前的训练数据上表现良好的概率相对较高(分布较为均匀),使用伪数据对参数进行正则化有助于减少先前训练数据的性能下降并缓解知识遗忘问题。

图2:减轻灾难性遗忘

步长超参数:当生成伪数据的步长n_s太大时,生成的伪数据与本地数据的距离/差异性会太大,无法有效地对模型参数进行正则化。另一方面,当n_s太小时,来自不准确的全局模型的冲突信息会减慢收敛速度。当 ηs 太小时,需要更多的通信轮次才能达到相同的精度;随着 ηs 的增加,通信轮数减少到最小,然后反弹回来,表明此时在局部训练阶段正则化对模型参数不太有效。

隐私保护:如下表2所示,梯度反转攻击用于从分类问题中的更新梯度中恢复信息,从下表中可以看出,在基线方法中,使用 DP 保护隐私信息的性能下降幅度很大,而使用 FedReg 在保持相似的隐私保护级别时,性能下降幅度要小得多。相比较而言,FedReg 能够保护敏感的时间信息,但是模型性能只有轻微的衰减。

表2:隐私保护比较

 5 Conclusion

三点总结:

  1. 在这项工作中,作者提出了一种新的算法 FedReg,通过减轻局部训练阶段的灾难性遗忘问题来加快 FL 的收敛速度。生成伪数据以携带有关全局模型学习的先前训练数据的知识,而不会产生额外的通信成本或访问其他客户端提供的数据。

  2. 生成的伪数据包含与其他客户端之前的训练数据相似信息,因此可以通过使用伪数据对本地训练的参数进行正则化来缓解遗忘问题。

  3. 伪数据还可以用于防御分类问题中的梯度反转攻击,与 DP 相比,结果模型的性能只有轻微的衰减。

 6 补充: Fast Gradient Sign Method

我们说到作者生成伪数据是通过Fast Gradient Sign Method,这是一种对抗样本数据生成方法,如下图3所示,横坐标表示单维x输入值,纵坐标表示损失值,函数图像是损失函数,损失值越大表示越大概率分类错误,假设灰的线上方为分类错误,下方为分类正确;

以样本点x1为例,根据公式,此时的偏导函数为负,则黑色箭头方向为扰动方向,同理x2样本在取值为正时,也沿着黑色箭头方向变化,只要我们的取值合适,就能生成对抗样本,使得分类错误。总之扰动方向就是使得损失函数变大的方向,通过扰动使得样本被分类错误。

图3:FGSM

参考文献

Xu, C., Hong, Z., Huang, M., & Jiang, T. (2022). Acceleration of Federated Learning with Alleviated Forgetting in Local Training. International Conference on Learning Representations 2022. https://openreview.net/forum?id=541PxiEKN3F.

END

欢迎加入「计算机视觉交流群

提升速度与精度,FedReg: 减轻灾难性遗忘加速联邦收敛(ICLR 2022)相关推荐

  1. 怎样缓解灾难性遗忘?持续学习最新综述三篇

    本文转载自公众号"夕小瑶的卖萌屋",专业带逛互联网算法圈的神操作 ----->我是传送门 关注后,回复以下口令: 回复[789] :领取深度学习全栈手册(含NLP.CV海量综 ...

  2. 弹性响应蒸馏 | 用弹性响应蒸馏克服增量目标检测中的灾难性遗忘

      欢迎关注我的公众号 [极智视界],获取我的更多笔记分享   大家好,我是极智视界,本文解读一下 用弹性蒸馏克服增量目标检测中的灾难性遗忘.   传统的目标检测不适用于增量学习.然而,仅用新数据直接 ...

  3. 如何利用增量学习的方法来解决灾难性遗忘的问题?

    增量学习是一种逐步学习新数据的方法,通过在新数据上更新模型而不是从头开始训练.这种方法在很大程度上可以缓解灾难性遗忘问题,因为它试图在学习新知识的同时保留已有知识.以下是一些使用增量学习解决灾难性遗忘 ...

  4. 克服神经网络中的灾难性遗忘(EWC):Overcoming catastrophic forgetting inneural networks

    克服神经网络中的灾难性遗忘 Introduction Results EWC Extends Memory Lifetime for Random Patterns EWC Allows Contin ...

  5. 机器人操作持续学习论文(1)原文阅读与翻译——机器人操作中无灾难性遗忘的原语生成策略学习

    Primitives Generation Policy Learning without Catastrophic Forgetting for Robotic Manipulation 1机器人操 ...

  6. YOLO V4 Tiny改进版来啦!速度294FPS精度不减YOLO V4 Tiny

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 此YOLO V4 Tiny改进在保证精度的同时帧率可以达到294FPS!具有比YOLOv4-tiny( ...

  7. 论文速递:通过模拟大脑-解决深度学习中的灾难性遗忘

    来源:混沌巡洋舰 灾难性遗忘指的是:不像人类,当深度学习模型被训练完成新的任务时,他们很快就会忘记以前学过的东西.8月13号的自然通讯论文Brain-inspired replay for conti ...

  8. TNNLS 22|分数不是关键,排名才是关键:针对排行榜的模型“行为”保持与灾难性遗忘的克服...

    本文认为对于类增量学习任务而言,单个点在特征空间的位置不是关键,它们之间距离值也不是关键,它们两两距离的排序才是重中之重.为此我们提出了一种新的类增量学习模型并设计了一个可导的排序算法,已被 IEEE ...

  9. 不能兼顾速度与精度,STOC 2021最佳论文揭示梯度下降复杂度理论

    ©作者 | 机器之心编辑部 来源 | 机器之心 梯度下降算法具有广泛的用途,但是关于它的计算复杂度的理论研究却非常少.最近,来自利物浦大学.牛津大学的研究者从数学的角度证明了梯度下降的计算复杂度,这项 ...

最新文章

  1. python property内建函数的介绍
  2. matlab上位机串口通信,MATLAB GUIDE 上位机串口通信开发 绘制图形
  3. html 获得文本节点,JavaScript获取节点——获取文本节点
  4. 7 centos 源码安装samba_centos 7 安装 samba 服务
  5. 使用ABAP代码获得tcode RZ11里的参数值
  6. js 给动态li添加动态点击事件
  7. 一步一步分析vue之_data属性
  8. JAVA面试要点010---重入锁_ReentrantLock 详解
  9. ssm项目打包到云服务器,ssm项目打包到云服务器
  10. 1 月份 Github 上最热门最有价值的开源项目
  11. 解决java环境变量配置不生效
  12. GX Works2 安装详细过程
  13. JAVA数据库的操作(增、删、改、查)
  14. WinCC V7.4 过程值归档概述及流程演示
  15. 冬瓜哥祝大家新年快乐!
  16. 认识kata-containers
  17. 无人驾驶一 协方差矩阵的几何意义
  18. 腾讯云服务器公网流量是如何计算的?出流量还是入流量?
  19. 公司要求实时监控服务器,写个Web的监控系统
  20. Charles ios无法下载证书- chls.pro/ssl一直加载治标办法

热门文章

  1. 【Proxy SwitchyOmega】Chrome安装插件【提示程序包无效:“CRX_HEADER_INVALID“】【解决方法】
  2. H264视频压缩编码标准简介(一)
  3. [NOIP2018]铺设道路
  4. ai作文批改_好未来:AI智能批改中英文作文为老师“减负”
  5. jodd忽略ssl证书_关于java访问https资源时,忽略证书信任问题
  6. pythontime模块计算时长_用python的time模块查看你出生多长时间了
  7. Android 自动动画布局更新 使用,在RecyclerView上使用布局动画(Layout animation)
  8. 中表名字必须大写吗_pi network改名字的重要性—非常之重要!!
  9. python多目标跟踪卡尔曼滤波_卡尔曼多目标跟踪的例子?
  10. PHP群发300万,mysql 300万数据查询500多秒如何优化