TFF 全称 tensorflow_federated,为谷歌的联邦学习框架。

在TFF官网的Building Your Own Federated Learning Algorithm
界面中,介绍了如何尽可能多的利用现有的TensorFlow代码,构建一个TFF的模型。

本文的行文结构

本文共分为3个章节,其中第1章介绍了TFF的框架,然后给出了客户端和服务器的模型参数更新函数。第2章到主要介绍Federated core的内容。第3章主要把前2章的内容串起来,构建自己的TFF框架。

1.TFF框架的构成

1.1 TFF框架可以分成4个步骤:

  1. 服务器向客户端(server-to-client)传递初始模型参数
  2. 客户端更新模型参数
  3. 客户端向服务器(client-to-server)传递参数
  4. 服务器更新参数

1.2 这4个步骤又可以根据是否使用纯TensorFlow的代码分为两类:

第一类:全部使用TensorFlow代码构建
包括第2和第4步,客户端更新模型参数,服务器更新参数。

第二类:使用Federated Core代码构建
包括第1和第3步,server-to-client和client-to-server

下面就详细介绍纯TensorFlow环节的要点。需要Federated Core构建的放在后面第5部分。

1.2.1 客户端更新参数

可以分成两步:
(1)从服务器模型获取客户端模型参数,此处的服务器模型是经过第一步传递过来的模型
(2)客户端模型在客户端数据集上训练和更新参数

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):"""Performs training (using the server model weights) on the client's dataset."""# Initialize the client model with the current server weights.client_weights = model.trainable_variables# Assign the server weights to the client model.tf.nest.map_structure(lambda x, y: x.assign(y),client_weights, server_weights)# Use the client_optimizer to update the local model.for batch in dataset:with tf.GradientTape() as tape:# Compute a forward pass on the batch of dataoutputs = model.forward_pass(batch)# Compute the corresponding gradientgrads = tape.gradient(outputs.loss, client_weights)grads_and_vars = zip(grads, client_weights)# Apply the gradient using a client optimizer.client_optimizer.apply_gradients(grads_and_vars)return client_weights

1.2.2 服务器更新参数

在服务器上更新参数主要涉及更新策略。此处采用了较为简单的"vanilla" 联合平均算法,直接取各个客户端模型参数的平均值作为服务器的模型参数。这里的模型参数只包括可训练的参数。

@tf.function
def server_update(model, mean_client_weights):"""Updates the server model weights as the average of the client model weights."""model_weights = model.trainable_variables# Assign the mean client weights to the server model.tf.nest.map_structure(lambda x, y: x.assign(y),model_weights, mean_client_weights)return model_weights

2.关于Federated Core (FC)

FC包含底层和顶层两个维度的API。具体而言,FC是服务于tff.learningAPI的底层的接口。然而,FC又是一个顶层的开发环境,它提供了一种更加紧凑的程序逻辑,把TensorFlow的代码和分布式通信操作(包括分布式求和和广播)结合起来。

FC的目标是允许开发者明确地控制系统中的分布式通信(例如点对点的网络消息交换),而不需要了解实施的细节。

TFF的设计之初就是为了数据的隐秘性,所以FC允许用户明确的控制数据应该在哪一个层面,防止数据泄露。

2.1 联邦数据(Federated data)

TFF中的一个关键概念是“联合数据”(Federated data),它是指分布式系统中一组设备上托管的数据项的集合(例如,客户端数据集或服务器模型权重)。跨所有设备的整个值集合表示为单个联合值

例如,假设存在客户端设备,每个客户端设备都有一个表示传感器温度的浮点。这些浮点可以通过下面的式子表示为联合浮点

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

联合类型由其成员成分的变量类型“T”(例如“tf.float32”)和设备组“G”指定。通常,“G”是tff.CLIENTStff.SERVER。这种联合类型表示为{T}@G’,如下所示:

str(federated_float_on_clients)
# '{float32}@CLIENTS'

2.2 联邦计算(Federated computations)

TFF接受联合值作为输入,并且把联合值作为输出。例如,假设您想要平均客户端传感器上的温度。您可以定义以下内容(使用我们的联合浮点):

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):return tff.federated_mean(client_temperatures)

对于联邦计算,作者用了一句话来定义:

It is a specification of a distributed system in an internal platform-independent glue language.
它是一个分布式系统的规范,采用独立于内部平台的“粘合语言”

此处的tff.federated_computation接受联合类型{float32}@CLIENTS的参数,并返回联合类型{float 32}@SERVER的值。联邦计算也可以从服务器到客户机、从客户机到客户端或从服务器到服务器。联邦计算也可以像普通函数一样组成,只要它们的类型签名匹配即可。

get_average_temperature([68.5, 70.3, 69.8])
# 相当于(68.5+70.3+69.8)/3

2.3 关于非渴望计算和TensorFlow

  • TFF操作的是联合数值
  • 每一个联合值都有一个联合类型(Federated type),包括类型(type)和分配(placement)
  • 联合数值可以使用联合计算来传递,必须使用tff.federated_computation加上联合类型去修饰。
  • TensorFlow code必须包含在tff.tf_computation的修饰块里面,然后可以将这些块合并到federated_computation

