FATE系统主要支持表格数据作为其标准数据格式。然而,通过使用NN模块的数据集特性,可以在神经网络中使用非表格数据,例如图像、文本、混合数据或关系数据。NN模块中的数据集模块允许自定义数据集,以便用户可以在更复杂的数据场景中使用它们。本教程将介绍Hetero NN中数据集功能的使用。为了便于演示,我们将使用MNIST手写识别数据集作为示例来模拟Hetero Federation任务以说明这些概念。

准备MNIST数据

请从以下链接下载来宾/主机MNIST数据集,并将其放在项目示例/数据文件夹中:

  • gust数据:guest地址

  • host数据:host地址

mnist_guest是mnist数据集的简化版本,共有十个类别,根据标签分为0-9个10个文件夹。mnist_host具有与mnist_guest相同的图像,但未标记。

! ls /mnt/hgfs/Hetero/mnist_guest  # 根据自己的文件位置进行调整

0 1 2 3 4 5 6 7 8 9

! ls /mnt/hgfs/Hetero/mnist_host

not_labeled

数据集

在FATE-1.10版本中,FATE为数据集引入了一个新的基类,称为Dataset,它基于PyTorch的Dataset类。此类允许用户根据其特定需求创建自定义数据集。其用法与PyTorch的Dataset类类似,在使用FATE-NN进行数据读取和训练时,需要实现两个额外的接口:load()和get_sample_ids()。

要在Hetero NN中创建自定义数据集,用户需要:

开发继承自dataset类的新数据集类

实现__len__()和__getitem__()方法,它们与PyTorch的数据集用法一致。__len__()方法应返回数据集的长度,而__getitem_()方法则应返回指定索引处的相应数据。但是,请注意,不同的__getitem_()方法在不同方之间可能有不同的行为。在来宾方(带有标签的方)中,_getitem\_()方法返回功能和标签,而在宿主方(没有标签的方,_getiitem\_(()方法仅返回功能。

实现load()、get_sample_ids()和get_classes()方法

对于不熟悉PyTorch的数据集类的人,可以在PyTorch文档中找到更多信息:PyTorch数据集文档

load()

所需的第一个附加接口是load()。此接口接收文件路径,并允许用户直接从本地文件系统读取数据。提交任务时,可以通过读取器组件指定数据路径。Hetero NN将使用用户指定的Dataset类,利用load()接口从指定路径读取数据,并完成数据集的加载以进行训练。有关更多信息,请参阅base.py中的源代码。

get_sample_ids()

第二个附加接口是get_sample_ids()。此接口应返回一个样本ID列表,该列表可以是整数或字符串,并且长度应与数据集相同。该功能在Hetero NN中很重要,您需要了解以下几点:

  • 在Hetero NN中使用自定义数据集时,确保您的样本ID与其他方的样本ID一致非常重要。您可以通过使用交集组件并提取结果,或者通过与其他方商定要使用的样本ID来实现这一点。

  • 你不必把你的id按顺序排列,Hetero-NN组件会对它们进行排序。

get_classes()

第三个函数返回所有唯一标签的列表。这将在宾客聚会上进行。如果不是分类任务,只需返回一个空列表。

示例:实现一个简单的图像数据集

为了更好地理解数据集的定制,这里我们实现了一个简单的图像数据集来读取MNIST图像,然后在垂直场景中完成联合图像分类任务。为了方便起见,我们使用save_to_rate的jupyter接口将代码更新为federatedml.nn.dataset(名为mnist_dataset.py),当然,您可以手动将代码文件复制到目录中。

  • 此数据集有一个参数“return_label”,当来宾方(带标签的方)使用它时,我们将return_Label=True,否则return_lappel=False

  • 它是基于ImageFolder开发的,我们将图像名称作为示例id。

from pipeline.component.nn import save_to_fate
%%save_to_fate dataset mnist_dataset.py
import numpy as np
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transformsclass MNISTDataset(Dataset):def __init__(self, return_label=True):  super(MNISTDataset, self).__init__() self.return_label = return_labelself.image_folder = Noneself.ids = Nonedef load(self, path):  self.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))ids = []for image_name in self.image_folder.imgs:ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))self.ids = idsreturn selfdef get_sample_ids(self, ):return self.idsdef get_classes(self, ):return np.unique(self.image_folder.targets).tolist()def __len__(self,):  return len(self.image_folder)def __getitem__(self, idx): # get item ret = self.image_folder[idx]img = ret[0][0].flatten() # flatten tensor 784 dimsif self.return_label:return img, ret[1] # img & labelelse:return img # no label, for host

