这是一种使用GAN来生成对抗样本的模型

代码:

首先来看一个训练过程

代码中首先训练的是D

首先用generator生成干扰项 perturbation,然后与原图相加形成对抗样本 adv_images

当然训练一个D的loss分为了两部分,loss_D_real旨在拉近吃正样本之后的输出与1的距离

loss_D_fake旨在拉近吃负样本之后与0的距离,这里的负样本就是对抗样本,输入的时候不要忘了detach掉

        # optimize D# x are the input imagesfor i in range(1):perturbation = self.netG(x)  # torch.Size([128, 1, 28, 28])# add a clipping trickadv_images = torch.clamp(perturbation, -0.3, 0.3) + xadv_images = torch.clamp(adv_images, self.box_min, self.box_max)self.optimizer_D.zero_grad()pred_real = self.netDisc(x)loss_D_real = F.mse_loss(pred_real, torch.ones_like(pred_real, device=self.device))loss_D_real.backward()pred_fake = self.netDisc(adv_images.detach())loss_D_fake = F.mse_loss(pred_fake, torch.zeros_like(pred_fake, device=self.device))loss_D_fake.backward()loss_D_GAN = loss_D_fake + loss_D_realself.optimizer_D.step()

训练G的过程就有些复杂了

首先要G的Gan损失的训练目标是让自己生成的对抗样本,在D看起来和正样本1相近

下方的retain_graph = True的意思是保留当前方向传播的计算图,可以做梯度累加

可以参见这两篇博客https://www.cnblogs.com/picassooo/p/13748618.html

https://www.cnblogs.com/picassooo/p/13818952.html

            self.optimizer_G.zero_grad()# cal G's loss in GANpred_fake = self.netDisc(adv_images)loss_G_fake = F.mse_loss(pred_fake, torch.ones_like(pred_fake, device=self.device))loss_G_fake.backward(retain_graph=True)

接下来就是限制扰动大小的损失

这里设计的是一个batch之中所有图片的矩阵二范数都不能太大

            # calculate perturbation normC = 0.1loss_perturb = torch.mean(torch.norm(perturbation.view(perturbation.shape[0], -1), 2, dim=1))

接下来就是样本对抗损失

onehot_labels这里的实现是比较优雅的,总体功能是根据手写数字的类别转换为onehot编码的格式,torch.eye的功能就是得到onehot编码,然后使用lables变量中对应的类别把他提取出来

real的功能,按照我粗浅的理解,是得到网络针对一个batch中所有对抗样本预测正确的概率。other的功能,是得到了网络针对一个batch中的所有对抗样本预测为错误的类别中,可能性最大的概率。

那个torch.max(real-other,0)的功能,按照我粗浅的理解,首先看real-ohter的部分,因为损失函数都是梯度下降的,最小化这个损失函数,相当于训练模型让real更小,other更大,犯错的概率越大。之所以要与0相max,也许是小于0的时候,other已经大于real了,然后没必要训练这个部分了?

最后一个损失函数我可能理解的不正确,还是要看一下那个C&W模型是怎么设计的

            # cal adv losslogits_model = self.model(adv_images)probs_model = F.softmax(logits_model, dim=1)onehot_labels = torch.eye(self.model_num_labels, device=self.device)[labels]  # torch.Size([128, 10])# C&W loss functionreal = torch.sum(onehot_labels * probs_model, dim=1)  # [128]other, _ = torch.max((1 - onehot_labels) * probs_model - onehot_labels * 10000, dim=1)zeros = torch.zeros_like(other)loss_adv = torch.max(real -other, zeros)loss_adv = torch.sum(loss_adv)

接下来就是把这两个loss乘以一个超参权重,然后backward就好了

            adv_lambda = 10pert_lambda = 1loss_G = adv_lambda * loss_adv + pert_lambda * loss_perturbloss_G.backward()self.optimizer_G.step()return loss_D_GAN.item(), loss_G_fake.item(), loss_perturb.item(), loss_adv.item()

