我们可以在Hetero NN中定制顶部模型和底部模型。model_zoo模块在FATE 1.10中引入,位于federatedml.nn.model_zoo下。该模块允许您自定义自己的PyTorch模型,前提是它是基于torch.nn.module开发的,并实现前向接口。有关更多信息,请参阅自定义模块上的PyTorch文档PyTorch模块。要在联合任务中使用自定义模型,只需将其放置在federatedml/nn/model_zoo目录中,并在提交任务时通过接口指定模块和模型类。Hetero-NN组件将自动搜索并导入您已实现的模型。

您也可以用类似的方式定义自己的损失类别。您可以将损失类放在位于federatedml.nn.loss下的损失模块下。我们建议您在阅读本教程之前阅读以下两个教程:自定义损失函数、自定义模型

作为一个例子,我们考虑重用上一篇杂文教程的MNIST手写识别任务。

准备MNIST数据

请从以下链接下载来宾/主机MNIST数据集,并将其放在项目示例/数据文件夹中:

  • gust数据:guest地址

  • host数据:host地址

mnist_guest是mnist数据集的简化版本,共有十个类别,根据标签分为0-9个10个文件夹。mnist_host具有与mnist_guest相同的图像,但未标记。

! ls /mnt/hgfs/Hetero/mnist_guest  # 根据自己的文件位置进行调整

0 1 2 3 4 5 6 7 8 9

! ls /mnt/hgfs/Hetero/mnist_host

not_labeled

数据集

在FATE-1.10版本中,FATE为数据集引入了一个新的基类,称为Dataset,它基于PyTorch的Dataset类。此类允许用户根据其特定需求创建自定义数据集。其用法与PyTorch的Dataset类类似,在使用FATE-NN进行数据读取和训练时,需要实现两个额外的接口:load()和get_sample_ids()。

要在Hetero NN中创建自定义数据集,用户需要:

  • 开发继承自dataset类的新数据集类

  • 实现__len__()和__getitem__()方法,它们与PyTorch的数据集用法一致。__len__()方法应返回数据集的长度,而__getitem_()方法则应返回指定索引处的相应数据。但是,请注意,不同的__getitem_()方法在不同方之间可能有不同的行为。在guest方(带有标签的方)中,_getitem\_()方法返回功能和标签,而在宿主方(没有标签的方,_getiitem\_(()方法仅返回功能。

  • 实现load()、get_sample_ids()和get_classes()方法

对于不熟悉PyTorch的数据集类的人,可以在PyTorch文档中找到更多信息:PyTorch数据集文档

自定义底部/顶部模型

将模型代码命名为bottom_net.py,您可以将其直接放在fedratedml/nn/model_zoo下,或使用jupyter的快捷界面:save_to_date,将其直接保存到fedratedml/nn/model_zoo。这是我们为特征提取定义的底部模型结构。

from pipeline.component.nn import save_to_fate
%%save_to_fate model bottom_net.py
import torch as t
from torch import nn
from torch.nn import Moduleclass BottomNet(nn.Module):def __init__(self):super(BottomNet, self).__init__()self.seq = t.nn.Sequential(nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5),nn.MaxPool2d(kernel_size=3),nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3),nn.AvgPool2d(kernel_size=3))self.fc = t.nn.Sequential(   # extracted feature is a 8-dim embeddingnn.Linear(48, 32),nn.ReLU(),nn.Linear(32, 8),nn.ReLU())def forward(self, x):x = self.seq(x)x = x.flatten(start_dim=1)x = self.fc(x)return x

这是我们为分类定义的顶级模型,我们将其命名为top_model.py。

%%save_to_fate model top_net.py
import torch as t
from torch import nn
from torch.nn import Moduleclass TopNet(nn.Module):def __init__(self):super(TopNet, self).__init__()self.fc = t.nn.Sequential(   nn.Linear(8, 10))self.softmax = t.nn.Softmax(dim=1)def forward(self, x):x = self.fc(x)return self.softmax(x)

使用Cust Loss

使用Cust Loss与Homo NN完全相同,请参见:自定义损失函数。这里我们使用一个新的CrossEntropyLoss。

%%save_to_fate loss ce.py
import torch as t
from federatedml.util import consts
from torch.nn.functional import one_hotdef cross_entropy(p2, p1, reduction='mean'):p2 = p2 + consts.FLOAT_ZERO  # to avoid nanassert p2.shape == p1.shapeif reduction == 'sum':return -t.sum(p1 * t.log(p2))elif reduction == 'mean':return -t.mean(t.sum(p1 * t.log(p2), dim=1))elif reduction == 'none':return -t.sum(p1 * t.log(p2), dim=1)else:raise ValueError('unknown reduction')class CrossEntropyLoss(t.nn.Module):"""A CrossEntropy Loss that will not compute Softmax"""def __init__(self, reduction='mean'):super(CrossEntropyLoss, self).__init__()self.reduction = reductiondef forward(self, pred, label):one_hot_label = one_hot(label.flatten())loss_ = cross_entropy(pred, one_hot_label, self.reduction)return loss_

然后,我们可以在Hetero NN MNIST任务中使用我们的模型和损失!用法与Homo NN相同:我们通过NN.CustModel和NN.CustLoss接口指定模型和损失。

pipeline初始化

在这里,我们定义了运行异类任务的pipeline

import os
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HeteroNN
from pipeline.component.hetero_nn import DatasetParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
from pipeline.component.nn import save_to_fatefate_torch_hook(t)# bind path to fate name&namespace
# fate_project_path = os.path.abspath('../../../../')
guest = 10000
host = 9999pipeline_img = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}# 自定义文件位置
guest_data_path = '/mnt/hgfs/Hetero/mnist_guest/'
host_data_path = '/mnt/hgfs/Hetero/mnist_host/'
pipeline_img.bind_table(name='mnist_guest', namespace='experiment', path=guest_data_path)
pipeline_img.bind_table(name='mnist_host', namespace='experiment', path=host_data_path)

