文章目录

  • 引言
  • 一、对抗训练一般原理
    • 1.对抗样本
  • 二、对抗训练的经典算法
  • 三、对抗训练代码实现
    • 1.FGM
    • 2.PGD

引言

  对抗训练对于NLP来说,是一种非常好的上分利器,所以,非常有必要加深对对抗训练的认识。

一、对抗训练一般原理

  小学语文课上,我们都学习过《矛与盾》这篇课文,

从辩证唯物史观角度来看,矛与盾并没有严格意义上的谁更厉害,谁一直占优。矛盾着的双方又同一又斗争,双方力量此长彼消,不断前进,从而推动事物发展。这也就是对抗训练。

1.对抗样本

  两句话,只是部分英文单词发生了改变,但是在我们看来,含义还是几乎不变的,但是就是这种变化,基于BERT的文本分类模型对句子的情感分类居然是完全相反的。

  根据研究,我们发现InferSent MultiNLI模型、ESIM MultiNLI模型、BERT MultiNLI模型在SNLI数据集上面准确率为84%以上,但是在对抗文本上,不管是传统的深度网络还是BERT模型,准确率都只有个位数。尽管BERT模型经过微调可以在不同的下游任务得到好的结果,但仍然是非常容易受到攻击的(模型不鲁棒)。

  那么,如何使得模型更加鲁棒呢?比如:

  • 我们可以对数据进行处理,将对攻击的防御放到数据处理过程中
  • 我们也可以对模型输出的向量表征做一些转换,将防御放在模型的输出端
  • 我们也可以将防守放在模型本身上来

二、对抗训练的经典算法

  lan Goodfellow提出对抗训练方法,它的思想是在训练时,在原始的输入样本中加上扰动(对抗样本),我们用对抗样本来进行模型的训练,使得模型更加鲁棒

扰动的计算方式定义为:求得模型损失,对于xxx求梯度,对所求的梯度经过符号函数处理,在乘以一个系数ϵ\epsilonϵ

这种对抗训练的方法叫做Fast Gradient Sign Method (FGSM)。
  对抗训练其实可以当做正则化,减少模型的过拟合,提升模型的泛化性能。FGSM对抗训练方法应用与CV,cv的输入都是图像,图像输入模型时为RGB三个通道上的像素值,它的输入本身就是连续的,但是NLP的输入是离散的单词序列。那么,我们应该如何在NLP模型上定义对抗训练?
  虽然NLP的输入都是离散的单词序列,但是会经过embedding转变成低维空间上的向量表征,我们可以将embedding后的向量表征当成上述对抗训练模型中的xxx。lan Goodfellow在2017年提出了在连续的embedding上做扰动。

扰动的计算方式为:

  • xxx表示文本序列的embedding vectors

这种对抗训练的方法叫做Fast Gradient Method (FGM)。FGM通过一步,就移动到对抗样本上,如果梯度太大,可能会导致扰动过大,对模型造成误导。
  FGM对抗训练方法一步得到对抗样本,容易导致扰动过大。Projected Gradient Descent
(PGD)方法限定了扰动的范围,对抗样本并不是一步就得到了,而是通过沿着不同点的梯度走了多步之后再去得到。

其中,

  • 小步走:如果沿着梯度走的比较远,则通过投影的方式,投影到球面s上;
  • 多步走:生成一个对抗样本是走多步得到的;

这种对抗训练的方法叫做Projected Gradient Descent(PGD)。

三、对抗训练代码实现

1.FGM

class FGM():"""定义对抗训练方法FGM,对模型embedding参数进行扰动"""def __init__(self, model, epsilon=0.25,):# BERT模型self.model = model# 求干扰时的系数值self.epsilon = epsilonself.backup = {}def attack(self, emb_name='word_embeddings'):"""得到对抗样本:param emb_name:模型中embedding的参数名:return:"""# 循环遍历模型所有参数for name, param in self.model.named_parameters():# 如果当前参数在计算中保留了对应的梯度信息,并且包含了模型中embedding的参数名if param.requires_grad and emb_name in name:# 把真实参数保存起来self.backup[name] = param.data.clone()# 对参数的梯度求范数norm = torch.norm(param.grad)# 如果范数不等于0并且norm中没有缺失值if norm != 0 and not torch.isnan(norm):# 计算扰动,param.grad / norm=单位向量,起到了sgn(param.grad)一样的作用r_at = self.epsilon * param.grad / norm# 在原参数的基础上添加扰动param.data.add_(r_at)def restore(self, emb_name='word_embeddings'):"""将模型原本的参数复原:param emb_name:模型中embedding的参数名"""# 循环遍历模型所有参数for name, param in self.model.named_parameters():# 如果当前参数在计算中保留了对应的梯度信息,并且包含了模型中embedding的参数名if param.requires_grad and emb_name in name:# 断言assert name in self.backup# 取出模型真实参数param.data = self.backup[name]# 清空self.backupself.backup = {}
# 实例初始化
fgm = FGM(model)
for batch_input, batch_label in data:# 正常训练loss = model(batch_input, batch_label)# 反向传播,得到正常的gradloss.backward() # 对抗训练,在embedding上添加对抗扰动fgm.attack() # embedding参数被修改,此时,输入序列得到的embedding表征不一样loss_adv = model(batch_input, batch_label)# 反向传播,并在正常的grad基础上,累加对抗训练的梯度loss_adv.backward() # 恢复embedding参数fgm.restore() # 梯度下降,更新参数optimizer.step()# 将梯度清零model.zero_grad()