advGAN代码笔记相关推荐

  1. CSDN技术主题月----“深度学习”代码笔记专栏

    from: CSDN技术主题月----"深度学习"代码笔记专栏 2016-09-13 nigelyq 技术专题 Hi,各位用户 CSDN技术主题月代码笔记专栏会每月在CODE博客为 ...

  2. 看完师兄的代码笔记,我失眠了

    祝大家中秋节快乐! 最近很多公司的秋季招聘都已经启动了. 想必大家(尤其是经历过求职面试的)都知道,数据结构和算法在求职笔试/面试中的重要性. 纵观如今的互联网公司面试,清一色地都在重点考查这块,不开 ...

  3. LSTM TF核心实现代码笔记

    LSTM TF核心实现代码笔记 1. LSTM TF里的核心代码实现 2. 代码详细讲解 1. LSTM TF里的核心代码实现 LSTM网络的核心实现是在这个包里tensorflow/python.k ...

  4. 2018年最新Spring Boot视频教程附代码笔记资料(50G)

    1. Spring Boot  项目实战 ----- 技术栈博客企业前后端 链接:https://pan.baidu.com/s/1hueViq4 密码:4ma8 2.Spring Boot  项目实 ...

  5. Python Text Processing with NLTK 2.0 Cookbook代码笔记

    如下是<Python Text Processing with NLTK 2.0 Cookbook>一书部分章节的代码笔记. Tokenizing text into sentences ...

  6. Transformer课程 第8课NER案例代码笔记-IOB标记

    Transformer课程 第8课NER案例代码笔记-IOB标记 NER Tags and IOB Format 训练集和测试集都是包含餐厅相关文本(主要是评论和查询)的单个文件,其中每个单词都有一个 ...

  7. 2018尚硅谷SpringBoot视频教程附代码+笔记+课件(内含Docker)

    尚硅谷SpringBoot视频教程(内含Docker)附代码+笔记+课件 下载地址:百度网盘

  8. 吴恩达机器学习MATLAB代码笔记(1)梯度下降

    吴恩达机器学习MATLAB代码笔记(1)梯度下降 单变量线性回归 1.标记数据点(Plotting the Date) fprintf('Plotting Data') data = load('D: ...

  9. Transformer课程 第8课NER案例代码笔记-部署简介

    Transformer课程 第8课NER案例代码笔记 BERT微调器 NER是信息提取的子任务,旨在将非结构化文本中提到的命名实体定位并分类为预定义类别,如人名.组织.位置.医疗代码.时间表达式.数量 ...

  10. 【SFND_Lidar_Obstacle_Detection】代码笔记

    源代码链接: https://github.com/williamhyin/SFND_Lidar_Obstacle_Detection 激光雷达数据: x,y,z,indensity(可用于评价物体的 ...

最新文章

  1. Mysql 获取当月和上个月第一天和最后一天的解决方案
  2. 流式计算框架Storm后台启动命令(避免新开窗口)
  3. Python中的并行处理(Pool.map()、Pool.starmap()、Pool.apply()、)
  4. oracle中directory的使用
  5. H3C服务器系统配置ip,H3C交换机DHCP 服务器动态分配地址典型配置指导
  6. 【Elasticsearch】Elasticsearch性能调优:千万不要做愚蠢的事
  7. 利用cookie爬取QQ邮箱的python脚本
  8. 富士施乐248b粉盒清零_能不能告诉我施乐5070硒鼓芯片清零方法是什么
  9. 最小环flody hdu6080
  10. Python-re中search()函数的用法-----查找ip(超详细)
  11. asp.net消除锯齿的办法
  12. python海龟教程_Python 零基础 快速入门 趣味教程 (咪博士 海龟绘图 turtle) 7. 条件循环...
  13. cdrx8如何批量导出jpg_cdr x8批量导出插件
  14. 蒙太奇服务器维修,蒙太奇服务器多台互连导片方法
  15. python平均值代码_python中的运行平均值
  16. 为什么没有下划线_资料1907:xumin字体打不出下划线?凌哥英语送您改进版!
  17. 从I4GL迁移到EGL
  18. 信息化课堂怎么控屏教学的
  19. A Game of Thrones(19)
  20. mt6735 [AudioDriver]mt6592使用如何使用2nd I2S输出24bits format

热门文章

  1. CSS布局——导航栏悬浮滚动更改背景色
  2. 时域特征提取_EEG信号特征提取算法
  3. 一款基于java开发的开源监控平台
  4. 企业微信 PC端多开
  5. 模块一 day03 Python基础
  6. linux监控系统catic,网络设备监控-Catic添加H3C的监控图解
  7. detached entity passed to persist:xxx;
  8. 计算程序运行时间,并将毫秒换算成人看得懂的文字,展示形式为时分秒
  9. pycharm报错: with exit code -1073740791 (0xC0000409)
  10. x的x分之一次方极限x趋于0_e的x分之一的左右极限