复现《Deep Leakage from Gradients》的攻击实验

Deep Leakage from Gradients

在GitHub上找到一个在pytorch实现《Deep Leakage from Gradients》论文中对CIFAR100数据集攻击的实验,加上了自己的理解

class LeNet(nn.Module):
def __init__(self):super(LeNet, self).__init__()act = nn.Sigmoidself.body = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, padding=5//2, stride=2),act(),nn.Conv2d(32, 32, kernel_size=5, padding=5//2, stride=2),act(),nn.Conv2d(32, 64, kernel_size=5, padding=5//2, stride=1),act(),)self.fc = nn.Sequential(nn.Linear(4096, 100))def forward(self, x):out = self.body(x)out = out.view(out.size(0), -1)#print(out.size())out = self.fc(out)return out

`

原代码有两种模型,一种Lenet,一种为Resnet,我用的第一种其中它源代码的卷积通道都为12,但是自己在实现的时候发现最后恢复不了原始的图片,全部都是噪音,不知到它是怎么实现的,摊手.jpg。然后自己将通道数换成32,32,64,然后就奇迹发生了,只迭代了0次loss就0.001???,可能是自己实现的有问题吧。
上图!

有没有很奇怪的感觉?? ,但是可以恢复出原始数据,(_),管不了这么多了
异常清晰啊,朋友们,不得不说这篇论文的idel是真的棒啊!!!
最后上主要代码:
修改了一些变量名并且加入了一些注释方便理解

# -*- coding: utf-8 -*-
import argparse
import numpy as np
from pprint import pprintfrom PIL import Image
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transformsfrom utils import label_to_onehot, cross_entropy_for_onehot   #将标签onehot化   并使用onehot形式的交叉熵损失函数parser = argparse.ArgumentParser(description='Deep Leakage from Gradients.')
parser.add_argument('--index', type=int, default="45",help='the index for leaking images on CIFAR.')
parser.add_argument('--image', type=str,default="",help='the path to customized image.')
args = parser.parse_args()device = "cpu"
if torch.cuda.is_available():device = "cuda"
print("Running on %s" % device)data_cifar = datasets.CIFAR100("/.torch", download=True)
To_tensor = transforms.ToTensor()
To_image = transforms.ToPILImage()img_index = args.indexgt_data = To_tensor(data_cifar[img_index][0]).to(device)  #image_index[i][0]表示的是第I张图片的data,image_index[i][1]表示的是第i张图片的lableif len(args.image) > 1:    #得到预设参数的图片并将其转换为tensor对象gt_data = Image.open(args.image)gt_data = To_tensor(gt_data).to(device)gt_data = gt_data.view(1, *gt_data.size())gt_label = torch.Tensor([data_cifar[img_index][1]]).long().to(device)
gt_label = gt_label.view(1, )
gt_onehot_label = label_to_onehot(gt_label)plt.imshow(To_image(gt_data[0].cpu()))from models.vision import LeNet,  ResNet18
net = LeNet().to(device)torch.manual_seed(1234)#net.apply(weights_init)
criterion = cross_entropy_for_onehot  #调用损失函数# compute original gradient
pred = net(gt_data)
y = criterion(pred, gt_onehot_label)
dy_dx = torch.autograd.grad(y, net.parameters())   #获取对参数W的梯度original_dy_dx = list((_.detach().clone() for _ in dy_dx))    #对原始梯度复制# generate dummy data and label
dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)plt.imshow(To_image(dummy_data[0].cpu()))optimizer = torch.optim.LBFGS([dummy_data, dummy_label])history = []
for iters in range(300):def closure():optimizer.zero_grad()  #梯度清零dummy_pred = net(dummy_data) dummy_onehot_label = F.softmax(dummy_label, dim=-1)dummy_loss = criterion(dummy_pred, dummy_onehot_label) dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)   #faked数据得到的梯度grad_diff = 0for gx, gy in zip(dummy_dy_dx, original_dy_dx): grad_diff += ((gx - gy) ** 2).sum()   #计算fake梯度与真实梯度的均方损失grad_diff.backward()    #对损失进行反向传播    优化器的目标是fake_data, fake_labelreturn grad_diffoptimizer.step(closure)if iters % 10 == 0: current_loss = closure()print(iters, "%.4f" % current_loss.item())history.append(To_image(dummy_data[0].cpu()))plt.figure(figsize=(12, 8))
for i in range(30):plt.subplot(3, 10, i + 1)plt.imshow(history[i])plt.title("iter=%d" % (i * 10))plt.axis('off')plt.show()

这篇论文的核心就是构建一个损失函数——自己创建的一个fake的(dummy_data,dummy_label)所得到的梯度与原始训练数据(True_data,True_label)所得到的梯度的均方误差,再用损失函数对(dummy_data,dummy_label)最优化,不断迭代以恢复出原始数据

原文再cifar100数据集迭代了差不多100次才能恢复出原始数据,而我参考的这个代码很快就迭代完成了,也不知道原文的源代码是怎样写的,后续再仔细研究一下这个代码。

参考:
代码地址: https://github.com/mit-han-lab/dlg
文献地址:https://papers.nips.cc/paper/9617-deep-leakage-from-gradients.pdf

复现《Deep Leakage from Gradients》的攻击实验相关推荐

  1. Deep Leakage From Gradients文献阅读及代码重现

    目录 1.DLG文献解析 1.1 背景介绍 1.2 算法描述 1.3 实验结果 2.iDLG文献解析 2.1 算法描述 2.2 实验结果 3.代码(DLG和iDLG) 1.DLG文献解析 文献地址: ...

  2. Deep leakage from Gradients论文解析

    Deep leakage from Gradients论文解析 今天来给大家介绍下2019年NIPS上发表的一篇通过梯度进行原始数据恢复的论文. 论文传送门 **问题背景:**现在分布式机器学习和联邦 ...

  3. 论文阅读:Deep Leakage From Gradients

    论文名字 Deep Leakage From Gradients 来源 顶会 NeurIPS 年份 2019.12 作者 Ligeng Zhu  Zhijian Liu  Song Han 核心点 主 ...

  4. Deep Leakage from Gradients

    Summary 对于分布式学习,特别是相关之前共享梯度的学习,提出了一种攻击方式(DLG).通过窃取client之间传递的梯度反推出(也是使用机器学习迭代的方式)原始的输入.并在图像分类.Masked ...

  5. 脏牛(Dirty COW)漏洞攻击实验(SEED-Lab:Dirty-COW Attack Lab)

    <脏牛(Dirty COW)漏洞攻击实验> 目录 <脏牛(Dirty COW)漏洞攻击实验> **一:实验目的** **二:实验步骤与结果** **漏洞原理:** **COW机 ...

  6. 实验三 ShellShock 攻击实验

    ShellShock 攻击实验 沙雨济 一. 实验描述 2014年9月24日,Bash中发现了一个严重漏洞shellshock,该漏洞可用于许多系统,并且既可以远程也可以在本地触发.在本实验中,学生需 ...

  7. CSAPP lab3 bufbomb-缓冲区溢出攻击实验(下)bang boom kaboom

    CSAPP lab3 bufbomb-缓冲区溢出攻击实验(上)smoke fizz CSAPP lab3 bufbomb-缓冲区溢出攻击实验(下)bang boom kaboom 栈结构镇楼 这里先给 ...

  8. 《Linux内核原理与设计》第十一周作业 ShellShock攻击实验

    <Linux内核原理与设计>第十一周作业 ShellShock攻击实验 分组: 和20179215袁琳完成实验及博客攥写 实验内容:   Bash中发现了一个严重漏洞shellshock, ...

  9. 计算机系统基础学习笔记(7)-缓冲区溢出攻击实验

    缓冲区溢出攻击实验 实验介绍 实验任务 实验数据 目标程序 bufbomb 说明 bufbomb 程序接受下列命令行参数 目标程序bufbomb中函数之间的调用关系 缓冲区溢出理解 目标程序调用的ge ...

最新文章

  1. IJCAI 2019:中国团队录取论文超三成,北大、南大榜上有名
  2. 在Python中将整数附加到列表的开头
  3. Nacos 2.0的Spring Boot Starter来了!
  4. .NET 2.0 CER学习笔记
  5. Codeforces Round #501 (Div. 3)【未完结】
  6. struct and union
  7. SpringMVC表单验证与Velocity整合
  8. 年度神作!这本Python 3.6的书刷爆朋友圈,网友:太香!
  9. 四边形不等式优化dp
  10. php批量数据提交mysql_php在mysql里批量插入数据(代码实例)
  11. 连续状态空间模型离散化
  12. opencv项目6----AI绘画(隔空绘画)
  13. 计算机42D,汉印G42D 电子面单打印机
  14. 建造者模式(Builder)---创建型
  15. springboot网页小图标
  16. 获取固定到任务栏的快捷方式的图标
  17. OC5038内置 MOS 开关降压型 LED 恒流驱动器
  18. asr标注工具_传统ASR全流程【转载】
  19. 二分图的Hall定理
  20. 多年厮杀,腾讯、阿里、百度、小米的投资版图长什么样!

热门文章

  1. 海天味业又火了,市值突破五千亿,“卖酱油”甚比“卖茅台”
  2. 整理的jquery使用技巧
  3. 微服务架构下的 服务熔断, 降级, 限流
  4. Java如何获取JSON数据中的值 备忘
  5. 环迅支付2015新代理政策
  6. 云计算就像马拉松 京东CTO为啥这么说
  7. VB6.0简繁体转换步骤
  8. ❤️垃圾大学,想自学 Java 可以吗?难吗?毕业后能找到一份 6k左右的工作吗?
  9. 操作系统-用信号量解决小和尚打水老和尚喝水问题
  10. 如何学好数据结构与算法(视频+文字版)