当前,在各大NLP竞赛中,对抗训练已然成为上分神器,尤其是fgm和pgd使用较多,下面来说说吧。对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力。

fgm

FGM的全称是Fast Gradient Method, 出现于Adversarial Training Methods for Semi-supervised Text Classification这篇论文,FGM是根据具体的梯度进行scale,得到更好的对抗样本:

整个对抗训练的过程如下,伪代码如下:

  • 1.计算x的前向loss、反向传播得到梯度;
  • 2.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r;
  • 3.计算x+r的前向loss,反向传播得到对抗的梯度,累加到(1)的梯度上;
  • 4.将embedding恢复为(1)时的值;
  • 5.根据(3)的梯度对参数进行更新。

fgm代码实现如下:

class FGM:def __init__(self, model: nn.Module, eps=1.):self.model = (model.module if hasattr(model, "module") else model)self.eps = epsself.backup = {}# only attack word embeddingdef attack(self, emb_name='word_embeddings'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm and not torch.isnan(norm):r_at = self.eps * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='word_embeddings'):for name, para in self.model.named_parameters():if para.requires_grad and emb_name in name:assert name in self.backuppara.data = self.backup[name]self.backup = {}

fgm应用代码如下:

##对应第一步
loss = model(**batch_data)[0]
loss.backward()
##对应第二步
fgm.attack()
#对应第三步
loss_adv = model(**batch_data)[0]
loss_adv.backward()
#对应第四步
fgm.restore()
#对应第五步
optimizer.step()

fgsm

FGSM的全称是Fast Gradient Sign Method. 如果要说FGSM和FGM的区别,核心区别在计算扰动的方式不一样,FGSM扰动的计算方式如下:

FGSM的其他算法流程跟FGM一样,这里不再赘述。

pgd

FGM直接通过epsilon参数一下子算出了对抗扰动,这样得到的可能不是最优的。因此PGD进行了改进,多迭代几次,慢慢找到最优的扰动。
引用:

FGM简单粗暴的“一步到位”,可能走不到约束内的最优点。PGD则是“小步走,多走几步”,如果走出了扰动半径为epsilon的空间,就映射回“球面”上,以保证扰动不要过大


并且

pgd整个对抗训练的过程如下,伪代码如下:

  • 1.计算x的前向loss、反向传播得到梯度并备份;
  • 2.对于每步t:
  •  a.根据embedding矩阵的梯度计算出r,并加到当前embedding上,相当于x+r(超出范围则投影回epsilon内);
    
  •  if t 不是最后一步,则进行b步骤:将模型梯度归0,根据a的x+r计算前后向并得到梯度,继续a步骤;if t 是最后一步,则进行c步骤:恢复(1)的梯度,根据a的x+r计算前后向得到梯度并将梯度累加到(1)的梯度上,跳出循环;
    
  • 3.将embedding恢复为(1)时的值;
  • 4.根据2c的梯度对参数进行更新。

可以看到,在循环中r是逐渐累加的,要注意的是最后更新参数只使用最后一个x+r算出来的梯度。
pgd代码实现如下:

class PGD:def __init__(self, model, eps=1., alpha=0.3):self.model = (model.module if hasattr(model, "module") else model)self.eps = epsself.alpha = alphaself.emb_backup = {}self.grad_backup = {}def attack(self, emb_name='word_embeddings', is_first_attack=False):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = self.alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data)def restore(self, emb_name='word_embeddings'):for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data):r = param_data - self.emb_backup[param_name]if torch.norm(r) > self.eps:r = self.eps * r / torch.norm(r)return self.emb_backup[param_name] + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad and param.grad is not None:self.grad_backup[name] = param.grad.clone()def restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad and param.grad is not None:param.grad = self.grad_backup[name]

pgd应用代码如下:

loss = model(**batch_data)[0]
loss.backward()
pgd.backup_grad()
for _t in range(pgd_k):pgd.attack(is_first_attack=(_t == 0))if _t != pgd_k - 1:model.zero_grad()else:pgd.restore_grad()loss_adv = model(**batch_data)[0]loss_adv.backward()
pgd.restore()
optimizer.step()

注:在torch中,每次迭代时,如果不把模型的梯度清零,会默认将模型每次迭代的梯度累加的。

FreeAT

FreeAT (Free Adversarial Training):来源于NIPS2019的一篇论文,从FGSM到PGD,主要是优化对抗扰动的计算,但计算量也一步步增加。对于每个样本,FGSM和FGM都只用计算两次,一次是计算x的前后向,一次是计算x+r的前后向。而PGD则计算了K+1次,消耗了更多的计算资源。因此FreeAT被提了出来,在PGD的基础上进行训练速度的优化。

FreeAT的思想是在对每个样本x连续重复m次训练,计算r时复用上一步的梯度,为了保证速度,整体epoch会除以m。r的更新公式为:


整个对抗训练的过程如下,伪代码如下:

初始化r=0
对于epoch=1...N/m:对于每个x:对于每步m:1.利用上一步的r,计算x+r的前后向,得到梯度2.根据梯度更新参数3.根据梯度更新r

