联邦学习基本算法FedAvg的代码实现
目录
- I. 前言
- II. 数据介绍
- 1. 特征构造
- III. 联邦学习
- 1. 整体框架
- 2. 服务器端
- 3. 客户端
- 4. 代码实现
- 4.1 初始化
- 4.2 服务器端
- 4.3 客户端
- 4.4 测试
- IV. 实验及结果
- V. 源码及数据
I. 前言
联邦学习(Federated Learning) 是人工智能的一个新的分支,这项技术是谷歌2016年于论文Communication-Efficient Learning of Deep Networks from Decentralized Data中首次提出。
在我的另一篇博文联邦学习:《Communication-Efficient Learning of Deep Networks from Decentralized Data中详细解析了该篇论文,而本篇博文的目的是利用这篇解读文章对原始论文中的FedAvg方法进行复现。
因此,阅读本文前建议先阅读联邦学习:《Communication-Efficient Learning of Deep Networks from Decentralized Data。
II. 数据介绍
联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。
本文选用的数据集为中国北方某城市十个区/县从2016年到2019年三年的真实用电负荷数据,采集时间间隔为1小时,即每一天都有24个负荷值。
我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。
除了电力负荷数据意外,还有风功率数据,两个数据通过参数type指定:type == 'load’表示负荷数据,'wind’表示风功率数据。
1. 特征构造
用某一时刻前24个时刻的负荷值以及该时刻的相关气象数据(如温度、湿度、压强等)来预测该时刻的负荷值。
对于风功率数据,同样使用某一时刻前24个时刻的风功率值以及该时刻的相关气象数据来预测该时刻的风功率值。
各个地区应该就如何制定特征集达成一致意见,本文使用的各个地区上的数据的特征是一致的,可以直接使用。
III. 联邦学习
1. 整体框架
原始论文中提出的FedAvg的框架为:
由于本文中需要利用各个客户端的模型参数来对服务器端的模型参数进行更新,因此本文决定采用numpy搭建一个四层的神经网络模型。模型的具体搭建过程可以参考上一篇博文:从矩阵链式求导的角度来深入理解BP算法(原理+代码)。在这篇博文里面我详细得介绍了神经网络参数更新的过程,这将有助于理解本文中的模型参数更新过程。
神经网络由1个输入层、3个隐藏层以及1个输出层组成,激活函数全部采用Sigmoid函数。
网络各层间的运算关系,也就是前向传播过程如下所示:
z1=Iw1,h1=σ(z1)z2=h1w2,h2=σ(z2)z3=h2w3,h3=σ(z3)z4=h3w4,O=σ(z4)loss=12(O−y)2z_1=Iw_1,h_1=\sigma(z_1)\\ z_2=h_1w_2,h_2=\sigma(z_2)\\ z_3=h_2w_3,h_3=\sigma(z_3)\\ z_4=h_3w_4,O=\sigma(z_4)\\ loss=\frac{1}{2} (O-y)^2z1=Iw1,h1=σ(z1)z2=h1w2,h2=σ(z2)z3=h2w3,h3=σ(z3)z4=h3w4,O=σ(z4)loss=21(O−y)2
其中:
- 输入III的shape为100×inputdim100 \times inputdim100×inputdim,每一次输入100个样本,每个样本的特征数为32
- w1w_1w1的shape为inputdim×20inputdim \times 20inputdim×20
- z1z_1z1和h1h_1h1的shape都为100×20100 \times 20100×20
- w2w_2w2的shape为20×2020 \times 2020×20
- z2z_2z2和h2h_2h2的shape都为100×20100 \times 20100×20
- w3w_3w3的shape为20×2020 \times 2020×20
- z3z_3z3和h3h_3h3的shape都为100×20100 \times 20100×20
- w4w_4w4的shape为20×120 \times 120×1
- z4z_4z4和OOO的shape都为100×1100 \times 1100×1,表示100个样本的输出
- losslossloss为损失函数,shape和输出OOO一致
因此,客户端参数更新实际上就是更新四个www。
2. 服务器端
服务器端执行以下步骤:
- 初始化参数
- 对第ttt轮训练来说:首先计算出m=max(C⋅K,1)m=max(C \cdot K, 1)m=max(C⋅K,1),然后随机选择mmm个客户端,对这mmm个客户端做如下操作(所有客户端并行执行):更新本地的wtkw_t^{k}wtk得到wt+1kw_{t+1}^{k}wt+1k。mmm个客户端更新结束后,将wt+1kw_{t+1}^{k}wt+1k传到服务器,服务器整合所有wt+1kw_{t+1}^{k}wt+1k得到最新的全局参数wt+1w_{t+1}wt+1。
简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后将更新后的参数传给服务器,服务器汇总客户端更新后的参数形成最新的全局参数。下一轮通信时,服务器端将最新的参数分发给被选中的客户端,进行下一轮更新。
3. 客户端
客户端没什么可说的,就是利用本地数据对神经网络模型的参数进行更新。
4. 代码实现
4.1 初始化
参数:
- K,客户端数量,本文为10个,也就是10个地区。
- C:选择率,每一轮通信时都只是选择C * K个客户端。
- E:客户端更新本地模型的参数时,在本地数据集上训练E轮。
- B:客户端更新本地模型的参数时,本地数据集batch大小为B
- r:服务器端和客户端一共进行r轮通信。
- clients:客户端集合。
- type:指定数据类型,负荷预测or风功率预测。
- lr:学习率。
- input_dim:数据输入维度。
- nn:全局模型。
- nns: 客户端模型集合。
代码实现:
class FedAvg:def __init__(self, options):self.C = options['C']self.E = options['E']self.B = options['B']self.K = options['K']self.r = options['r']self.clients = options['clients']self.type = options['type']self.lr = options['lr']self.input_dim = options['input_dim']self.nn = BP(file_name='server', B=B, E=E, input_dim=self.input_dim, type=self.type, lr=self.lr)self.nns = []# distributionfor i in range(self.K):s = copy.deepcopy(self.nn)s.file_name = self.clients[i]self.nns.append(s)
其中self.nnself.nnself.nn是服务器端初始化的全局参数,由于服务器端不需要进行反向传播更新参数,因此不需要定义各个隐层以及输出。
4.2 服务器端
服务器端代码如下:
def server(self):for t in range(self.r):print('第', t + 1, '轮通信:')m = np.max([int(self.C * self.K), 1])# samplingindex = random.sample(range(0, self.K), m)# dispatchself.dispatch(index)# local updatingself.client_update(index)# aggregationself.aggregation(index)# return global modelreturn self.nn
其中client_update(index):
def client_update(self, index): # update nnfor k in index:self.nns[k] = train(self.nns[k])
aggregation(index):
def aggregation(self, index):# update ws = 0for j in index:# normals += self.nns[j].lenw1 = np.zeros_like(self.nn.w1)w2 = np.zeros_like(self.nn.w2)w3 = np.zeros_like(self.nn.w3)w4 = np.zeros_like(self.nn.w4)for j in index:# normalw1 += self.nns[j].w1 * (self.nns[j].len / s)w2 += self.nns[j].w2 * (self.nns[j].len / s)w3 += self.nns[j].w3 * (self.nns[j].len / s)w4 += self.nns[j].w4 * (self.nns[j].len / s)# update serverself.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4 = w1, w2, w3, w4
dispatch(index):
def dispatch(self, index):# dispatchfor i in index:self.nns[i].w1, self.nns[i].w2, self.nns[i].w3, self.nns[i].w4 = self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4
下面对重要代码进行分析:
- 客户端的选择
m = np.max([int(self.C * self.K), 1])
index = random.sample(range(0, self.K), m)
index中存储中m个0~10间的整数,表示被选中客户端的序号。
- 客户端的更新
for k in index:self.client_update(self.nns[k])
- 服务器端汇总客户端模型的参数
关于模型汇总方式,可以参考一下我的另一篇文章:对FedAvg中模型聚合过程的理解。
当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。论文Electricity Consumer Characteristics Identification: A Federated Learning Approach中总结了三种汇总方式:
- normal:原始论文中的方式,即根据样本数量来决定客户端参数在最终组合时所占比例。
- LA:根据客户端模型的损失占所有客户端损失和的比重来决定最终组合时参数所占比例。
- LS:根据损失与样本数量的乘积所占的比重来决定。
- 将更新后的参数分发给客户端
def dispatch(self, inidex):# dispatchfor i in index:self.nns[i].w1, self.nns[i].w2, self.nns[i].w3, self.nns[i].w4 = self.nn.w1, self.nn.w2, self.nn.w3, self.nn.w4
4.3 客户端
客户端只需要利用本地数据来进行更新就行了:
def client_update(self, index): # update nnfor k in index:self.nns[k] = train(self.nns[k])
4.4 测试
def global_test(self):model = self.nnc = clients if self.type == 'load' else clients_windfor client in c:model.file_name = clienttest(model)
IV. 实验及结果
本次实验的参数选择为:
K | C | E | B | r |
---|---|---|---|---|
10 | 0.5 | 50 | 50 | 5 |
if __name__ == '__main__':K, C, E, B, r = 10, 0.5, 50, 50, 5type = 'load'input_dim = 30 if type == 'load' else 28_client = clients if type == 'load' else clients_windlr = 0.08options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client,'input_dim': input_dim, 'lr': lr}fedavg = FedAvg(options)fedavg.server()fedavg.global_test()
各个客户端单独训练(训练50轮,batch大小为50)后在本地的测试集上的表现为:
客户端编号 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 5.79 | 6.73 | 6.18 | 5.82 | 5.49 | 4.55 | 6.23 | 9.59 | 4.84 | 5.49 |
可以看到,由于各个客户端的数据都十分充足,所以每个客户端自己训练的本地模型的预测精度已经很高了。
服务器与客户端通信5轮后,服务器上的全局模型在10个客户端测试集上的表现如下所示:
客户端编号 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 6.58 | 4.19 | 3.17 | 5.13 | 3.58 | 4.69 | 4.71 | 3.75 | 2.94 | 4.77 |
可以看到,经过联邦学习框架得到全局模型在各个客户端上表现同样很好,这是因为十个地区上的数据是独立同分布的。
V. 源码及数据
后面将陆续公开~
联邦学习基本算法FedAvg的代码实现相关推荐
- 联邦学习笔记(四):使用底层API设计联邦学习平均算法
设计联邦学习平均算法 写在前面 联邦平均算法流程 数据处理 加载数据集 数据集预处理 获取处理后的数据集 前向运算和损失函数 获取代码输入类型 创建初始模型 联邦学习训练和梯度下降 单batch梯度下 ...
- 深度强化学习-D3QN算法原理与代码
Dueling Double Deep Q Network(D3QN)算法结合了Double DQN和Dueling DQN算法的思想,进一步提升了算法的性能.如果对Doubel DQN和Duelin ...
- 深度强化学习-DQN算法原理与代码
DQN算法是DeepMind团队提出的一种深度强化学习算法,在许多电动游戏中达到人类玩家甚至超越人类玩家的水准,本文就带领大家了解一下这个算法,论文和代码的链接见下方. 论文:Human-level ...
- 深度强化学习DDPG算法高性能Pytorch代码(改写自spinningup,低环境依赖,低阅读障碍)
写在前面 DRL各种算法在github上各处都是,例如莫凡的DRL代码.ElegantDRL(推荐,易读性NO.1) 很多代码不是原算法的最佳实现,在具体实现细节上也存在差异,不建议直接用在科研上. ...
- 利用谷歌的联邦学习框架Tensorflow Federated实现FedAvg(详细介绍)
目录 I. 前言 II. 数据介绍 III. 联邦学习 1. 整体框架 2. 服务器端 3. 客户端 IV. Tensorflow Federated 1. 数据处理 2. 构造TFF的Keras模型 ...
- 虚拟专题:联邦学习 | 联邦学习算法综述
来源:大数据期刊 联邦学习算法综述 王健宗1 ,孔令炜1 ,黄章成1 ,陈霖捷1 ,刘懿1 ,何安珣1 ,肖京2 1. 平安科技(深圳)有限公司,广东 深圳 518063 2. 中国平安保险(集团)股 ...
- 深度强化学习-Double DQN算法原理与代码
深度强化学习-Double DQN算法原理与代码 引言 1 DDQN算法简介 2 DDQN算法原理 3 DDQN算法伪代码 4 仿真验证 引言 Double Deep Q Network(DDQN)是 ...
- 联邦学习攻击与防御综述
联邦学习攻击与防御综述 吴建汉1,2, 司世景1, 王健宗1, 肖京1 1.平安科技(深圳)有限公司,广东 深圳 518063 2.中国科学技术大学,安徽 合肥 230026 摘要:随着机器学习技术的 ...
- 联邦学习框架和数据隐私综述
联邦学习 --新型的分布式机器学习技术. 一.联邦学习开源框架 1.联邦学习框架(按架构分类) 联邦学习常用的框架分为2种:中心化框架.去中心化框架,以中心化框架为主. 2.联邦学习的分类(按照参与方 ...
最新文章
- 单张图像重建3D人手、人脸和人体
- 红帽中出现”This system is not registered with RHN”的解决方案
- kafka 在阿里云部署
- 进程句柄表初始化,扩展,插入删除句柄源码分析
- 易中天与单田芳的区别在哪儿
- 395. Longest Substring with At Least K Repeating Characters
- python期末知识点_史上最全的Python知识点整理之基本语法
- REVERSE-PRACTICE-BUUCTF-10
- Java之Base64实现文件和字符串之间的转换
- 浅谈opencl之错误码
- 软件测试 学习之路 MYSQL安装
- 教你如何在Python中读,写和解析CSV文
- IMX8 Audio声卡
- Selenium WebDriver 数据驱动测试框架
- 书摘---创业36条军规3:创业人七大须知
- 携程线上测评测试题目,答案解析
- 云硬盘(Elastic Volume Service,EVS)
- 电子计算机技术人才需求,电子与信息技术专业人才需求调研报告.pdf
- 忍者理论谈《嗜血边缘》如何做出独特的4V4战斗音效
- 硅谷一万清华人,为何打不过印度人
热门文章
- SSD固态硬盘 4K对齐
- 1.1.1.1校园网_高一数学上册必修1第一章知识点:1.1.1集合的含义与表示
- SpringBoot导出pdf文件学习
- 忽略' scanf '的返回值,用属性warn_unused_result声明的疑问
- python爬取页面内容由京东提_python制作爬虫爬取京东商品评论教程
- 5.Apache Kylin 构建 第一步报错 Container complete event for unknown container
- std::true_type和std::false_type详解
- VINS-FUSION GPS融合坐标转换细节分析
- PTA 数组 7-5 按字母顺序排列出场国家名称
- 利用ACM服务,快速申请免费的公有证书,你get到了吗?