联邦学习安全之后门攻击
本博客地址:https://security.blog.csdn.net/article/details/124067669
一、后门攻击定义
在联邦学习中,后门攻击是意图让模型对具有某种特定特征的数据做出错误的判断,但模型不会对主任务产生影响。
举个例子,在图像识别中,攻击者意图让带有红色的小车都被识别为小鸟,那攻击者会先通过修改其挟持的客户端样本标签,将带有红色的小车标注为小鸟,让模型重新训练,这样训练得到的最终模型在推断的时候,会将带有红色的小车错误判断为小鸟,但不会影响对其他图片的判断。
在联邦学习场景下进行后门攻击会比较困难,一个原因就是在服务端进行聚合运算时,平均化之后会很大程度消除恶意客户端模型的影响,另一个原因是由于服务端的选择机制,因为并不能保证被攻击者挟持的客户端在每一轮都能被选取,从而降低了被后门攻击的风险。
二、后门攻击策略
带有后门攻击行为的联邦学习,其客户端可以分为恶意客户端和正常客户端。不同类型的客户端,其本地训练策略各不相同。
2.1、正常客户端训练
正常客户端的训练算法如下,其执行过程就是常规的梯度下降过程。
正常客户端的训练算法:
---------------------------------------------------------------------------------------
input: 客户端ID: k ;
全局模型: ;
学习率: ;
本地迭代次数: E ;
每一轮训练的样本大小: B ;
output: 返回模型更新:
利用服务端下发的全局模型参数,更新本地模型:
for 对每一轮的迭代 i = 1, 2, 3, ……, E,执行下面的操作 do
将本地数据切分为 |B| 份数据 B
for 对每一个 batch
执行梯度下降:
end
end
---------------------------------------------------------------------------------------
2.2、恶意客户端训练
对于恶意客户端的本地训练,主要体现在两个方面:损失函数的设计和上传服务端的模型权重。
● 对于损失函数的设计,恶意客户端训练的目标,一方面是保证在正常数据集和被篡改毒化的数据集中都取得较好的性能;另一方面是保证本地训练的模型与全局模型之间的距离尽量小(距离越小,被服务端判断为异常模型的概率就越小)。
● 对于上传服务端的模型权重,根据以下公式: 可以看出,通过增大异常客户端m的模型权重,使其在后面的聚合过程中,对全局模型的影响和贡献尽量持久。
恶意客户端的训练算法:
---------------------------------------------------------------------------------------
input: 客户端ID: k ;
全局模型: ;
学习率: ;
本地迭代次数: E ;
每一轮的训练样本大小: B ;
output: 返回模型更新:
利用服务端下发的全局模型参数 ,更新本地模型
损失函数:
for 对每一轮的迭代 i = 1, 2, 3, ……, E,执行下面的操作 do
将本地数据切分为 |B| 份数据 B
for 对每一个 batch
数据集 中包含正常的数据集 和被篡改毒化的数据集
执行梯度下降:
end
end
---------------------------------------------------------------------------------------
三、后门攻击具体实现
3.1、客户端
人为篡改客户端client.py的代码,已对代码做出了具体的注释说明,具体细节阅读代码即可。
client.py
import models, torch, copy
import numpy as np
import matplotlib.pyplot as pltclass Client(object):def __init__(self, conf, model, train_dataset, id = -1):self.conf = confself.local_model = models.get_model(self.conf["model_name"]) self.client_id = idself.train_dataset = train_datasetall_range = list(range(len(self.train_dataset)))data_len = int(len(self.train_dataset) / self.conf['no_models'])train_indices = all_range[id * data_len: (id + 1) * data_len]self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"], sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))def local_train(self, model):for name, param in model.state_dict().items():self.local_model.state_dict()[name].copy_(param.clone())optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],momentum=self.conf['momentum'])self.local_model.train()for e in range(self.conf["local_epochs"]):for batch_id, batch in enumerate(self.train_loader):data, target = batchif torch.cuda.is_available():data = data.cuda()target = target.cuda()optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)loss.backward()optimizer.step()print("Epoch %d done." % e) diff = dict()for name, data in self.local_model.state_dict().items():diff[name] = (data - model.state_dict()[name])return diffdef local_train_malicious(self, model):for name, param in model.state_dict().items():self.local_model.state_dict()[name].copy_(param.clone())# 设置优化数据optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],momentum=self.conf['momentum'])pos = []# 手动篡改数据,设置毒化数据的样式for i in range(2, 28):pos.append([i, 3])pos.append([i, 4])pos.append([i, 5])self.local_model.train()for e in range(self.conf["local_epochs"]):for batch_id, batch in enumerate(self.train_loader):data, target = batch# 在线修改数据,模拟被攻击场景for k in range(self.conf["poisoning_per_batch"]):img = data[k].numpy()for i in range(0,len(pos)):img[0][pos[i][0]][pos[i][1]] = 1.0img[1][pos[i][0]][pos[i][1]] = 0img[2][pos[i][0]][pos[i][1]] = 0target[k] = self.conf['poison_label']if torch.cuda.is_available():data = data.cuda()target = target.cuda()optimizer.zero_grad()output = self.local_model(data)# 类别损失class_loss = torch.nn.functional.cross_entropy(output, target)# 距离损失dist_loss = models.model_norm(self.local_model, model)# 总的损失函数为类别损失与距离损失之和loss = self.conf["alpha"]*class_loss + (1-self.conf["alpha"])*dist_lossloss.backward()optimizer.step()print("Epoch %d done." % e)diff = dict()# 计算返回值for name, data in self.local_model.state_dict().items():# 恶意客户端返回值diff[name] = self.conf["eta"]*(data - model.state_dict()[name])+model.state_dict()[name]return diff
3.2、服务端
由于服务端一般是难以攻破的,所以服务端代码不做改动。
server.py
import models, torchclass Server(object):def __init__(self, conf, eval_dataset):self.conf = conf self.global_model = models.get_model(self.conf["model_name"]) self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)def model_aggregate(self, weight_accumulator):for name, data in self.global_model.state_dict().items():update_per_layer = weight_accumulator[name] * self.conf["lambda"]if data.type() != update_per_layer.type():data.add_(update_per_layer.to(torch.int64))else:data.add_(update_per_layer)def model_eval(self):self.global_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id, batch in enumerate(self.eval_loader):data, target = batch dataset_size += data.size()[0]if torch.cuda.is_available():data = data.cuda()target = target.cuda()output = self.global_model(data)total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()pred = output.data.max(1)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()acc = 100.0 * (float(correct) / float(dataset_size))total_l = total_loss / dataset_sizereturn acc, total_l
3.3、配置文件
已对代码做出了具体的注释说明,具体细节阅读代码即可。
conf.json
{"model_name" : "resnet18","no_models" : 10,"type" : "cifar","global_epochs" : 20,"local_epochs" : 3,"k" : 3,"batch_size" : 32,"lr" : 0.001,"momentum" : 0.0001,"lambda" : 0.3,"eta" : 2, // 恶意客户端的权重参数"alpha" : 1.0, // class_loss和dist_loss之间的权重比例"poison_label" : 2, // 约定将被毒化的数据归类为哪一类"poisoning_per_batch" : 4 // 当恶意客户端在本地训练时,在每一轮迭代过程中被篡改的数据量
}
3.4、模型文件
models.json
import torch
from torchvision import models
import mathdef get_model(name="vgg16", pretrained=True):if name == "resnet18":model = models.resnet18(pretrained=pretrained)elif name == "resnet50":model = models.resnet50(pretrained=pretrained) elif name == "densenet121":model = models.densenet121(pretrained=pretrained) elif name == "alexnet":model = models.alexnet(pretrained=pretrained)elif name == "vgg16":model = models.vgg16(pretrained=pretrained)elif name == "vgg19":model = models.vgg19(pretrained=pretrained)elif name == "inception_v3":model = models.inception_v3(pretrained=pretrained)elif name == "googlenet": model = models.googlenet(pretrained=pretrained)if torch.cuda.is_available():return model.cuda()else:return model # 定义两个模型的距离函数
def model_norm(model_1, model_2):squared_sum = 0for name, layer in model_1.named_parameters():squared_sum += torch.sum(torch.pow(layer.data - model_2.state_dict()[name].data, 2))return math.sqrt(squared_sum)
联邦学习安全之后门攻击相关推荐
- 论文 ❀《评价联邦学习中梯度泄漏攻击的框架》- A Framework for Evaluating Gradient Leakage Attacks in Federated Learning
摘要 联合学习(FL)是一个新兴的分布式机器学习框架,用于与客户网络(边缘设备)进行协作式模型训练.联合学习允许客户将其敏感数据保存在本地设备上,并且只与联合服务器共享本地训练参数更新,从而默认客户隐 ...
- 联邦学习后门攻击总结(2019-2022)
联邦学习后门攻击总结(2019-2022) 联邦学习安全性问题框架概览 下表和下图为联邦学习中常见的安全性问题,本文重点关注模型鲁棒性问题中的后门攻击问题. 攻击手段 安全性问题 攻击方与被攻击方 攻 ...
- 【阅读笔记】联邦学习实战——联邦学习攻防实战
联邦学习实战--联邦学习攻防实战 前言 1. 后门攻击 1.1 问题定义 1.2 后门攻击策略 1.3 详细实现 2. 差分隐私 2.1 集中式差分隐私 2.2 联邦差分隐私 2.3 详细实现 3. ...
- 数据分析综述:联邦学习中的数据安全和隐私保护问题
©作者 | Doreen 01 联邦学习的背景知识 近年来,随着大量数据.更强的算力以及深度学习模型的出现,机器学习在各领域的应用中取得了较大的成功. 然而在实际操作中,为了使机器学习有更好的效果,人 ...
- 联邦学习攻击与防御综述
联邦学习攻击与防御综述 吴建汉1,2, 司世景1, 王健宗1, 肖京1 1.平安科技(深圳)有限公司,广东 深圳 518063 2.中国科学技术大学,安徽 合肥 230026 摘要:随着机器学习技术的 ...
- 虚拟专题:联邦学习 | 联邦学习隐私保护研究进展
来源:大数据期刊 联邦学习隐私保护研究进展 王健宗, 孔令炜, 黄章成, 陈霖捷, 刘懿, 卢春曦, 肖京 平安科技(深圳)有限公司,广东 深圳 518063 摘要:针对隐私保护的法律法规相继出台,数 ...
- 联邦学习隐私保护研究进展
点击上方蓝字关注我们 联邦学习隐私保护研究进展 王健宗, 孔令炜, 黄章成, 陈霖捷, 刘懿, 卢春曦, 肖京 平安科技(深圳)有限公司,广东 深圳 518063 摘要:针对隐私保护的法律法规相继出台 ...
- 后门攻击经典背景文献(综述)
总结 攻击在各个场景都有体现,比如外包场景.迁移学习.联邦学习等,主要集中于前两个前景,联邦学习的攻击还有待发展. 攻击手段都集中在带触发器输入的构造上,无论是直接设计,还是使用目标模型的参数进行优化 ...
- 联邦学习入门(一)-Advances and Open Problems in Federated Learning详解
本文主要是联邦学习的入门级笔记,主要参考了论文Advances and Open Problems in Federated Learning和微众银行的联邦学习白皮书,笔者作为初次接触该领域的小白, ...
最新文章
- JavaScript异步流程控制的前世今生
- 最新县及县以上行政区划代码(截止2010年12月31日)
- 如何在页面加载完成后再去做某事?什么方法可以判断当前页面加载已完成?...
- [机器学习] 训练集(train set) 验证集(validation set) 测试集(test set)
- JS魔法堂:彻底理解0.1 + 0.2 === 0.30000000000000004的背后 1
- CVE-2019-0708 BlueKeep的扫描和打补丁
- 微信公众号后台开发总结
- GD32F130之DMA
- 2w 字长文带你搞懂 Linux 命令行
- 阿里云储道深度解析存储系统设计——NVMe SSD性能影响因素一探究竟
- [HNOI2004]高精度开根
- Uber收购动作引发巨震 美国外卖“三国杀”有望诞生“美团”?
- ElasticSearch---------------------step3,安装Kibana
- 微信公众平台开发 账号快速申请
- 安卓音乐播放器app开发(一)---功能分析及启动页的制作
- 计算机专业学生如何找到一份优质实习?如何进大厂呢?
- 小白笔记---坐标系、坐标参照系、坐标变换、投影变换
- 京东宣布涨薪,两年内将员工平均年薪从14薪涨到16薪!
- Javascript中的事件捕获、事件冒泡与事件委托
- Unity 事件系统
热门文章
- Alexa工具查询网站流量
- linux命令行模式kvm,Linux命令行管理KVM虚拟机【一】 | C/C++程序员之家
- 励志暖心英语短文-改变我的那个瞬间
- 魔兽怀旧服服务器周2维护一次吗,魔兽怀旧服:服务器再次崩溃,角色全部消失,需要维护14个小时...
- 2023全国安全生产合格证危险化学品生产单位主要负责人模拟一[安考星]
- python按键脚本会被检测封号_js调用python脚本文件挂会封号吗
- 002945华林证券75天亏86%中签的人却亏了近200%
- 【如何在linux系统里安装无线网卡驱动】
- 对import与require用法
- 同学,这有一份「实践证明」请领取!