Fast Gradient Sign Attack(FGSM)算法小结


对抗攻击引发了机器学习一系列的思考,训练出来的model是否拥有强大的泛化能力?模型的准确率是否真实?
在对抗攻击中添加一些肉眼无法识别的噪声可能会对识别效果产生巨大的影响。

什么是对抗攻击

对抗攻击的核心思想就是人为地制造干扰项去迷惑模型,使模型产生错误的结果。在计算机视觉中,对抗攻击就是在原图上添加一些人为无法识别的噪声生成干扰图片,使得模型作出错误的判断。

对抗攻击分类

无目标的对抗攻击:只是让目标模型的判断出错

有目标的对抗攻击:引导目标模型做出我们想要错误判断

以对目标模型的了解程度为标准,对抗攻击又可以分成白盒攻击和黑盒攻击

白盒攻击:在已经获取机器学习模型内部的所有信息和参数上进行攻击

黑盒攻击:在神经网络结构为黑箱时,仅通过模型的输入和输出,逆推生成对抗样本。

FGSM算法原理

直观来看就是在输入的基础上沿损失函数的梯度方向加入了一定的噪声,使目标模型产生了误判。
如下图所示:原图加上超参数 ϵ 乘与 损失梯度生成新的干扰图片。

如下图所示为生成干扰图片的完整公式:x*表示对抗样本,x表示原样本,J() 表示损失函数,ϵ 表示超参数。

对于某个特定的模型而言,FGSM将损失函数近似线性化(对于神经网络而言,很多神经网络为了节省计算上的代价,都被设计成了非常线性的形式,这使得他们更容易优化,但是这样”廉价”的网络也导致了对于对抗扰动的脆弱性)。

也就是说,即是是神经网络这样的模型,也能通过线性干扰来对它进行攻击。

FGSM实例

导入所需包

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

初始化超参数等全局变量

epsilons = [0, .05, .1, .15, .2, .25, .3]
pretrained_model = "../datasets/lenet_mnist_model.pth"  # 使用的预训练模型路径
use_cuda=True

定义网络结构

class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../datasets', train=False, download=True, transform=transforms.Compose([transforms.ToTensor(),])),batch_size=1, shuffle=True
)
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")model = Net().to(device)model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
model.eval()

定义对抗图片生成函数

def fgsm_attack(image, epsilon, data_grad):"""获取扰动图片:param image: 原始图片:param epsilon: 扰动量:param data_grad: 损失梯度:return:"""# Collect the element-wise sign of the data gradientsign_data_grad = data_grad.sign()# Create the perturbed image by adjusting each pixel of the input imageperturbed_image = image + epsilon*sign_data_grad# Adding clipping to maintain [0,1] rangeperturbed_image = torch.clamp(perturbed_image, 0, 1)# Return the perturbed imagereturn perturbed_image

测试函数定义

def test(model, device, test_loader, epsilon):# Accuracy countercorrect = 0adv_examples = []for data, target in test_loader:data, target = data.to(device), target.to(device)# Set requires_grad attribute of tensor. Important for Attackdata.requires_grad = Trueoutput = model(data)init_pred = output.max(1, keepdim=True)[1]if init_pred.item() != target.item():continue# lossloss = F.nll_loss(output, target) # 用于多分类的负对数似然损失函数(Negative Log Likelihood) loss(x,label)=−xlabelmodel.zero_grad()loss.backward()# Collect datagraddata_grad = data.grad.dataperturbed_data = fgsm_attack(data, epsilon, data_grad)output = model(perturbed_data)final_pred = output.max(1, keepdim=True)[1]if final_pred.item() == target.item():correct += 1if (epsilon == 0) and (len(adv_examples) < 5):adv_ex = perturbed_data.squeeze().detach().cpu().numpy()adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))else:if len(adv_examples) < 5:adv_ex = perturbed_data.squeeze().detach().cpu().numpy()adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))# Calculate final accuracy for this epsilonfinal_acc = correct / float(len(test_loader))print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))# Return the accuracy and an adversarial examplereturn final_acc, adv_examples

开始测试

accuracies = []
examples = []
for eps in epsilons:acc, ex = test(model, device, test_loader, eps)accuracies.append(acc)examples.append(ex)plt.figure(figsize=(5,5))
plt.plot(epsilons, accuracies, "*-")
plt.yticks(np.arange(0, 1.1, step=0.1))
plt.xticks(np.arange(0, .35, step=0.05))
plt.title("Accuracy vs Epsilon")
plt.xlabel("Epsilon")
plt.ylabel("Accuracy")
plt.show()# Plot several examples of adversarial samples at each epsilon
cnt = 0
plt.figure(figsize=(8,10))
for i in range(len(epsilons)):for j in range(len(examples[i])):cnt += 1plt.subplot(len(epsilons),len(examples[0]),cnt)plt.xticks([], [])plt.yticks([], [])if j == 0:plt.ylabel("Eps: {}".format(epsilons[i]), fontsize=14)orig,adv,ex = examples[i][j]plt.title("{} -> {}".format(orig, adv))plt.imshow(ex, cmap="gray")
plt.tight_layout()
plt.show()

下图所示为测试结果图:

最后得出超参数 ϵ 越大(但是不超过 1 )生成的干扰图片就越有效。