现在我们测试数据集类:

# guest party
! ls /mnt/hgfs/Hetero/mnist_guest
ds = MNISTDataset().load('/mnt/hgfs/Hetero/mnist_guest')
print(len(ds))
print(ds[0][0])
print(ds.get_classes())
print(ds.get_sample_ids()[0: 10])
# host party
! ls /mnt/hgfs/Hetero/mnist_host  # no label
ds = MNISTDataset(return_label=False).load('/mnt/hgfs/Hetero/mnist_host')
print(len(ds))
print(ds[0]) # no label

好的它已经准备好使用了,所以让我们使用这个开发的数据集来运行一个Hetero-NN模型,并且双方都使用两个数据集mnist_guest和mnist_host来执行一个异类联合训练

与Homo NN(请参阅自定义数据集)相同,这里我们不会遵循传统FATE组件的用法,而是将数据路径直接绑定到FATE名称和命名空间,并通过读取器将其传递给Hemo NN组件,Hemo NN通过您设置的DatasetParam导入自定义数据集类,然后从路径读取数据。

pipeline初始化

在这里,我们定义了运行异类任务的pipeline

import os
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HeteroNN
from pipeline.component.hetero_nn import DatasetParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model
from pipeline.component.nn import save_to_fatefate_torch_hook(t)# bind path to fate name&namespace
# fate_project_path = os.path.abspath('/mnt/hgfs/Hetero/')  # 自定义文件位置
guest = 10000
host = 9999pipeline_img = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}# 自定义文件位置
guest_data_path = '/mnt/hgfs/Hetero/mnist_guest/'
host_data_path = '/mnt/hgfs/Hetero/mnist_host/'
pipeline_img.bind_table(name='mnist_guest', namespace='experiment', path=guest_data_path)
pipeline_img.bind_table(name='mnist_host', namespace='experiment', path=host_data_path)

{'namespace': 'experiment', 'table_name': 'mnist_host'}

guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_data)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_data)
hetero_nn_0 = HeteroNN(name="hetero_nn_0", epochs=3,interactive_layer_lr=0.01, batch_size=512, task_type='classification', seed=100)
guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)
host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)# define model
# image features 784, guest bottom model
# our simple classification model:
guest_bottom = t.nn.Sequential(t.nn.Linear(784, 8),t.nn.ReLU()
)# image features 784, host bottom model
host_bottom = t.nn.Sequential(t.nn.Linear(784, 8),t.nn.ReLU()
)# Top Model, a classifier
guest_top = t.nn.Sequential(nn.Linear(8, 10),nn.Softmax(dim=1)
)# interactive layer define
interactive_layer = t.nn.InteractiveLayer(out_dim=8, guest_dim=8, host_dim=8)# add models
guest_nn_0.add_top_model(guest_top)
guest_nn_0.add_bottom_model(guest_bottom)
host_nn_0.add_bottom_model(host_bottom)# opt, loss
optimizer = t.optim.Adam(lr=0.01)
loss = t.nn.CrossEntropyLoss()# use DatasetParam to specify dataset and pass parameters
guest_nn_0.add_dataset(DatasetParam(dataset_name='mnist_dataset', return_label=True))
host_nn_0.add_dataset(DatasetParam(dataset_name='mnist_dataset', return_label=False))hetero_nn_0.set_interactive_layer(interactive_layer)
hetero_nn_0.compile(optimizer=optimizer, loss=loss)
pipeline_img.add_component(reader_0)
pipeline_img.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))
pipeline_img.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.compile()
pipeline_img.fit()
pipeline_img.get_component('hetero_nn_0').get_output_data()  # get result