2.PGD

class PGD():"""定义对抗训练方法PGD"""def __init__(self, model, epsilon=1.0, alpha=0.3):# BERT模型self.model = model# 两个计算参数self.epsilon = epsilonself.alpha = alpha# 用于存储embedding参数self.emb_backup = {}# 用于存储梯度,与多步走相关self.grad_backup = {}def attack(self, emb_name='word_embeddings', is_first_attack=False):"""对抗:param emb_name: 模型中embedding的参数名:param is_first_attack: 是否是第一次攻击"""# 循环遍历模型的每一个参数for name, param in self.model.named_parameters():# 如果当前参数在计算中保留了对应的梯度信息,并且包含了模型中embedding的参数名if param.requires_grad and emb_name in name:# 如果是第一次攻击if is_first_attack:# 存储embedding参数self.emb_backup[name] = param.data.clone()# 求梯度的范数norm = torch.norm(param.grad)# 如果范数不等于0if norm != 0:# 计算扰动,param.grad / norm=单位向量相当于sgn符号函数r_at = self.alpha * param.grad / norm# 在原参数的基础上添加扰动param.data.add_(r_at)# 控制扰动后的模型参数值# 投影到以原参数为原点,epsilon大小为半径的球上面param.data = self.project(name, param.data, self.epsilon)def restore(self, emb_name='word_embeddings'):"""将模型原本参数复原:param emb_name: 模型中embedding的参数名"""# 循环遍历每一个参数for name, param in self.model.named_parameters():# 如果当前参数在计算中保留了对应的梯度信息,并且包含了模型中embedding的参数名if param.requires_grad and emb_name in name:assert name in self.emb_backup# 取出模型真实参数param.data = self.emb_backup[name]# 清空emb_backupself.emb_backup = {}def project(self, param_name, param_data, epsilon):"""控制扰动后的模型参数值:param param_name::param param_data::param epsilon:"""# 计算加了扰动后的参数值与原始参数的差值r = param_data - self.emb_backup[param_name]# 如果差值的范数大于epsilonif torch.norm(r) > epsilon:# 对差值进行截断r = epsilon * r / torch.norm(r)# 返回新的加了扰动后的参数值return self.emb_backup[param_name] + rdef backup_grad(self):"""对梯度进行备份"""# 循环遍历每一个参数for name, param in self.model.named_parameters():# 如果当前参数在计算中保留了对应的梯度信息if param.requires_grad:# 如果参数没有梯度if param.grad is None:print("{} param has no grad !!!".format(name))continue# 将参数梯度进行备份self.grad_backup[name] = param.grad.clone()def restore_grad(self):"""将梯度进行复原"""# 循环遍历每一个参数for name, param in self.model.named_parameters():# 如果当前参数在计算中保留了对应的梯度信息if param.requires_grad:# 如果没有备份if name not in self.grad_backup:continue# 如果备份了,就将原始模型参数梯度取出param.grad = self.grad_backup[name]
# 实例初始化
pgd = PGD(model)
steps_for_at = 3
for batch_input, batch_label in data:# 正常训练loss = model(batch_input, batch_label)# 反向传播,得到正常的gradloss.backward() # 保存正常的梯度pgd.backup_grad()# PGD要走多步,迭代走多步for t in range(steps_for_at):# 在embedding上添加对抗扰动, first attack时备份param.datapgd.attack(is_first_attack=(t == 0))# 中间过程,梯度清零if t != steps_for_at - 1:optimizer.zero_grad()# 最后一步,恢复正常的gradelse:pgd.restore_grad()# embedding参数被修改,此时,输入序列得到的embedding表征不一样loss_at = model(batch_input, batch_label)# 对抗样本上的损失loss_at = outputs_at[0]# 反向传播,并在正常的grad基础上,累加对抗训练的梯度loss_at.backward()# 恢复embedding参数pgd.restore()# 梯度下降,更新参数optimizer.step()# 将梯度清零model.zero_grad()

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!


