训练技巧 | 功守道:NLP中的对抗训练 + PyTorch实现
作者丨Nicolas
单位丨追一科技AI Lab研究员
研究方向丨信息抽取、机器阅读理解
对抗样本
对抗训练的基本概念
Min-Max公式
从CV到NLP
Because the set of high-dimensional one-hot vectors does not admit infinitesimal perturbation, we define the perturbation on continuous word embeddings instead of discrete word inputs.
NLP中的两种对抗训练 + PyTorch实现
Fast Gradient Method(FGM)
上面我们提到,Goodfellow 在 15 年的 ICLR [7] 中提出了 Fast Gradient Sign Method(FGSM),随后,在 17 年的 ICLR [9] 中,Goodfellow 对 FGSM 中计算扰动的部分做了一点简单的修改。假设输入的文本序列的 embedding vectors 为 x ,embedding 的扰动为:
实际上就是取消了符号函数,用二范式做了一个 scale,需要注意的是:这里的 norm 计算的是,每个样本的输入序列中出现过的词组成的矩阵的梯度 norm。原作者提供了一个 TensorFlow 的实现 [10],在他的实现中,公式里的 x 是 embedding 后的中间结果(batch_size, timesteps, hidden_dim),对其梯度 g 的后面两维计算 norm,得到的是一个 (batch_size, 1, 1) 的向量。
为了实现插件式的调用,笔者将一个 batch 抽象成一个样本,一个 batch 统一用一个 norm,由于本来 norm 也只是一个 scale 的作用,影响不大。笔者的实现如下:
import torch
class FGM():def __init__(self, model):self.model = modelself.backup = {}def attack(self, epsilon=1., emb_name='emb.'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:self.backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = epsilon * param.grad / normparam.data.add_(r_at)def restore(self, emb_name='emb.'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.backupparam.data = self.backup[name]self.backup = {}
需要使用对抗训练的时候,只需要添加五行代码:
# 初始化
fgm = FGM(model)
for batch_input, batch_label in data:# 正常训练loss = model(batch_input, batch_label)loss.backward() # 反向传播,得到正常的grad# 对抗训练fgm.attack() # 在embedding上添加对抗扰动loss_adv = model(batch_input, batch_label)loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度fgm.restore() # 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()
PyTorch 为了节约内存,在 backward 的时候并不保存中间变量的梯度。因此,如果需要完全照搬原作的实现,需要用 register_hook 接口 [11] 将 embedding 后的中间变量的梯度保存成全局变量,norm 后面两维,计算出扰动后,在对抗训练 forward 时传入扰动,累加到 embedding 后的中间变量上,得到新的 loss,再进行梯度下降。不过这样实现就与我们追求插件式简单好用的初衷相悖,这里就不赘述了,感兴趣的读者可以自行实现。
Projected Gradient Descent(PGD)
内部 max 的过程,本质上是一个非凹的约束优化问题,FGM 解决的思路其实就是梯度上升,那么 FGM 简单粗暴的“一步到位”,是不是有可能并不能走到约束内的最优点呢?当然是有可能的。于是,一个很 intuitive 的改进诞生了:Madry 在 18 年的 ICLR 中 [8],提出了用 Projected Gradient Descent(PGD)的方法,简单的说,就是“小步走,多走几步”,如果走出了扰动半径为 ϵ 的空间,就映射回“球面”上,以保证扰动不要过大:
其中为扰动的约束空间,α 为小步的步长。
import torch
class PGD():def __init__(self, model):self.model = modelself.emb_backup = {}self.grad_backup = {}def attack(self, epsilon=1., alpha=0.3, emb_name='emb.', is_first_attack=False):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name:if is_first_attack:self.emb_backup[name] = param.data.clone()norm = torch.norm(param.grad)if norm != 0 and not torch.isnan(norm):r_at = alpha * param.grad / normparam.data.add_(r_at)param.data = self.project(name, param.data, epsilon)def restore(self, emb_name='emb.'):# emb_name这个参数要换成你模型中embedding的参数名for name, param in self.model.named_parameters():if param.requires_grad and emb_name in name: assert name in self.emb_backupparam.data = self.emb_backup[name]self.emb_backup = {}def project(self, param_name, param_data, epsilon):r = param_data - self.emb_backup[param_name]if torch.norm(r) > epsilon:r = epsilon * r / torch.norm(r)return param_data + rdef backup_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:self.grad_backup[name] = param.graddef restore_grad(self):for name, param in self.model.named_parameters():if param.requires_grad:param.grad = self.grad_backup[name]
使用的时候,要麻烦一点:
pgd = PGD(model)
K = 3
for batch_input, batch_label in data:# 正常训练loss = model(batch_input, batch_label)loss.backward() # 反向传播,得到正常的gradpgd.backup_grad()# 对抗训练for t in range(K):pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.dataif t != K-1:model.zero_grad()else:pgd.restore_grad()loss_adv = model(batch_input, batch_label)loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度pgd.restore() # 恢复embedding参数# 梯度下降,更新参数optimizer.step()model.zero_grad()
在 [8] 中,作者将这一类通过一阶梯度得到的对抗样本称之为“一阶对抗”,在实验中,作者发现,经过 PGD 训练过的模型,对于所有的一阶对抗都能得到一个低且集中的损失值,如下图所示:
我们可以看到,面对约束空间 S 内随机采样的十万个扰动,PGD 模型能够得到一个非常低且集中的 loss 分布,因此,在论文中,作者称 PGD 为“一阶最强对抗”。也就是说,只要能搞定 PGD 对抗,别的一阶对抗就不在话下了。
实验对照
为了说明对抗训练的作用,笔者选了四个 GLUE 中的任务进行了对照试验。实验代码是用的 Huggingface 的 transfomers/examples/run_glue.py [12],超参都是默认的,对抗训练用的也是相同的超参。
我们可以看到,对抗训练还是有效的,在 MRPC 和 RTE 任务上甚至可以提高三四个百分点。不过,根据我们使用的经验来看,是否有效有时也取决于数据集。毕竟:缘,妙不可言~
总结
这篇博客梳理了 NLP 对抗训练发展的来龙去脉,介绍了对抗训练的数学定义,并对于两种经典的对抗训练方法,提供了插件式的实现,做了简单的实验对照。由于笔者接触对抗训练的时间也并不长,如果文中有理解偏差的地方,希望读者不吝指出。
一个彩蛋:Virtual Adversarial Training
除了监督训练,对抗训练还可以用在半监督任务中,尤其对于 NLP 任务来说,很多时候输入的无监督文本多的很,但是很难大规模地进行标注,那么就可以参考 [13] 中提到的 Virtual Adversarial Training 进行半监督训练。
首先,我们抽取一个随机标准正态扰动(),加到 embedding 上,并用 KL 散度计算梯度:
然后,用得到的梯度,计算对抗扰动,并进行对抗训练:
实现方法跟 FGM 差不多,这里就不给出了。
Reference
点击以下标题查看更多往期内容:
#投 稿 通 道#
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得或技术干货。我们的目的只有一个,让知识真正流动起来。
训练技巧 | 功守道:NLP中的对抗训练 + PyTorch实现相关推荐
- 浅谈NLP中的对抗训练方式
©作者 | 林远平 单位 | QTrade AI研发中心 研究方向 | 自然语言处理 前言 什么是对抗训练呢?说起"对抗",我们就想起了计算机视觉领域的对抗生成网络(GAN).在计 ...
- pytorch 对抗样本_【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现
本文分享一个"万物皆可盘"的NLP对抗训练实现,只需要四行代码即可调用.盘他. 最近,微软的FreeLB-Roberta [1] 靠着对抗训练 (Adversarial Train ...
- 【炼丹之道】NLP中的对抗训练
作者 | 王嘉宁@华师数据学院 整理 | NewBeeNLP https://blog.csdn.net/qq_36426650/article/details/122807916 大家好,这里是Ne ...
- 【NLP】一文搞懂NLP中的对抗训练
本文主要串烧了FGSM, FGM, PGD, FreeAT, YOPO, FreeLB, SMART这几种对抗训练方法,希望能使各位大佬炼出的丹药更加圆润有光泽,一颗永流传 简介 对抗训练是一种引入噪 ...
- 【NLP】NLP中的对抗训练
作者 | 王嘉宁@华师数据学院 整理 | NewBeeNLP https://blog.csdn.net/qq_36426650/article/details/122807916 对抗训练本质是为了 ...
- NLP中的对抗训练(附PyTorch实现)
对抗样本的基本概念 要认识对抗训练,首先要了解"对抗样本",它首先出现在论文Intriguing properties of neural networks之中.简单来说,它是指对 ...
- 文本中的对抗学习 + pytorch实现
最近,微软的FreeLB-Roberta [1] 靠着对抗训练 (Adversarial Training) 在GLUE榜上超越了Facebook原生的Roberta,追一科技也用到了这个方法仅凭单模 ...
- 『功守道』软件供应链安全大赛·C源代码赛季启示录
背景 软件供应链安全,这可以说是一个新近的人造的概念热词.泛泛来讲,如今的软件系统中任何一方都不是孤立的:套用到企业的场景,就有了供应链的概念. 以典型互联网企业为例.线上生产环境所依赖的操所系统,配 ...
- 「功守道」软件供应链安全大赛·C源代码赛季启示录
背景 软件供应链安全,这可以说是一个新近人造的概念热词.泛泛来讲,如今的软件系统中任何一方都不是孤立的:套用到企业的场景,就有了供应链的概念. 以典型互联网企业为例.线上生产环境所依赖的操所系统,配套 ...
最新文章
- access-control-allow-origin php,PHP通过Access-Control-Allow-Origin 跨域
- 代码生成器,自己实现的一个基于模板的在线代码生成网站
- 大学计算机二级考试 vb,大学计算机二级考试常用vb代码.docx
- 男人必看,男性排毒同样重要 - 生活至上,美容至尚!
- mac使用被动ftp模式(pasv)_ftp主动模式和被动模式
- python 进度条_六种酷炫Python运行进度条
- EF Core 插件 —— ToSql
- 数据分析之pandas笔记
- linux用date指令,Linux中date指令的使用
- 我的欧拉工程之路_3
- 广度优先搜索——岛屿数量(Leetcode 200)
- LUNA16数据集肺结节显示亲测
- 使用py 和flask 实现的服务器系统目录浏览,日志文件实时显示到网页的功能
- 苹果x微信为什么不出定位服务器,苹果x微信发动态为什么显示不了位置
- EMMC和Nand傻傻分不清
- 请问如何查询一个APP的Android和iOS下载量?
- QCustomPlot使用心得三:线样式,点样式
- c语言0x前缀的作用,有趣的问题,C语言程序中,为什么十六进制数字以前缀0x开头呢?...
- 时间戳计算机网络,时间戳
- element-ui input组件源码分析整理笔记(六)
热门文章
- python中long类型的取值范围_java基本数据类型取值范围
- php redis decr_对于高并发的问题你知道怎么处理吗?php接口如何处理并发问题
- 今天写一个关于浮动的页面,页面高度不能设置。用元素将他撑开。
- (六)6-3Mysql操作据二
- OC的项目网址(自己编写的项目)
- 关于Log 的一些东西
- 【转】C++中的SFINAE
- HOWTO:如何在代码中获取安装包目标机上的Windows Installer(MSI)版本
- 面向搜索的中文分词设计
- java宝典_JAVA宝典之_JAVA基础