{'namespace': 'experiment', 'table_name': 'mnist_host'}

hetero_nn_0 = HeteroNN(name="hetero_nn_0", epochs=3,interactive_layer_lr=0.01, batch_size=512, task_type='classification', seed=100)
guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)
host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)# define model
# use cust model here
# our simple classification model:
guest_bottom = t.nn.CustModel(module_name='bottom_net.py', class_name='BottomNet')# use cust model here
host_bottom = t.nn.CustModel(module_name='bottom_net.py', class_name='BottomNet')# use new top model here
guest_top = t.nn.CustModel(module_name='top_net.py', class_name='TopNet')# interactive layer define
interactive_layer = t.nn.InteractiveLayer(out_dim=8, guest_dim=8, host_dim=8)# add models
guest_nn_0.add_top_model(guest_top)
guest_nn_0.add_bottom_model(guest_bottom)
host_nn_0.add_bottom_model(host_bottom)# opt, loss
optimizer = t.optim.Adam(lr=0.01)
loss = t.nn.CustLoss(loss_module_name='ce', class_name='CrossEntropyLoss')# use DatasetParam to specify dataset and pass parameters
guest_nn_0.add_dataset(DatasetParam(dataset_name='image', return_label=True))
host_nn_0.add_dataset(DatasetParam(dataset_name='image', return_label=False))hetero_nn_0.set_interactive_layer(interactive_layer)
hetero_nn_0.compile(optimizer=optimizer, loss=loss)
pipeline_img.add_component(reader_0)
pipeline_img.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))
pipeline_img.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.compile()
pipeline_img.fit()
pipeline_img.get_component('hetero_nn_0').get_output_data()  # get result