3.关于构建自己的联邦学习算法

我们定义了 initialize_fn and next_fn 来完成联邦学习的步骤。
在1.2中介绍了服务器参数更新server_update和客户端参数client_update更新,都是由TensorFlow代码构成,
但是为了能够实现联邦计算,要把initialize_fn and next_fn 变成一个tff.federated_computation.

3.1 TFF blocks

3.1.1 创建初始化计算

使用model_fn来创建一个我们的模型,然后使用tff.tf_computation来把TF的代码分开。

@tff.tf_computation
def server_init():model = model_fn()return model.trainable_variables

然后我们可以通过tff.federated_value来把服务器初始化参数传递到联邦计算中:

@tff.federated_computation
def initialize_fn():return tff.federated_value(server_init(), tff.SERVER)

3.1.2 创建next_fn函数

服务器和客户端更新代码可以用于编写实际算法。首先,需要把 client_update 转变成tff.tf_computation,接收一个客户端的数据集和服务器的参数,并且输出一个更新后的客户端参数tensor。

需要对函数添加相应的变量类型修饰。幸运的是,服务器的权重类型可以直接通过模型导出。

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

让我们看一看数据集类型签名,假设采用的是Mnist数据集中的数据,里面的样本是是28*28像素的图片,可以展开成784,然后标签是1。

也可以通过server_init函数提取权重类型:

model_weights_type = server_init.type_signature.result

然后通过str直接打印模型的结构:

str(model_weights_type)
# '<float32[784,10],float32[10]>'

现在,我们知道了tf_dataset_typemodel_weights_type,然后我们可以为client_update创建 tff.tf_computation了:

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):model = model_fn()client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)return client_update(model, tf_dataset, server_weights, client_optimizer)

server update创建 tff.tf_computation的方式是类似的:

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):model = model_fn()return server_update(model, mean_client_weights)

最后最重要的是:我们需要创建一个tff.federated_computation把他们都放在一起。这个函数将会接收2个联邦数值(以及分配情况),一个是服务器相应的权重(分配给tff.SERVER),另一个是对应的客户端的数据集(分配给tff.CLIENTS)。

这两个变量的类型上面都已经定义了,也就是model_weights_typetf_dataset_type

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

3.1.3 构建federated_computation

至此,TFF所需要的组件都已经构建好了,下面就开始把各个组件放到一起。
按照联邦学习的4个步骤,构建next_fn

  1. 服务器参数传递
  2. 更新客户端参数
  3. 根据客户端参数计算服务器参数
  4. 根据服务器参数更新客户端参数
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):# Broadcast the server weights to the clients.server_weights_at_client = tff.federated_broadcast(server_weights)# Each client computes their updated weights.client_weights = tff.federated_map(client_update_fn, (federated_dataset, server_weights_at_client))# The server averages these updates.mean_client_weights = tff.federated_mean(client_weights)# The server updates its model.server_weights = tff.federated_map(server_update_fn, mean_client_weights)return server_weights

3.1.4 tff.templates.IterativeProcess

为了完成我们的算法,还需要把initialize_fn和next_fn传给tff.templates.IterativeProcess

federated_algorithm = tff.templates.IterativeProcess(initialize_fn=initialize_fn,next_fn=next_fn
)

可以通过str查看federated_algorithm 的类型:

str(federated_algorithm.initialize.type_signature)
#'( -> <float32[784,10],float32[10]>@SERVER)'
str(federated_algorithm.next.type_signature)
# '(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

->的左边是传入的参数结构,右边是输出的参数结构。可以清楚的看到,next_fn的参数传入的是服务器的参数,客户端的数据集,输出的是更新的服务器的参数。

3.2 评估算法

终于来到了机动人心的评估算法编写环节。
首先需要构建一个集中的评估数据集,需要对其做训练集一样的预处理:

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

然后,我们需要写一个函数接收一个服务器的状态server_state,并且使用keras在测试数据集上进行评估。

def evaluate(server_state):keras_model = create_keras_model()keras_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  )keras_model.set_weights(server_state)keras_model.evaluate(central_emnist_test)

然后,让我们初始化算法并且在测试集上进行评估。

server_state = federated_algorithm.initialize()
evaluate(server_state)
# 2042/2042 [==============================] - 8s 3ms/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

然后我们进行联合训练15轮再次评估:

for round in range(15):server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
# 2042/2042 [==============================] - 5s 2ms/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

可以看到loss小幅度降低了。

3.3 构建自己的算法

经过上面的步骤,我们已经可以导入emnist数据集,基于keras构建模型,然后编写服务器更新函数和客户端更新函数,将其转换到TFF框架下,然后在TFF测试数据集进行训练的评估。

那么如果我们需要构建自己的模型,然后在自己的数据集上进行训练只需要把其中纯TensorFlow构建的模型部分进行修改就可以了。