BERT模型—6.对抗训练原理与代码实现相关推荐

  1. 【NLP】BERT蒸馏完全指南|原理/技巧/代码

    小朋友,关于模型蒸馏,你是否有很多问号: 蒸馏是什么?怎么蒸BERT? BERT蒸馏有什么技巧?如何调参? 蒸馏代码怎么写?有现成的吗? 今天rumor就结合Distilled BiLSTM/BERT ...

  2. IREC-GAN:在线推荐中基于模型的对抗训练强化学习

    IREC-GAN 1 Introduction 推荐系统已经成功地将用户与其在各种应用领域中最感兴趣的内容联系起来.然而,由于用户的兴趣和行为模式不同,只向每个用户呈现一小部分项目,记录的反馈更少.对 ...

  3. BERT模型—7.BERT模型在句子分类任务上的微调(对抗训练)

    文章目录 引言 二.项目环境配置 二.数据集介绍 三.代码介绍 四.测试 1.代码执行流程 数据代码见:https://gitee.com/lj857335332/bert_finetune_cls_ ...

  4. [深度学习] 自然语言处理 --- BERT模型原理

    一 BERT简介 NLP:自然语言处理(NLP)是信息时代最重要的技术之一.理解复杂的语言也是人工智能的重要组成部分.Google AI 团队提出的预训练语言模型 BERT(Bidirectional ...

  5. 对抗训练硬核分析:对抗样本与模型参数的关系

    ©PaperWeekly 原创 · 作者|孙裕道 学校|北京邮电大学博士生 研究方向|GAN图像生成.情绪对抗样本生成 引言 对抗训练是防御对抗样本一种有效的方法,但是对于它有效性的边界,一直都是很模 ...

  6. 从源码到实战:BERT模型训练营

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 开课吧教育 方向:NLP 之 BERT实战 都说BERT模型开启了NLP的新时代,更有" ...

  7. 【NLP】一份相当全面的BERT模型精讲

    本文概览: 1. Autoregressive语言模型与Autoencoder语言模型 1.1 语言模型概念介绍 Autoregressive语言模型:指的是依据前面(或后面)出现的单词来预测当前时刻 ...

  8. 对抗训练浅谈:意义、方法和思考(附Keras实现)

    ©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 当前,说到深度学习中的对抗,一般会有两个含义:一个是生成对抗网络(Generative Adversari ...

  9. 开启NLP新时代的BERT模型,真的好上手吗?

    都说BERT模型开启了NLP的新时代,更有"BERT在手,天下我有"的传说,它解决了很多NLP的难题: 1.BERT让低成本地训练超大规模语料成为可能: 2.BERT能够联合神经网 ...

  10. BERT模型实战之多文本分类(附源码)

    BERT模型也出来很久了,之前看了论文学习过它的大致模型(可以参考前些日子写的笔记NLP大杀器BERT模型解读),但是一直有杂七杂八的事拖着没有具体去实现过真实效果如何.今天就趁机来动手写一写实战,顺 ...

最新文章

  1. 【青少年编程(第33周)】Scratch(三级)公益活动开营了!
  2. 构建一个增量推荐系统
  3. p2596 书架(Treap)
  4. Mysql流程控制结构
  5. Java Statement PK PrepareStatement
  6. php编译支持mysql,编译php支持curl和pdo_mysql
  7. 聊一聊 java8 中的 Optional
  8. kaldi 源码分析(七) - HCLG 分析
  9. Ajax应用简单实例
  10. Windows Phone开发手记-WinRT下启动器替代方案
  11. linux 新老软件切换,Linux下非常重要的软件切换命令
  12. how to set up github blog
  13. ROS(9):机器人开源项目poppy-project
  14. 什么是静态分析?代码分析工具
  15. 微信小程序弹窗显隐动态控制页面滚动
  16. PS 仿制图章 轻松换支付宝赞赏码 GIF动态图手把手教你
  17. python学习之路之:import(详细介绍import的各种调用原理和使用方法)
  18. 将串口转换成TCP连接
  19. 相片尺寸规格像素一览
  20. 复旦大学邱锡鹏:若优化顺利,MOSS三月底开源;库克或被踢出苹果董事会;华为云联合CSDN发布智能化编程助手Snap|极客头条...

热门文章

  1. Neuron:Neural activities in V1 create a bottom-up saliency map
  2. 纯css实现照片墙3D效果
  3. 邀请您加入移动开发专家联盟
  4. IDEA 没有创建servlet选项问题的解决
  5. POJ 1094 Sorting It All Out 【拓扑排序】
  6. git 删除已经 add 的文件
  7. Uva 10306 e-Coins
  8. shiro的详细讲解
  9. Vue 中watch和computed 的用法及区别
  10. weixin-api生成二维码