FATE —— 二.3.2 Hetero-NN使用CustModel设置顶部、底部模型相关推荐

  1. FATE —— 二.3.1 Hetero-NN自定义数据集

    FATE系统主要支持表格数据作为其标准数据格式.然而,通过使用NN模块的数据集特性,可以在神经网络中使用非表格数据,例如图像.文本.混合数据或关系数据.NN模块中的数据集模块允许自定义数据集,以便用户 ...

  2. 【word2vec】篇二:基于Hierarchical Softmax的 CBOW 模型和 Skip-gram 模型

    文章目录 CBOW 模型 基本结构 目标函数 梯度计算 Skip-gram 模型 基本结构 梯度计算 优缺点分析 系列文章: [word2vec]篇一:理解词向量.CBOW与Skip-Gram等知识 ...

  3. 知识图谱论文阅读(八)【转】推荐系统遇上深度学习(二十六)--知识图谱与推荐系统结合之DKN模型原理及实现

    学习的博客: 推荐系统遇上深度学习(二十六)–知识图谱与推荐系统结合之DKN模型原理及实现 知识图谱特征学习的模型分类汇总 知识图谱嵌入(KGE):方法和应用的综述 论文: Knowledge Gra ...

  4. R语言epiDisplay包的kap函数计算配对列联表的计算一致性的比例以及Kappa统计量的值(总一致性、期望一致性)、使用xtabs函数生成二维列联表、使用wttable参数设置权重表参数为w2

    R语言使用epiDisplay包的kap函数计算配对列联表的计算一致性的比例以及Kappa统计量的值(总一致性.期望一致性).使用xtabs函数生成二维列联表.使用wttable参数设置权重表参数为w ...

  5. Java实现一行代码生成二维码,可传输到前端展示,可自定义二维码样式,可设置图片格式,可对二维码添加图片,可对二维码添加文字,可以设置二维码大小、字体大小、字体颜色、边框颜色、边框大小等等

    Java实现一行代码生成二维码,可传输到前端展示,可自定义二维码样式,可设置图片格式,可对二维码添加图片,可对二维码添加文字,可以设置二维码大小.字体大小.字体颜色.边框颜色.边框大小等等. 0.准备 ...

  6. 抓包工具Charles(二)-移动端APP抓包(设置手机代理、安装证书)

    安装好Charles之后,还只能捕获电脑的接口请求 想要抓取移动设备的APP还需要设置代理.安装证书. 文章目录 一.抓包原理 二.手机设置网络代理 1. 查看电脑的IP地址(local IP add ...

  7. ML之yellowbrick:基于titanic泰坦尼克是否获救二分类预测数据集利用yellowbrick对LoR逻辑回归模型实现可解释性(阈值图)案例

    ML之yellowbrick:基于titanic泰坦尼克是否获救二分类预测数据集利用yellowbrick对LoR逻辑回归模型实现可解释性(阈值图)案例 目录 基于titanic泰坦尼克是否获救二分类 ...

  8. FATE —— 二.4.2 Criteo上的联邦经典CTR模型训练

    在本教程中,我们将向您展示如何开发水平联合推荐模型.我们使用第三方库torch rechub调用一些经典的推荐模型,如FM.DeepFM等,并使用它在FATE中构建联邦任务.在数据集方面,我们使用了经 ...

  9. Lesson 8.38.4 二分类神经网络torch.nn.functional实现单层二分类网络的正向传播

    二.二分类神经网络:逻辑回归 1 二分类神经网络的理论基础 线性回归是统计学经典算法,它能够拟合出一条直线来描述变量之间的线性关系.但在实际中,变量之间的关系通常都不是一条直线,而是呈现出某种曲线关系 ...

最新文章

  1. CaaS环境下实践经验总结(二):监控系统部署
  2. 第四百一十四节,python常用算法学习
  3. Summer Training day6 coseforces339D 线段树、位操作
  4. springMVC开启声明式事务实现操作日志记录
  5. HTML学习二_HTML常用的行级标签,常用实体字符及表单标签
  6. 谷歌android wear智能腕表 价格,谷歌Android Wear 2.0更新推送:仅三款智能手表可享受...
  7. 【答辩问题】计算机专业本科毕业设计答辩问题
  8. Mac好用的文件对比工具Beyond Compare 4
  9. Cesium Terrain Builder 非压缩瓦片
  10. java判断数组值类型,判断(1分) Java语言中的数组元素只能是基本数据类型而不能为对象类型。...
  11. 详解阿里云第六代增强型实例,性能强劲,百万IOPS加持
  12. 10GE DWDM SFP+彩色光模块应用案例
  13. 计算机windows用户名密码怎么查,电脑密码怎么查看? 从零教你查询方式
  14. clickhouse之数据存储:JBOD vs RAID
  15. C++设计并测试一个名为Rectangle的矩形类,其属性为矩形的左下角与右上角两个点的坐标,根据坐标能计算矩形的面积。
  16. javaee.jar与servlet-api.jar
  17. 机器学习从零到一的基础知识总集篇
  18. 【刷题笔记】——day.6 有效的井字游戏
  19. Java文件复制的三种方法
  20. WebApi在MVC 4中一个Controll多个post方法报错处理

热门文章

  1. AUTOCAD——设置文字间距与行距
  2. HDLBits 系列(8)——Sequential Logic(Finite State Machines(一))
  3. 汽车电子EMC实验简介
  4. P2P通信中的NAT/FW穿越
  5. 首席新媒体黎想教程:活动形式和用户吸引逻辑!
  6. php 中英文查询字数,php统计中英文混合的文章字数
  7. Python中的pass的作用
  8. BSides Noida CTF 2021 web题wowooofreepoint writeup(两道反序列化)
  9. 【新手教程】51Sim-One Cloud 2.0如何构建一个V2X案例
  10. Luogu P4735(可持久化字典树)