```python
import numpy as np
import torch   # pytorch机器学习开源框架
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision
import torch.optim as optim
from torchvision import transforms
from tqdm import *
import matplotlib.pyplot as plt
import copy
from torch.autograd.gradcheck import zero_gradientsclass Net(nn.Module):# 定义Net的初始化函数,这个函数定义了该神经网络的基本结构def __init__(self):# 复制并使用Net的父类的初始化方法,即先运行nn.Module的初始化函数super(Net, self).__init__()# 定义fc1(fullconnect)全连接函数1为线性函数:y = Wx + b,并将28*28个节点连接到300个节点上。self.fc1 = nn.Linear(28*28, 300)# 定义fc2(fullconnect)全连接函数2为线性函数:y = Wx + b,并将300个节点连接到100个节点上。self.fc2 = nn.Linear(300, 100)# 定义fc3(fullconnect)全连接函数3为线性函数:y = Wx + b,并将100个节点连接到10个节点上。self.fc3 = nn.Linear(100, 10)#定义该神经网络的向前传播函数,该函数必须定义,一旦定义成功,向后传播函数也会自动生成(autograd)def forward(self, x):# 输入x经过全连接1,再经过ReLU激活函数,然后更新xx = F.relu(self.fc1(x))x = F.relu(self.fc2(x))# 输入x经过全连接3,然后更新xx = self.fc3(x)return x# 定义数据转换格式
mnist_transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x : x.resize_(28*28))])# 导入数据,定义数据接口
# 1.root,表示mnist数据的加载的相对目录
# 2.train,表示是否加载数据库的训练集,false的时候加载测试集
# 3.download,表示是否自动下载mnist数据集
# 4.transform,表示是否需要对数据进行预处理,none为不进行预处理
traindata = torchvision.datasets.MNIST(root="./drive/My Drive/fgsm/mnist", train=True, download=True, transform=mnist_transform)
testdata  = torchvision.datasets.MNIST(root="./drive/My Drive/fgsm/mnist", train=False, download=True, transform=mnist_transform)# 将训练集的*张图片划分成*份,每份256(batch_size)张图,用于mini-batch输入
# shffule=True在表示不同批次的数据遍历时,打乱顺序
# num_workers=n表示使用n个子进程来加载数据
trainloader = torch.utils.data.DataLoader(traindata, batch_size=256, shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testdata, batch_size=256, shuffle=True, num_workers=0)# 展示图片
index = 100
image = testdata[index][0]
label = testdata[index][1]
image.resize_(28,28) # 调整图片大小
img = transforms.ToPILImage()(image)
plt.imshow(img)
plt.show()index = 100
batch = iter(testloader).next() # 将testloader转换为迭代器
# 例如:如果batch_size为4,则取出来的images是4×c×h×w的向量,labels是1×4的向量
image = batch[0][index]
label = batch[1][index]
image.resize_(28,28)
img = transforms.ToPILImage()(image)
plt.imshow(img)
plt.show()net = Net()
loss_function = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-04) # 随机梯度下降优化num_epoch = 50
for epoch in tqdm(range(num_epoch)): # python进度条,num_epoch=50,所以每2%显示一次losses = 0.0for data in trainloader:inputs, labels = data # 获取输入# inputs, labels = Variable(inputs), Variable(labels)optimizer.zero_grad() # 参数梯度置零# 前向+ 反向 + 优化outputs = net(inputs)loss = loss_function(outputs, labels) # 计算lossloss.backward() # 传回反向梯度optimizer.step() # 梯度传回,利用优化器将参数更新losses += loss.data.item() # 输出统计 print("*****************当前平均损失为{}*****************".format(losses/2000.0))correct = 0 # 定义预测正确的图片数,初始化为0
total = 0 # 总共参与测试的图片数,也初始化为0
for data in testloader:images, labels = data outputs = net(Variable(images))# 输入网络进行测试 # 因为神经网络只能输入Variable_, predicted = torch.max(outputs.data, 1)#返回了最大的索引,即预测出来的类别。# 这个_,predicted是python的一种常用的写法,表示后面的函数其实会返回两个值# 但是我们对第一个值不感兴趣,就写个_在那里,把它赋值给_就好,我们只关心第二个值predicted# torch.max(outputs.data,1) ,返回一个tuple(元组)。第二个元素是labeltotal += labels.size(0) # 更新测试图片的数量correct += (predicted == labels).sum() # 更新正确分类的图片的数量
print("预测准确率为:{}/{}".format(correct, total))# 保存整个网络 #...It won't be checked...是保存模型时的输出
PATH1="./drive/My Drive/fgsm/mnist_net_all.pkl"
torch.save(net,PATH1)
# 保存网络中的参数,速度快,占空间少,PATH1是保存路径和文件名
PATH2="./drive/My Drive/fgsm/mnist_net_param.pkl"
torch.save(net.state_dict(),PATH2)
#针对上面一般的保存方法,加载的方法分别是:
#model_dict=torch.load(PATH)
#model_dict=model.load_state_dict(torch.load(PATH))net = torch.load(PATH1) # 加载模型index = 100 # 选择测试样本
image = testdata[index][0]
label = testdata[index][1]outputs = net(Variable(image)) # 因为神经网络只能输入Variable
predicted = torch.max(outputs.data,0)[1]
print('预测值为:{}'.format(predicted))image.resize_(28,28) # 显示一下测试的图片,和上文代码相同
img = transforms.ToPILImage()(image)
plt.imshow(img)
plt.show()PATH1="./drive/My Drive/fgsm/mnist_net_all.pkl"
net = torch.load(PATH1) # 加载模型
index = 100 # 选择测试样本
image = Variable(testdata[index][0].resize_(1,784), requires_grad=True)
label = torch.tensor([testdata[index][1]])f_image = net.forward(image).data.numpy().flatten() # flatten()函数默认是按行的方向降维
I = (np.array(f_image)).flatten().argsort()[::-1] # argsort()[::-1]表示降维排序后返回索引值
label = I[0] # 样本标签input_shape = image.data.numpy().shape # 获取原始样本的维度,返回(第一维长度,第二维长度,...)
pert_image = copy.deepcopy(image) # 深度复制原始样本,复制出来后就独立了
w = np.zeros(input_shape) # 返回来一个给定形状和类型的用0填充的数组
r_tot = np.zeros(input_shape)loop_i = 0 # 循环
max_iter = 50 # 最多迭代次数
overshoot = 0.0x = Variable(pert_image, requires_grad=True) # 因为神经网络只能输入Variable
fs = net.forward(x) # 调用forward函数
fs_list = [fs[0][I[k]] for k in range(len(I))] # 每个类别的取值情况,及其对应的梯度值
k_i = labelwhile k_i == label and loop_i < max_iter: # 分类标签变化时结束循环pert = np.inf # np.inf表示一个足够大的数fs[0][I[0]].backward(retain_graph=True) # 反向传播,计算当前梯度;连续执行两次backward,参数表明保留backward后的中间参数。orig_grad = x.grad.data.numpy().copy() # 原始梯度for k in range(len(I)): # 获得x到各分类边界的距离zero_gradients(x)fs[0][I[k]].backward(retain_graph=True)cur_grad = x.grad.data.numpy().copy() # 现在梯度w_k = cur_grad - orig_gradf_k = (fs[0][I[k]] - fs[0][I[0]]).data.numpy()pert_k = abs(f_k) / np.linalg.norm(w_k.flatten())if pert_k < pert:  # 获得最小的分类边界距离向量pert = pert_k # 更新perk,pert为最小距离w = w_kr_i = (pert + 1e-4) * w / np.linalg.norm(w)r_tot = np.float32(r_tot + r_i) # 累积扰动pert_image = image + (1+overshoot)*torch.from_numpy(r_tot) # 添加扰动x = Variable(pert_image, requires_grad=True)fs = net.forward(x)k_i = np.argmax(fs.data.numpy().flatten()) # 扰动后的分类标签loop_i += 1
r_tot = (1+overshoot)*r_tot # 最终累积的扰动outputs = net(pert_image.data.resize_(1,784))
predicted = torch.max(outputs.data,1)[1] #outputs含有梯度值,其处理方式与之前有所不同
print('预测值为:{}'.format(predicted[0]))pert_image = pert_image.reshape(28,28)
img = transforms.ToPILImage()(pert_image)
plt.imshow(img)
plt.show()

deepfool简单实现相关推荐

  1. 对抗攻击经典论文剖析(下)【DeepFool、One pixel attack、Universal adversarial perturbations、ATN】

    引言 上一篇讲的几篇经典对抗攻击论文主要讲的是如何在梯度上扰动或者优化,即尽可能保证下的扰动,不被人类发现,却大大降低了模型的性能.这一篇我们将会有一些更有意思的对抗攻击样本生成,包括像素级别的扰动以 ...

  2. 在docker上安装部署tomcat项目 超简单,拿来主义

    在docker中部署tomcat,非常简单,而且省去了手动安装jdk等步骤,只需要将war包复制在容器tomcat实例中的webapps下面即可.以下将详细讲解流程: 在windows中打好包以后用w ...

  3. Linux下tomcat的安装与卸载以及配置(超简单)

    无敌简单的几步 1.安装 //首先你需要下载好tomcat包 sudo tar -xvzf apache-tomcat-7.0.85.tar.gz(这里是包名) -C 你要放的位置 2.卸载 rm - ...

  4. Docker安装Apache与运行简单的web服务——httpd helloworld

    Docker运行简单的web服务--httpd helloworld目录[阅读时间:约5分钟] 一.Docker简介 二.Docker的安装与配置[CentOS环境] 三.Docker运行简单的web ...

  5. Docker的安装、镜像源更换与简单应用

    Docker的安装.镜像源更换与简单应用[阅读时间:约20分钟] 一.概述 二.系统环境&项目介绍 1.系统环境 2.项目的任务要求 三.Docker的安装 四.Docker的简单应用 1. ...

  6. 基于Golang的简单web服务程序开发——CloudGo

    基于Golang的简单web服务程序开发--CloudGo[阅读时间:约10分钟] 一.概述 二.系统环境&项目介绍 1.系统环境 2.项目的任务要求 (1)基本要求 (2)扩展要求 三.具体 ...

  7. 简单图文配置golang+vscode【win10/centos7+golang helloworld+解决install failed等情况】

    博客目录(阅读时间:10分钟) 一.win10 0.系统环境 1. win10配置golang环境 ①下载相关软件 ②创建gowork工作空间 ③配置环境变量(GOPATH+PATH) ④验证环境配置 ...

  8. 简单介绍互联网领域选择与营销方法

    在我看来,互联网领域的选择是"安家",而营销方法的不同则表现了"定家"的方式多种多样,只有选对了,"家"才得以"安定". ...

  9. JAVA用最简单的方法来构建一个高可用的服务端,提升系统可用性

    一.什么是提升系统的高可用性 JAVA服务端,顾名思义就是23体验网为用户提供服务的.停工时间,就是不能向用户提供服务的时间.高可用,就是系统具有高度可用性,尽量减少停工时间.如何用最简单的方法来搭建 ...

  10. java发送简单邮件_Java程序实现发送简单文本邮件

    /** * Java程序实现发送简单文本邮件 * * @author Administrator * */ public class SendTextMail { // 定义发件人地址 public  ...

最新文章

  1. Kimera实时重建的语义SLAM系统
  2. 性能测试的“2-5-10原则”
  3. 关于解决[INSTALL_FAILED_UPDATE_INCOMPATIBLE]
  4. 2015蓝桥杯省赛---java---B---3(三羊献瑞)
  5. 梦想还是要有的,万一实现了呢
  6. Perspective Mockups mac(PS透视模型动作插件)支持ps2021
  7. CentOS6.9安装Kafka
  8. python能制作ppt动画效果吗_你听说过Python可以做动画吗
  9. js产生两个数字之间的随机数
  10. paip.c++ static 变量的定义以及使用...
  11. python网页抓取与按键精灵原理一样吗_按键精灵等以GUI接口为基础的程序在爬虫界的地位是怎样的?...
  12. VBScript: 正则表达式(RegExp对象)
  13. 可以直接复制的emoji符号(表情)
  14. c语言简单计算器减编程,C语言实现简单的计算器(加、减、乘、除)
  15. java 排秩,lamd(java lambda表达式)
  16. 方维直播Android打包流程
  17. 【蓝桥杯冲刺 day12】题目全解析
  18. hw叠加层开还是不开_停用hw叠加层有什么用
  19. 陌生人交友软件有哪些?陌生人社交APP排名|良心推荐
  20. SIFT--特征描述符

热门文章

  1. 2022-2027年中国非人寿保险市场竞争态势及行业投资前景预测报告
  2. virtualization technology设置
  3. c语言tc游戏代码大全,wintcC语言小游戏画图代码.doc
  4. pmm9010在线测试软件,EMC/EMI 数字式测试接收机
  5. 淘宝商品上传API接口
  6. 2021虫虫百度域名URL批量采集工具【自动去重】
  7. linux学习---内存管理以及结存结构描述
  8. 用Keil工具搭建S3C2440编译环境
  9. 用三元运算符判断奇数和偶数
  10. QCOM chi-camera bring up