【论文代码复现】Clustered Sampling: Low-Variance and Improved Representativity for Clients Selection in Fede
目录
一、前言
二、论文内容概要
1. 论文背景:
2. 已有解决方案
3. 论文方法
三、实验
1. 实验设置
2. 代码
3. 修改后试验结果和论文结果对比:
参考:
一、前言
1. 论文地址:
[2105.05883v2] Clustered Sampling: Low-Variance and Improved Representativity for Clients Selection in Federated Learning (arxiv.org)
因为电脑性能有限,所以把四种聚类方式迭代次数都降低到迭代100次,而且只在MNist上,CIFAR10上跑起来也巨慢。官方给的代码跑不通,所以自己就改写了一下。
二、论文内容概要
因为目前只完整地跑完了MNist数据集上的实验,因此暂时先介绍到算法1。
1. 论文背景:
1.1 存在问题或现象:
1)抽样方法有偏倚
2)在服务器-客户端通信和训练的收敛稳定性方面不是最佳
1.2 论文提出的方法特点:
1) 聚类抽样选出的客户机具备更好的客户代表性
2)减少客户端在FL中随机聚集权重的差异(方差)
3)在客户端无需额外操作,可无缝集成至标准FL
4)与现有方法和技术兼容达到隐私增强
5)通过模型压缩减少通信量
2. 已有解决方案
1)FedAvg算法——随机选择m个客户端采样,对这m个客户端的梯度更新进行平均以形成全局更新同时用当前全局模型替换未采样的客户端
优点:相对于FedSGD在相同效果情况下,通讯成本大大降低
缺点:最终的模型是有偏倚的,不同于预期的每个客户端确定性聚合后的模型。
2)多项式分布抽样(MD抽样)算法——客户端抽样的概率对应于他们的相对样本量
优点:
(1)客户端抽样无偏性;
(2)通信量小(FedAvg和MD抽样是服务端-客户端通信最少的唯二方案)
缺点:
(1)仍然可能导致客户选择上有大的差异——选择客户端替换全局模型的次数差异
(2)这种差异导致了FL收敛性变化很大——在non-iid情况下,抽样的客户端都是基于自身数据分布改进全局模型,而未被抽样的客户端的全局模型则被直接替换
(3)损害了非抽样客户端的数据特异性
3. 论文方法
聚类采样方法:
1)Algorithm_1: sample size——基于样本大小的聚类采样聚合客户端的实现方法,该方法减少了客户端聚合权重的方差
2)Algorithm_2: models similarity——基于模型相似性,根据代表性梯度,将客户聚类,使得采样的客户端更具有代表性
优点:
1)增加了在全局模型中客户端的代表性,具有唯一分布的客户端,更有可能被采样,
2)并有可能导致更平滑、更快速的FL收敛,
3)两种方法都具备无偏性。
3)算法1——基于样本量大小的聚类
1) 基于样本数量的聚类抽样
参考[1]Blog 文中写的。
三、实验
1. 实验设置
参考[1]Blog 文中写的。
2. 代码
def get_clusters_with_alg1(n_sampled: int, weights: np.array):"Algorithm 1"epsilon = int(10 ** 6)# associate each client to a clusteraugmented_weights = np.array([w * n_sampled * epsilon for w in weights])ordered_client_idx = np.flip(np.argsort(augmented_weights))n_clients = len(weights)distri_clusters = np.zeros((n_sampled, n_clients)).astype(int)k = 0for client_idx in ordered_client_idx:while augmented_weights[client_idx] > 0:sum_proba_in_k = np.sum(distri_clusters[k])u_i = min(epsilon - sum_proba_in_k, augmented_weights[client_idx])#u_i = augmented_weights[client_idx]distri_clusters[k, :client_idx] = u_isum_proba_k = np.sum(distri_clusters[k])if sum_proba_k == client_idx * int(augmented_weights[client_idx]):augmented_weights[client_idx] += -u_ik += 1distri_clusters = distri_clusters.astype(float)for l in range(n_sampled):distri_clusters[l] /= np.sum(distri_clusters[l])return distri_clusters
参考官网代码出现两个问题:
1. k值始终无法增加,原因是初始的epsilon设置的太大了,不过也可能是我电脑问题,sum_proba_k=np.sum(distri_clusters[k])时,无法表示int64的数。
2. n_sampled给定的是10,但是client_idx循环次数远超过10次(100次)。
3. 修改后试验结果和论文结果对比:
1)代码结果:
2) 论文中结果:
训练次数比较少,不过大致看起来比较像,可能我改的也有问题,后面再继续看下无偏性和其他的。CIFAR10数据集的跑了α=0.001的,但是跑起来太慢了。
参考:
[1]Blog: 联邦学习——基于聚类抽样进行客户选择_联邦学习小白-CSDN博客https://blog.csdn.net/weixin_42534493/article/details/119330027
[2]Official Code: [2105.05883v2] Clustered Sampling: Low-Variance and Improved Representativity for Clients Selection in Federated Learning (arxiv.org)https://arxiv.org/abs/2105.05883v2
[3]分享一下画图宝库:matplotlib.pyplot.subplots — Matplotlib 3.4.3 documentationhttps://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplots.html
【论文代码复现】Clustered Sampling: Low-Variance and Improved Representativity for Clients Selection in Fede相关推荐
- 【论文代码复现2】Clustered sampling based on sample size
[论文代码复现]Clustered Sampling: Low-Variance and Improved Representativity for Clients Selection in Fede ...
- 论文代码复现常见问题
论文代码复现常见问题 场景1:代码太慢 1:写出基础代码 使用少量数据集来实践你的思路,代码可以不用很优化,优先写出来即可.写完后建议封装为函数,方便调用. 2:优化代码逻辑 代码本身足够高效吗? 代 ...
- 深度学习论文 代码复现 环境配置操作
***深度学习论文代码复现 前置工作 安装Ubuntu18.04 安装Nvidia显卡驱动 安装anaconda 安装CUDA与cuDNN 通过软链接的修改实现多版本CUDA间的切换 将~/.bash ...
- stylegan2论文代码复现超详细
stylegan2论文解读 论文就略过啦,参考别人博客了解一下 https://blog.csdn.net/g11d111/article/details/109187245 stylegan2原论文 ...
- 论文代码复现环境配置
论文代码复现环境配置 报错1:python导入tensorflow出现_np_qint8 = np.dtype([("qint8", np.int8, 1)]) 报错2:Value ...
- 进阶必备:CNN经典论文代码复现 | 附下载链接
经常会看到类似的广告<面试算法岗,你被要求复现论文了吗?>不好意思,我真的被问过这个问题.当然也不是所有面试官都会问,究其原因,其实也很好理解.企业肯定是希望自己的产品是有竞争力,有卖点的 ...
- 论文代码复现Enhancing the Transferability of Adversarial Attacks through Variance Tuning
<Enhancing the Transferability of Adversarial Attacks through Variance Tuning>CVPR2021 论文下载地址 ...
- 论文代码复现之:GPT-too: A Language-Model-First Approach for AMR-to-Text-Generation(ARM-to-text)
文章目录 资源引用 复现过程 虚拟环境创建 通过 pip 或者 anaconda 安装依赖 数据集 下载 GPT-2 预训练模型(medium尺寸的)并进行训练 解决作者的代码错误 解决安装包的版本问 ...
- AUTOVC: Zero-Shot Voice Style Transfer with Only Autoencoder Loss 论文代码复现
0. 说明 https://github.com/auspicious3000/autovc 但是听Demo中, 涉及到unseen的情况, 合成音色确实像, 但是质量不满足商用 复现Git的代码, ...
最新文章
- docker 离线安装 mysql_Oracle数据库之docker 离线环境安装oracle
- 需要用到的各种Jar包
- sql order by+字段,指定按照哪个字段来排序
- 问答专场 | 我是高级商业产品总监吴波,你有什么想问的?
- matlab GUI——按下按钮在指定的坐标下绘制函数图像
- 数据结构—图的基本概念
- finally中关闭资源
- mysql环境变量配置还是不行_为什么要配置mysql环境变量
- Python 百度智能云文字识别 实现手写文字识别
- linux 提取网卡驱动,linux(ubuntu18.04)系统上安装RTL8822CE网卡驱动
- 来了!5G和AI的未来 这10位行业领袖这么说
- GGSN - SCP 业务控制点
- ChatGPT 体验和思考
- 京东和区块链的那些事儿
- 前端:Tomcat服务器部署Web项目
- Java项目:springboot私人牙医管理系统
- CF-1200D White Lines(前缀和来两发么小老弟?)
- Web 2.0时代RSS的.Net实现
- 02 电商数仓(数据采集模块)
- 各大搜索引擎下拉词长尾词API接口