完整代码:
https://github.com/AndyandViky/ML-study/blob/master/pytorch/adversarial-example.py
Reference
https://undefinedf.github.io/2018/11/02/对抗攻击之FGSM/
https://pytorch.org/tutorials/beginner/fgsm_tutorial.html

Fast Gradient Sign Attack(FGSM)算法小结相关推荐

  1. 以FGSM算法为例的对抗训练的实现(基于Pytorch)

    如果可以,请点个赞,这是我写博客的动力,谢谢各位观众 1. 前言 深度学习虽然发展迅速,但是由于其线性的特性,受到了对抗样本的影响,很容易造成系统功能的失效. 以图像分类为例子,对抗样本很容易使得在测 ...

  2. 《Fast Gradient Projection Method for Text Adversary Generation and Adversarial Training》论文学习笔记

    最近在学习对抗学习在文本分类方面的论文,对抗训练在提高深度神经网络对图像分类的鲁棒性方面表现出了有效性和高效性.然而,对于文本分类,文本输入空间的离散特性使得基于梯度的对抗方法难以从图像域进行自适应. ...

  3. APG(Accelerate Proximal Gradient)加速近端梯度算法 和 NAG(Nesterov accelerated gradient)优化器原理 (二)

    文章目录 前言 NAG优化器 APG 与 NAG的结合 Pytorch 代码实现 总结 附录 公式(11)推导 引用 前言 近期在阅读Data-Driven Sparse Structure Sele ...

  4. Proximal Gradient Method近端梯度算法

    本文参考文献附在最后.是对参考文献的理解. 1:此算法解决凸优化问题模型如下: minF(x)=g(x)+h(x) min F(x)=g(x)+h(x)其中 g(x) g(x)凸的,可微的. h(x) ...

  5. 数据挖掘中分类算法小结

    数据挖掘中分类算法小结   数据仓库,数据库或者其它信息库中隐藏着许多可以为商业.科研等活动的决策提供所需要的知识.分类与预测是两种数据分析形式,它们可以用来抽取能够描述重要数据集合或预测未来数据趋势 ...

  6. [转载]SIFT(尺度不变特征变换)算法小结

    原文地址:SIFT(尺度不变特征变换)算法小结[转]作者:慕容天峰 最近一直在看SIFT算法.Sift是David Lowe于1999年提出的局部特征描述子,并于2004年进行了更深入的发展和完善.S ...

  7. linemod算法小结

    Linemod算法小结   LineMod方法是由Hinterstoisser[1][2][3]在2011年提出,主要解决的问题是复杂背景下3D物体的实时检测与定位,用到了RGBD的信息,可以应对无纹 ...

  8. 分治算法小结(附例题详解)

    分治算法小结(附例题详解) 我的理解: 分治算法我的理解就是看人下菜碟,我们要解决的问题就好像一群人构成的集体,要我们解决这个问题,那我们就要满足这群人里面每个人不同的需求,也就是写出解决的代码,把每 ...

  9. APG(Accelerate Proximal Gradient)加速近端梯度算法 和 NAG(Nesterov accelerated gradient)优化器原理 (一)

    文章目录 前言 APG(Accelerate Proximal Gradient)加速近端梯度算法[^1] PGD (Proximal Gradient Descent)近端梯度下降法推导[^2] E ...

最新文章

  1. 米家电磁炉显示e10_小米“米家电磁炉C1”评测:7挡火力,2100W大功率设计
  2. 局域网流量控制_羡慕多屏协同?这3款 App 让你的电脑也能轻松控制 Android 手机...
  3. 在博客园写了一年博客,收获的不仅仅是写作技能——我能一直保持积极的学习和工作态度...
  4. MyCat学习:使用MySQL搭建主从复制(双主双从模式)
  5. Winform中设置多条Y轴时新增的Y轴刻度不显示问题解决
  6. linux 运行python 看不到异常信息_linux python运行报编码错误
  7. Mysql笔记——DML
  8. HttpHandler浅析
  9. 世界手机号码格式_世界上手机号码最长的国家是中国,最短的是哪个国家?
  10. 【网络安全工程师面试合集】—黑客常用的端口及攻击方法汇总
  11. 有什么办法让Beyond Compare以网页形式显示文件
  12. MAC maven 安装和配置
  13. 统计通话次数和时间的软件_通话时间统计器下载-通话时间统计 安卓版v2.6-PC6安卓网...
  14. 20-统一网关Gateway-全局过滤器
  15. Windows上查看MTU值和修改MTU的方法
  16. 【Unity2D入门教程】简单制作一个弹珠游戏之制作场景①(开场,结束,板子,球)
  17. 二手书电商闲鱼、转转们的花样淘金和眼前僵局
  18. 欠债还钱,天经地义(一)
  19. HAL库版STM32双轮自平衡车(五) ———— 调参
  20. unity 音频可视化

热门文章

  1. 数组转换成字符串 join、toString、toLocaleString
  2. 刚刚入门,还请各路大神多多关照
  3. gcc编译选项-fPIC
  4. Unity批量更改脚本名字
  5. C# WinForm中四种显示信息的方式
  6. vscode模板自动补全
  7. Vue移动端项目(一)
  8. 使用Glide加载、缓存图片、Gif、解决背景出现浅绿色、GlideModules冲突
  9. 大厂出品的Web端AI语音转文字神器
  10. 项目部署到centos7服务器验证码乱码