[联邦学习TFF]构建自己的联邦学习模型相关推荐

  1. 【阅读笔记】联邦学习实战——构建公平的大数据交易市场

    联邦学习实战--构建公平的大数据交易市场 前言 1. 大数据交易 1.1 数据交易定义 1.2 数据确权 1.3 数据定价 2. 基于联邦学习构建新一代大数据交易市场 3. 联邦学习激励机制助力数据交 ...

  2. 联邦学习综述(二)——联邦学习的分类、框架及未来研究方向

    文章目录 第三章 联邦学习分类 3.1 横向联邦学习 3.2 纵向联邦学习 3.3 联邦迁移学习 第四章 联邦学习框架 4.1 开源框架介绍 4.2 FATE--企业级框架 第五章 未来研究方向 5. ...

  3. 【联邦学习+区块链】《联邦学习vs区块链:谁是“可信媒介”技术领域最强王者?》疑问解答

    联邦学习[1]VS 区块链 [问1]联邦学习,何为"联邦"? 作为一种分布式机器学习技术,联邦学习可以实现各个企业的自有数据不出本地,而是通过加密机制下的参数交换方式共建模型,即在 ...

  4. 【联邦学习 + 区块链】《联邦学习vs区块链:谁是“可信媒介”技术领域最强王者?》阅读记录与提问

    [注]块引用部分是博主自己的思考.. 题目:<联邦[1]学习vs区块链:谁是"可信媒介"技术领域最强王者?> [问1]联邦学习,何为"联邦"? 在互 ...

  5. 通过 DLPack 构建跨框架深度学习编译器

    通过 DLPack 构建跨框架深度学习编译器 深度学习框架,如Tensorflow, PyTorch, and ApacheMxNet,快速原型化和部署深度学习模型提供了强大的工具箱.不幸的是,易用性 ...

  6. R使用深度学习LSTM构建时间序列预测模型

    R使用深度学习LSTM构建时间序列预测模型 LSTM的全称是Long Short Term Memory,顾名思义,它具有记忆长短期信息的能力的神经网络.LSTM首先在1997年由Hochreiter ...

  7. “领域知识图谱的构建与应用”讲座学习笔记

    知识图谱是一种基于先进信息技术的型语义工具,它以实体或概念为节点,通过语义关系连接成大规模语义网络,可以帮助机器或信息系统理解语义.组织知识.发现知识,从而为人们提供知识.情报的智慧搜索和智能交互. ...

  8. 深度学习预测酶活性参数提升酶约束模型构建从头环境搭建

    前言 这项工作开发了一种用深度学习来预测酶活性参数的方法(DLKcat),主要采用了针对底物的图神经网络和针对蛋白质的卷积神经网络.通过从公开的数据库中获取和数据预处理,最终获得了超过一万六千条高质量 ...

  9. 干货 :六招教你用Python分分钟构建好玩的深度学习应用

    [导读]深度学习是近来数据科学中研究和讨论最多的话题.得益于深度学习的发展,数据科学在近期得到了重大突破,深度学习也因此得到了很多关注.据预测,在不久的将来,更多的深度学习应用程序会影响人们的生活.实 ...

最新文章

  1. Caffe学习系列(6):Blob,Layer and Net以及对应配置文件的编写
  2. 远程java接口说明
  3. 【Python基础】加密你的Python源码顺便再打个包如何?
  4. java代码简单操作Redis数据Jedis jar
  5. 2021-10-12Spring缓存注解@Cacheable、@CacheEvict、@CachePut使用
  6. 解决微信小程序 [Component] slot ““ is not found.
  7. 如何将微商管理模式流程化
  8. java treeset 删除_删除Java TreeSet中的最低元素
  9. Windows XP 默认蓝色桌面的 RGB
  10. 大学计算机模拟2014网络应用,2014全校大学计算机基础模拟考试.doc
  11. 如何拼局域网所有ip_怎么查看 同一个局域网内连的所有的IP地址 - 卡饭网
  12. 时间序列学习(1):平稳性、自相关性
  13. 嵌入式软件工程师面试题总结
  14. 在兼容系统上升级DELL SATA硬盘的固件
  15. 用对象的上转型对象、方法重写,抽象编程:求柱体的体积。
  16. 救命稻草VirtualBox,失之交臂VMware—— 2者的guest OS对 恒通笔记本并口卡的支持
  17. 人人网(cookie登录)
  18. VMware虚拟机下Ubuntu18.04学校宽带拨号连接网络
  19. WORD出错:设置为稿纸后,右键的字体、段落项目变灰
  20. deepin访问不了网页

热门文章

  1. 网页中怎么实现客户端通过扫描仪把图像传到服务器上
  2. springboot 整合 oss进行文件上传
  3. Oracle 表类型-表值函数-过程 -例子
  4. Matlab 制作《最炫民族风》弱爆了,附代码
  5. finBERT-金融英文情感分析运行介绍
  6. iOS二维码生成(带logo)
  7. 华为ICT大赛辅导——双AC主备双链路备份
  8. Android获取用户位置
  9. autojs-抖音评论
  10. WEB UI设计网站收藏