[联邦学习TFF]构建自己的联邦学习模型
TFF 全称 tensorflow_federated,为谷歌的联邦学习框架。
在TFF官网的Building Your Own Federated Learning Algorithm
界面中,介绍了如何尽可能多的利用现有的TensorFlow代码,构建一个TFF的模型。
本文的行文结构
本文共分为3个章节,其中第1章介绍了TFF的框架,然后给出了客户端和服务器的模型参数更新函数。第2章到主要介绍Federated core的内容。第3章主要把前2章的内容串起来,构建自己的TFF框架。
1.TFF框架的构成
1.1 TFF框架可以分成4个步骤:
- 服务器向客户端(server-to-client)传递初始模型参数
- 客户端更新模型参数
- 客户端向服务器(client-to-server)传递参数
- 服务器更新参数
1.2 这4个步骤又可以根据是否使用纯TensorFlow的代码分为两类:
第一类:全部使用TensorFlow代码构建
包括第2和第4步,客户端更新模型参数,服务器更新参数。
第二类:使用Federated Core代码构建
包括第1和第3步,server-to-client和client-to-server
下面就详细介绍纯TensorFlow环节的要点。需要Federated Core构建的放在后面第5部分。
1.2.1 客户端更新参数
可以分成两步:
(1)从服务器模型获取客户端模型参数,此处的服务器模型是经过第一步传递过来的模型
(2)客户端模型在客户端数据集上训练和更新参数
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):"""Performs training (using the server model weights) on the client's dataset."""# Initialize the client model with the current server weights.client_weights = model.trainable_variables# Assign the server weights to the client model.tf.nest.map_structure(lambda x, y: x.assign(y),client_weights, server_weights)# Use the client_optimizer to update the local model.for batch in dataset:with tf.GradientTape() as tape:# Compute a forward pass on the batch of dataoutputs = model.forward_pass(batch)# Compute the corresponding gradientgrads = tape.gradient(outputs.loss, client_weights)grads_and_vars = zip(grads, client_weights)# Apply the gradient using a client optimizer.client_optimizer.apply_gradients(grads_and_vars)return client_weights
1.2.2 服务器更新参数
在服务器上更新参数主要涉及更新策略。此处采用了较为简单的"vanilla" 联合平均算法,直接取各个客户端模型参数的平均值作为服务器的模型参数。这里的模型参数只包括可训练的参数。
@tf.function
def server_update(model, mean_client_weights):"""Updates the server model weights as the average of the client model weights."""model_weights = model.trainable_variables# Assign the mean client weights to the server model.tf.nest.map_structure(lambda x, y: x.assign(y),model_weights, mean_client_weights)return model_weights
2.关于Federated Core (FC)
FC包含底层和顶层两个维度的API。具体而言,FC是服务于tff.learning
API的底层的接口。然而,FC又是一个顶层的开发环境,它提供了一种更加紧凑的程序逻辑,把TensorFlow的代码和分布式通信操作(包括分布式求和和广播)结合起来。
FC的目标是允许开发者明确地控制系统中的分布式通信(例如点对点的网络消息交换),而不需要了解实施的细节。
TFF的设计之初就是为了数据的隐秘性,所以FC允许用户明确的控制数据应该在哪一个层面,防止数据泄露。
2.1 联邦数据(Federated data)
TFF中的一个关键概念是“联合数据”(Federated data),它是指分布式系统中一组设备上托管的数据项的集合(例如,客户端数据集或服务器模型权重)。跨所有设备的整个值集合表示为单个联合值。
例如,假设存在客户端设备,每个客户端设备都有一个表示传感器温度的浮点。这些浮点可以通过下面的式子表示为联合浮点:
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)
联合类型由其成员成分的变量类型“T”(例如“tf.float32”)和设备组“G”指定。通常,“G”是tff.CLIENTS
或tff.SERVER
。这种联合类型表示为{T}@G’,如下所示:
str(federated_float_on_clients)
# '{float32}@CLIENTS'
2.2 联邦计算(Federated computations)
TFF接受联合值作为输入,并且把联合值作为输出。例如,假设您想要平均客户端传感器上的温度。您可以定义以下内容(使用我们的联合浮点):
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):return tff.federated_mean(client_temperatures)
对于联邦计算,作者用了一句话来定义:
It is a specification of a distributed system in an internal platform-independent glue language.
它是一个分布式系统的规范,采用独立于内部平台的“粘合语言”
此处的tff.federated_computation
接受联合类型{float32}@CLIENTS
的参数,并返回联合类型{float 32}@SERVER
的值。联邦计算也可以从服务器到客户机、从客户机到客户端或从服务器到服务器。联邦计算也可以像普通函数一样组成,只要它们的类型签名匹配即可。
get_average_temperature([68.5, 70.3, 69.8])
# 相当于(68.5+70.3+69.8)/3
2.3 关于非渴望计算和TensorFlow
- TFF操作的是联合数值
- 每一个联合值都有一个联合类型
(Federated type)
,包括类型(type)
和分配(placement)
。 - 联合数值可以使用联合计算来传递,必须使用
tff.federated_computation
加上联合类型去修饰。 - TensorFlow code必须包含在
tff.tf_computation
的修饰块里面,然后可以将这些块合并到federated_computation
中
3.关于构建自己的联邦学习算法
我们定义了 initialize_fn
and next_fn
来完成联邦学习的步骤。
在1.2中介绍了服务器参数更新server_update
和客户端参数client_update
更新,都是由TensorFlow代码构成,
但是为了能够实现联邦计算,要把initialize_fn
and next_fn
变成一个tff.federated_computation
.
3.1 TFF blocks
3.1.1 创建初始化计算
使用model_fn来创建一个我们的模型,然后使用tff.tf_computation
来把TF的代码分开。
@tff.tf_computation
def server_init():model = model_fn()return model.trainable_variables
然后我们可以通过tff.federated_value
来把服务器初始化参数传递到联邦计算中:
@tff.federated_computation
def initialize_fn():return tff.federated_value(server_init(), tff.SERVER)
3.1.2 创建next_fn函数
服务器和客户端更新代码可以用于编写实际算法。首先,需要把 client_update
转变成tff.tf_computation
,接收一个客户端的数据集和服务器的参数,并且输出一个更新后的客户端参数tensor。
需要对函数添加相应的变量类型修饰。幸运的是,服务器的权重类型可以直接通过模型导出。
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
让我们看一看数据集类型签名,假设采用的是Mnist数据集中的数据,里面的样本是是28*28像素的图片,可以展开成784,然后标签是1。
也可以通过server_init
函数提取权重类型:
model_weights_type = server_init.type_signature.result
然后通过str直接打印模型的结构:
str(model_weights_type)
# '<float32[784,10],float32[10]>'
现在,我们知道了tf_dataset_type
和model_weights_type
,然后我们可以为client_update
创建 tff.tf_computation
了:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):model = model_fn()client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)return client_update(model, tf_dataset, server_weights, client_optimizer)
为 server update
创建 tff.tf_computation
的方式是类似的:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):model = model_fn()return server_update(model, mean_client_weights)
最后最重要的是:我们需要创建一个tff.federated_computation
把他们都放在一起。这个函数将会接收2个联邦数值(以及分配情况),一个是服务器相应的权重(分配给tff.SERVER
),另一个是对应的客户端的数据集(分配给tff.CLIENTS
)。
这两个变量的类型上面都已经定义了,也就是model_weights_type
和tf_dataset_type
。
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
3.1.3 构建federated_computation
至此,TFF所需要的组件都已经构建好了,下面就开始把各个组件放到一起。
按照联邦学习的4个步骤,构建next_fn
:
- 服务器参数传递
- 更新客户端参数
- 根据客户端参数计算服务器参数
- 根据服务器参数更新客户端参数
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):# Broadcast the server weights to the clients.server_weights_at_client = tff.federated_broadcast(server_weights)# Each client computes their updated weights.client_weights = tff.federated_map(client_update_fn, (federated_dataset, server_weights_at_client))# The server averages these updates.mean_client_weights = tff.federated_mean(client_weights)# The server updates its model.server_weights = tff.federated_map(server_update_fn, mean_client_weights)return server_weights
3.1.4 tff.templates.IterativeProcess
为了完成我们的算法,还需要把initialize_fn和next_fn传给tff.templates.IterativeProcess
。
federated_algorithm = tff.templates.IterativeProcess(initialize_fn=initialize_fn,next_fn=next_fn
)
可以通过str查看federated_algorithm
的类型:
str(federated_algorithm.initialize.type_signature)
#'( -> <float32[784,10],float32[10]>@SERVER)'
str(federated_algorithm.next.type_signature)
# '(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'
->的左边是传入的参数结构,右边是输出的参数结构。可以清楚的看到,next_fn的参数传入的是服务器的参数,客户端的数据集,输出的是更新的服务器的参数。
3.2 评估算法
终于来到了机动人心的评估算法编写环节。
首先需要构建一个集中的评估数据集,需要对其做训练集一样的预处理:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)
然后,我们需要写一个函数接收一个服务器的状态server_state
,并且使用keras在测试数据集上进行评估。
def evaluate(server_state):keras_model = create_keras_model()keras_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=[tf.keras.metrics.SparseCategoricalAccuracy()] )keras_model.set_weights(server_state)keras_model.evaluate(central_emnist_test)
然后,让我们初始化算法并且在测试集上进行评估。
server_state = federated_algorithm.initialize()
evaluate(server_state)
# 2042/2042 [==============================] - 8s 3ms/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027
然后我们进行联合训练15轮再次评估:
for round in range(15):server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
# 2042/2042 [==============================] - 5s 2ms/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980
可以看到loss小幅度降低了。
3.3 构建自己的算法
经过上面的步骤,我们已经可以导入emnist数据集,基于keras构建模型,然后编写服务器更新函数和客户端更新函数,将其转换到TFF框架下,然后在TFF测试数据集进行训练的评估。
那么如果我们需要构建自己的模型,然后在自己的数据集上进行训练只需要把其中纯TensorFlow构建的模型部分进行修改就可以了。
[联邦学习TFF]构建自己的联邦学习模型相关推荐
- 【阅读笔记】联邦学习实战——构建公平的大数据交易市场
联邦学习实战--构建公平的大数据交易市场 前言 1. 大数据交易 1.1 数据交易定义 1.2 数据确权 1.3 数据定价 2. 基于联邦学习构建新一代大数据交易市场 3. 联邦学习激励机制助力数据交 ...
- 联邦学习综述(二)——联邦学习的分类、框架及未来研究方向
文章目录 第三章 联邦学习分类 3.1 横向联邦学习 3.2 纵向联邦学习 3.3 联邦迁移学习 第四章 联邦学习框架 4.1 开源框架介绍 4.2 FATE--企业级框架 第五章 未来研究方向 5. ...
- 【联邦学习+区块链】《联邦学习vs区块链:谁是“可信媒介”技术领域最强王者?》疑问解答
联邦学习[1]VS 区块链 [问1]联邦学习,何为"联邦"? 作为一种分布式机器学习技术,联邦学习可以实现各个企业的自有数据不出本地,而是通过加密机制下的参数交换方式共建模型,即在 ...
- 【联邦学习 + 区块链】《联邦学习vs区块链:谁是“可信媒介”技术领域最强王者?》阅读记录与提问
[注]块引用部分是博主自己的思考.. 题目:<联邦[1]学习vs区块链:谁是"可信媒介"技术领域最强王者?> [问1]联邦学习,何为"联邦"? 在互 ...
- 通过 DLPack 构建跨框架深度学习编译器
通过 DLPack 构建跨框架深度学习编译器 深度学习框架,如Tensorflow, PyTorch, and ApacheMxNet,快速原型化和部署深度学习模型提供了强大的工具箱.不幸的是,易用性 ...
- R使用深度学习LSTM构建时间序列预测模型
R使用深度学习LSTM构建时间序列预测模型 LSTM的全称是Long Short Term Memory,顾名思义,它具有记忆长短期信息的能力的神经网络.LSTM首先在1997年由Hochreiter ...
- “领域知识图谱的构建与应用”讲座学习笔记
知识图谱是一种基于先进信息技术的型语义工具,它以实体或概念为节点,通过语义关系连接成大规模语义网络,可以帮助机器或信息系统理解语义.组织知识.发现知识,从而为人们提供知识.情报的智慧搜索和智能交互. ...
- 深度学习预测酶活性参数提升酶约束模型构建从头环境搭建
前言 这项工作开发了一种用深度学习来预测酶活性参数的方法(DLKcat),主要采用了针对底物的图神经网络和针对蛋白质的卷积神经网络.通过从公开的数据库中获取和数据预处理,最终获得了超过一万六千条高质量 ...
- 干货 :六招教你用Python分分钟构建好玩的深度学习应用
[导读]深度学习是近来数据科学中研究和讨论最多的话题.得益于深度学习的发展,数据科学在近期得到了重大突破,深度学习也因此得到了很多关注.据预测,在不久的将来,更多的深度学习应用程序会影响人们的生活.实 ...
最新文章
- Caffe学习系列(6):Blob,Layer and Net以及对应配置文件的编写
- 远程java接口说明
- 【Python基础】加密你的Python源码顺便再打个包如何?
- java代码简单操作Redis数据Jedis jar
- 2021-10-12Spring缓存注解@Cacheable、@CacheEvict、@CachePut使用
- 解决微信小程序 [Component] slot ““ is not found.
- 如何将微商管理模式流程化
- java treeset 删除_删除Java TreeSet中的最低元素
- Windows XP 默认蓝色桌面的 RGB
- 大学计算机模拟2014网络应用,2014全校大学计算机基础模拟考试.doc
- 如何拼局域网所有ip_怎么查看 同一个局域网内连的所有的IP地址 - 卡饭网
- 时间序列学习(1):平稳性、自相关性
- 嵌入式软件工程师面试题总结
- 在兼容系统上升级DELL SATA硬盘的固件
- 用对象的上转型对象、方法重写,抽象编程:求柱体的体积。
- 救命稻草VirtualBox,失之交臂VMware—— 2者的guest OS对 恒通笔记本并口卡的支持
- 人人网(cookie登录)
- VMware虚拟机下Ubuntu18.04学校宽带拨号连接网络
- WORD出错:设置为稿纸后,右键的字体、段落项目变灰
- deepin访问不了网页