FATE —— 二.3.1 Hetero-NN自定义数据集相关推荐

  1. 行人属性识别二:添加新网络训练和自定义数据集训练

    序言 上一篇记录了训练过程,但是项目中提供的模型网络都是偏大的,如果想要在边缘设备上部署,还是比较吃力的,所以本文记录如何加入新的网络模型进行训练,以repvgg为例,加入mobilenet.shuf ...

  2. ML:基于自定义数据集利用Logistic、梯度下降算法GD、LoR逻辑回归、Perceptron感知器、SVM支持向量机、LDA线性判别分析算法进行二分类预测(决策边界可视化)

    ML:基于自定义数据集利用Logistic.梯度下降算法GD.LoR逻辑回归.Perceptron感知器.支持向量机(SVM_Linear.SVM_Rbf).LDA线性判别分析算法进行二分类预测(决策 ...

  3. 使用RDLC报表(二)--使用自定义数据集

    使用RDLC报表(二)--使用自定义数据集 <!--[if !supportLists]-->1<!--[endif]-->新建窗体 <!--[if !supportLi ...

  4. FATE —— 二.3.2 Hetero-NN使用CustModel设置顶部、底部模型

    我们可以在Hetero NN中定制顶部模型和底部模型.model_zoo模块在FATE 1.10中引入,位于federatedml.nn.model_zoo下.该模块允许您自定义自己的PyTorch模 ...

  5. 我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!

    大家好,我是红色石头! 在上三篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 我用 PyTorch ...

  6. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!

    在上三篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)! 我用 PyTorch 复现了 LeNet-5 ...

  7. 〖TensorFlow2.0笔记21〗自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where!

    自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where! 文章目录 一. 数据集介绍以及加载 1.1. 数据集简单描述 1.2. 程序实现步骤 1.3. 加载数据的格式 1.4. map函 ...

  8. 利用PyTorch自定义数据集实现猫狗分类

    看了许多关于PyTorch的入门文章,大抵是从torchvision.datasets中自带的数据集进行训练,导致很难把PyTorch运用于自己的数据集上,真正地灵活运用PyTorch. 这里我采用从 ...

  9. PaddlePaddle飞桨《高层API助你快速上手深度学习》『深度学习7日打卡营』--自定义数据集OCEMOTION–中文情感分类

    赛题背景 自从2017年具有划时代意义的Transformer模型问世以来,短短两年多的时间内,如雨后春笋般的出现了大量的预训练模型,比如:Bert,Albert,ELECTRA,RoBERta,T5 ...

最新文章

  1. 激光雷达与摄影测量相结合如何提高点云质量?
  2. CSDN网友挑选的2007年最有价值文章
  3. 防火墙(二)SNAT和DNAT
  4. git--命令行放弃修改
  5. 数据结构——算法的基本概念
  6. 池化方法总结(Pooling)
  7. JQuery-FullCalendar 多数据源实现日程展示
  8. OpenCASCADE绘制测试线束:几何命令之Intersections
  9. idea 修改Git密码和账号方法
  10. 手机支付:电信运营商会被边缘化吗?!
  11. JavaScript学习手册十一:JSON
  12. 短视频去水印威信小程序源码下载,内附去水印解析接口
  13. 八大数据分析模型之——用户模型(一)
  14. LabVIEW顺序结构
  15. (转)sql server 排序规则
  16. 如何制作校园平面图及路线导图
  17. 怎样用计算机合并视频,怎么合并视频和字幕 格式工厂视频字幕合并教程-电脑教程...
  18. 努比亚z11mini 使用 移动物联卡
  19. overleaf怎么输入中文_【LATEX】在线latex排版工具Overleaf-制作中文简历-详细教程...
  20. 要成为鸿蒙开发者,应该学习哪些编程语言

热门文章

  1. Win7Codecs+设置程序中英文对照
  2. 35、sparkSQL及DataFrame
  3. STM32F4移植EMWIN(RA8875驱动显示屏)
  4. MyBatis中的日志(LOG4J)
  5. 百度地图WEB服务-逆地理编码使用心得
  6. 安装pytorch报错torch.cuda.is_available()=false的解决方法
  7. 紫书《算法竞赛入门经典》
  8. Vulnhub_hacksudo_fog
  9. python电脑推荐_kk视频app下载安装|腾讯视频app下载_电脑知识学习网
  10. ubuntukylin-16.04安装