以上所述的对抗训练方法在不同的训练数据上表现大同小异,需要根据具体场景和具体数据去选择对应的算法,没有最好的,只有当下场景最适合的算法。
详细代码实现参见github代码实现: 对抗训练.

对抗训练fgm、fgsm和pgd原理和源码分析相关推荐

  1. java.lang.ThreadLocal实现原理和源码分析

    java.lang.ThreadLocal实现原理和源码分析 1.ThreadLocal的原理:为每一个线程维护变量的副本.某个线程修改的只是自己的副本. 2.ThreadLocal是如何做到把变量变 ...

  2. Nacos高级特性Raft算法以及原理和源码分析

    Nacos高级特性Raft算法以及原理和源码分析 对比springcloud-config配置中心 springcloud-config工作原理 Nacos的工作原理图 springcloud-con ...

  3. 【项目一、xxx病虫害检测项目】1、SSD原理和源码分析

    目录 前言 一.SSD backbone 1.1.总体结构 1.2.修改vgg 1.3.额外添加层 1.4.需要注意的点 二.SSD head 2.1.检测头predictor 2.2.生成defau ...

  4. RocketMq-dashboard:topic 5min trend 原理和源码分析(一)

    本文阅读基础:使用或了解过rocketMq:想了解"topic 5min trend"背后的原理:想了解监控模式如何实现. RocketMq的dashboard,有运维页面,驾驶舱 ...

  5. 高级JAVA - 动态代理的实现原理和源码分析

    在之前的一篇文章中 , 我们简单了解了一下代理模式(JAVA设计模式 - 代理模式) , 本篇我们来学习一下动态代理的实现原理 , 以及源码是怎样的 . JDK动态代理的主要实现步骤如下 : 1 . ...

  6. Tomcat原理和源码分析

    Tomcat是什么? 首先看下官网的解释说明(看不懂的可以翻译一下),从第一句Tomcat是Java Servlet,JavaServer页,Java表达式语言和Java的WebSocket技术的一个 ...

  7. ConcurrentLinkedQueue的实现原理和源码分析

    原文链接:http://www.jianshu.com/p/26d9745614dd 前言 我们要实现一个线程安全的队列有两种实现方式一种是使用阻塞算法,另一种是使用非阻塞算法.使用阻塞算法的队列可以 ...

  8. ConcurrentHashMap的实现原理和源码分析

    原文链接:http://www.jianshu.com/p/7f42ba895a64 前言 在Java1.5中,并发编程大师Doug Lea给我们带来了concurrent包,而该包中提供的Concu ...

  9. 深入理解GO语言:map结构原理和源码分析

    Map结构是go语言项目经常使用的数据结构,map使用简单对于数据量不大的场合使用非常合适.Map结构是如何实现的?我们先从测试程序入手,我们希望分析map的创建.插入.查询.删除等流程,因此我们的测 ...

  10. Alertmanager 配置文件分析、原理和源码分析

    相关prometheus组件的基本知识总结,以下分析仅代表个人观点,如有错误还请指出,不胜感谢! 基本概述 我们先从应用的角度来看详细的介绍一下alertmanager以下简称am,以下是官方文档介绍 ...

最新文章

  1. 1039 Course List for Student
  2. 安全36计 你需要了解的那些安全术语
  3. 【CV实战】年轻人的第一个深度学习图像分割项目应该是什么样的(Pytorch框架)?...
  4. abap--关于异常的处理
  5. 商品评价判别,文本分类——学习笔记
  6. Java字节序,java整型数与网络字节序 byte[] 数组转换关系
  7. 源码 状态机_[源码阅读] 阿里SOFA服务注册中心MetaServer(1)
  8. 如何防止头文件被重复包含或引用?
  9. linux服务器文件名称乱码,linux中文文件名乱码怎么解决?
  10. vs2008试用版的评估期已经结束解决办法
  11. 并发编程---死锁||递归锁---信号量---Event事件---定时器
  12. 用PLSQL将Excel数据导入到Oracle中
  13. 机器学习(六)——降维处理原理
  14. 自适应和响应式区别以及写法
  15. 微信小程序短视频去水印解析
  16. 使用Photoshop制作相框
  17. oracle minus 利用率,oracle minus用法
  18. python爬取豆瓣读书_爬取豆瓣读书.py
  19. 云开发:未来的软件开发方式
  20. ResponseEntity返回图片,下载图片

热门文章

  1. iOS手势识别的工作原理
  2. 懂一些数据分析工具,为啥还要考CPDA数据分析师证书?
  3. html前端验证代码,前端js+html实现简单验证码
  4. js 上传文件到 minio
  5. python ray定时任务_python定时任务APScheduler
  6. EndNoteX9插入参考文献
  7. endnotex8与9的区别_下载安装EndnoteX8或EndnoteX9,建立数据库并以自己的名字命名。...
  8. Vins-fusion gps融合 KITTY数据集测试
  9. Oracle索引的建立及优缺点
  10. 1.2 DICOM成像协议剖析