每天给你送来NLP技术干货!


来自:ChallengeHub

论文标题:SSMix: Saliency-Based Span Mixup for Text Classification

论文链接:https://arxiv.org/pdf/2106.08062.pdf

论文代码:https://github.com/clovaai/ssmix

论文作者:{soyoungyoon etc.}

1

论文摘要

数据增强已证明对各种计算机视觉任务是有效的。尽管文本取得了巨大的成功,但由于文本由可变长度的离散标记组成,因此将混合应用于NLP任务一直存在障碍。在这项工作中,我们提出了SSMix,一种新的混合方法,其中操作是对输入文本执行的,而不是像以前的方法那样对隐藏向量执行的。SSMix通过基于跨度的混合,综合一个句子,同时保留两个原始文本的位置,并依赖于显著性信息保留更多与预测相关的标记。通过大量的实验,我们实证验证了我们的方法在广泛的文本分类基准上优于隐藏级混合方法,包括文本隐含、情感分类和问题类型分类。

数据增强的效果已经在各种计算机视觉任务中被证实是有效的。尽管数据增强非常有效,由于文本是由变长的离散字符组成的,所以将mixup应用与NLP任务一直存在障碍。在本篇论文,作者提出了SSMix算法,一种针对输入文本增强的mixup算法,而非之前针对隐藏向量的方法。SSMix通过跨度混合( span-based mixing)在保留原始两个文本的条件下合成一个句子,同时保留两个原始文本的位置,并依赖于显著性信息保留更多与预测相关的标记。通过大量的实验,论文验证了该算法在广泛的文本分类基准上优于隐藏级混合方法,包括文本推断、情感分类和问题类型分类任务。

2

算法简介

由于数据收集与标志的昂贵成本,数据增强在自然语言处理(NLP)中越来越重要。其中一些已往研究包括基于简单的规则和模型来生成类似的文本。比如通过标准方法或先进的训练方法与原始样本联合进行训练,也有基于混淆(mixup)插值文本和标签进行增强。

Mixup及其变体训练算法成为计算机视觉中常用的正则化方法,用来提高神经网络的泛化能力。混合方法分为输入级混合和隐藏级混合( hidden-level mixup),两者取决于混合操作的位置。输入级混合是一种比隐藏级混合更普遍的方法,因为它的简单性和能够捕获局部性,从而具有更好的准确性。

由于文本数据的离散性和可变的序列长度,在NLP中应用mixup比在计算机视觉中更具有挑战性和难度。因此,之前大多数关于文本混合的尝试将mixup应用于嵌入向量,如嵌入或中间表示。然而根据计算机视觉的增强直观感受,输入级混合一般比隐藏级混合有优势。这一动机鼓励作者对探究文本数据的输入级混淆方法。

在这项工作中,作者提出了SSMix(图1),一种新的输入级结合跨度(Span)的显著性混合数据增强法算法。首先,作者通过用另一个文本中的跨度替换连续的标记来进行混淆,这一灵感来自CutMixarXiv,在混合文本中保留两个源文本的位置。其次,选择一个要替换的跨度,并基于显著性信息进行替换,以使混合文本包含与输出预测更相关的标记,这在语义上可能很重要。文本的输入级方法不同于隐级混合方法,当当前的隐级混合方法线性插值原始隐向量,我们的方法在输入级上混合文本字符,产生非线性输出。同时,利用显著性值从每个句子中选择跨度,并离散地定义跨度的长度和混合比,这是与隐藏级别混合增强区别的地方。

SSMix已经通过大量的文本分类基准实验被证明是有效的。特别强调的是,论文证明了输入级混合方法一般要优于隐层混合方法。论文还展示了在进行文本混合增强的同时,在跨度水平上使用显著性信息和限制标记选择的重要性。

3

SSMix算法

SSMix基本原理为:给定两个文本和,通过将文本的片段替换为来自另一文本的显著信息片段生成得到新的文本。同时,对于新文本,基于两个文本标签和重新为新文本设置一个新的标签。最后可以使用这个生成的增强虚拟样本(,)来进行训练模型。

Saliency:显著性信息

Saliency衡量了文本数据的每个字符对最终结果预测的影响。在以往研究中基于梯度的方法被广泛用于显著性计算,文本同样计算了分类损失相对于输入嵌入的梯度,并使用其大小作为显著性:。文中应用l2范数来获得一个梯度向量的大小,代表着每个字符的类似于PuzzleMix的显著性。

Mixing Text:文本合成

之前提到过,Mixing Text主要是是指两个文本序列和如何合成新的文本。大致思路是根据梯度显著性计算方法得到两个文本中每个字符的显著性分数,然后在文本中选取一个显著性最低的片段,长度为,在文本中选取一个显著性最低的片段,长度为。长度设置为==,其中为mixup比例参数。最后生成新文本w为,其中和为原始文本中替换片段的左右的两部分。

Sample span length:相等片段长度

本文将原始()的长度和替换()跨度设置为相同的,主要原因是使用不同长度的span(片段)将导致冗余和语义不明确的mixup 转换。另外,计算不同长度的span之间的mixup 比列也过于复杂。在以往研究中也采用了这种相同大小的替换策略。在替换span长度相同的情况下,论文的SSMix算法能够使显著性的效果最大化。由于SSMix不限制字符的位置,可以同时选择最显著的span和被替换的最不显著片段。如图片1中,in this在文本中是不显著的,transcedent love在文本中是最显著的,那么可以用transcedent love替换in this

Mixing Label:标签合成

作者将mixup 比列设置为:

由于λ是通过计算片段内的字符数量来重新计算的,因此它可能与λ0不相等。然后的标签为:

算法1展示了如何利用原始样本对来计算增广样本的混合损失。公式中计算了增强输出logit相对于每个样本的原始目标标签的交叉熵损失,并通过加权和进行组合,因此SSMix算法与数据集标签个数是不相关的,在任何数据集上,输出标签比例是通过两个原始标签的线性组合来计算。

Paired sentence tasks:句子对任务

对于需要一对文本作为输入的任务,如文本隐含推断和相似性分类,SSMix以成对的方式进行混合,并通过聚合每个mixup结果中的标记计数来计算mixup比例。给定样本,,合成的新样本为,mixup比例记为,其中和为每个mixup操作中的替换片段。

如下图所示:

  • 为 "Fun for only children."

  • 为  "Fun foradults and children."

  • 为 "Problems in data synthesis."

  • 为  "Issues in data synthesis."

  • λ

4

实验设置

实验数据集

论文实验数据集有文本分类和句子对分类任务:

对比实验

论文将SSMix与三个基线进行了比较:(1) standard training without mixup,(2)EmbedMixMix(3)TMix。与基线和消融研究的实验结果进行了比较。所有的准确率值都是使用不同种子的5次运行的平均精度(%)。MNLI表示MNLI-不匹配的开发集的准确性。论文报告了GLUE的验证精度,TREC的测试精度,以及ANLI的有效(上)/测试(较低)精度,可以看出SSMix在大部分数据集效果要优于其他混合增强算法。

5

论文总结

  • 与隐层混合方法相比,SSMix在具有足够数据量的数据集上充分证明了其有效性。由于SSMix是一个离散的组合,而不是两个数据样本的线性组合,它在一个合成空间上创建数据样本的范围大于隐藏级别的混合。论文假设,大量的数据有助于更好地在合成空间中进行表示。

  • SSMix对于多个类标签数据集(TREC、ANLI、MNLI、QNLI)尤其有效。因此,在没有混合的训练条件下,SSMix在TREC-fine(47个标签)上的精度增益远高于TRECcrare(6个标签),+分别为3.56和+为0.52。具有多个总类标签的数据集增加了在混合源的随机抽样中被选择交叉标签的可能性,所以可以认为在这些多标签分类数据集中的混合性能会显著提高

  • 在成对句子任务上具有显著优势,如文本隐含或相似性分类。现有的方法(隐藏层混合)在隐藏层上应用混合,而不考虑特殊的标记,即[SEP]、[CLS]。这些方法可能会丢失关于句子开头的信息或句子对的适当分离。相比之下,SSMix在应用混合时可以考虑单个字符的特性。-SSMix 及其变体的消融研究结果表明,随着对片段约束和显著性信息的增加,性能有所提高。在混合操作中添加片段约束受益于更好的可定位能力,并且大多数显著的片段与相应的标签有更多的关系,而丢弃最小显著的片段,这些片段相对于原始标签在语义上不重要。其中,引入显著性信息对精度的贡献相对高于片段约束。

6

代码实现

import copy
import random
import torch
import torch.nn.functional as Ffrom .saliency import get_saliencyclass SSMix:def __init__(self, args):self.args = argsdef __call__(self, input1, input2, target1, target2, length1, length2, max_len):batch_size = len(length1)if self.args.ss_no_saliency:if self.args.ss_no_span:inputs_aug, ratio = self.ssmix_nosal_nospan(input1, input2, length1, length2, max_len)else:inputs_aug, ratio = self.ssmix_nosal(input1, input2, length1, length2, max_len)else:assert not self.args.ss_no_spaninput2_saliency, input2_emb, _ = get_saliency(self.args, input2, target2)inputs_aug, ratio = self.ssmix(batch_size, input1, input2,length1, length2, input2_saliency, target1, max_len)return inputs_aug, ratiodef ssmix(self, batch_size, input1, input2, length1, length2, saliency2, target1, max_len):inputs_aug = copy.deepcopy(input1)for i in range(batch_size):  # cut off length bigger than max_len ( nli task )if length1[i].item() > max_len:length1[i] = max_lenfor key in inputs_aug.keys():inputs_aug[key][i][max_len:] = 0inputs_aug['input_ids'][i][max_len - 1] = 102saliency1, _, _ = get_saliency(self.args, inputs_aug, target1)ratio = torch.ones((batch_size,), device=self.args.device)for i in range(batch_size):l1, l2 = length1[i].item(), length2[i].item()limit_len = min(l1, max_len) - 2  # mixup except [CLS] and [SEP]mix_size = max(int(limit_len * (self.args.ss_winsize / 100.)), 1)if l2 < mix_size:ratio[i] = 1continuesaliency1_nopad = saliency1[i, :l1].unsqueeze(0).unsqueeze(0)saliency2_nopad = saliency2[i, :l2].unsqueeze(0).unsqueeze(0)saliency1_pool = F.avg_pool1d(saliency1_nopad, mix_size, stride=1).squeeze(0).squeeze(0)saliency2_pool = F.avg_pool1d(saliency2_nopad, mix_size, stride=1).squeeze(0).squeeze(0)# should not select first and lastsaliency1_pool[0], saliency1_pool[-1] = 100, 100saliency2_pool[0], saliency2_pool[-1] = -100, -100input1_idx = torch.argmin(saliency1_pool)input2_idx = torch.argmax(saliency2_pool)inputs_aug['input_ids'][i, input1_idx:input1_idx + mix_size] = \input2['input_ids'][i, input2_idx:input2_idx + mix_size]ratio[i] = 1 - (mix_size / (l1 - 2))return inputs_aug, ratiodef ssmix_nosal(self, input1, input2, length1, length2, max_len):inputs_aug = copy.deepcopy(input1)ratio = torch.ones((len(length1),), device=self.args.device)for idx in range(len(length1)):if length1[idx].item() > max_len:for key in inputs_aug.keys():inputs_aug[key][idx][max_len:] = 0inputs_aug['input_ids'][idx][max_len - 1] = 102  # artificially add EOS token.l1, l2 = min(length1[idx].item(), max_len), length2[idx].item()if self.args.ss_winsize == -1:window_size = random.randrange(0, l1)  # random sampling of window_sizeelse:# remove EOS & SOS when calculating ratio & window size.window_size = int((l1 - 2) *self.args.ss_winsize / 100.) or 1if l2 <= window_size:ratio[idx] = 1continuestart_idx = random.randrange(0, l1 - window_size)  # random sampling of starting pointif (l2 - window_size) < start_idx:  # not enough text for reference.ratio[idx] = 1continueelse:ref_start_idx = start_idxmix_percent = float(window_size) / (l1 - 2)for key in input1.keys():inputs_aug[key][idx, start_idx:start_idx + window_size] = \input2[key][idx, ref_start_idx:ref_start_idx + window_size]ratio[idx] = 1 - mix_percentreturn inputs_aug, ratiodef ssmix_nosal_nospan(self, input1, input2, length1, length2, max_len):batch_size, n_token = input1['input_ids'].shapeinputs_aug = copy.deepcopy(input1)len1 = length1.clone().detach()ratio = torch.ones((batch_size,), device=self.args.device)for i in range(batch_size): # force augmented output length to be no more than max_lenif len1[i].item() > max_len:len1[i] = max_lenfor key in inputs_aug.keys():inputs_aug[key][i][max_len:] = 0inputs_aug['input_ids'][i][max_len - 1] = 102mix_len = int((len1[i] - 2) * (self.args.ss_winsize / 100.)) or 1if (length2[i] - 2) < mix_len:mix_len = length2[i] - 2flip_idx = random.sample(range(1, min(len1[i] - 1, length2[i] - 1)), mix_len)inputs_aug['input_ids'][i][flip_idx] = input2['input_ids'][i][flip_idx]ratio[i] = 1 - (mix_len / (len1[i].item() - 2))return inputs_aug, ratio

投稿或交流学习,备注:昵称-学校(公司)-方向,进入DL&NLP交流群。

方向有很多:机器学习、深度学习,python,情感分析、意见挖掘、句法分析、机器翻译、人机对话、知识图谱、语音识别等。

记得备注呦

整理不易,还望给个在看!

ACL2021 | 没想到Mixup还可以用于文本:SSMix相关推荐

  1. 没想到MySQL还会问这些...

    文本已收录至我的GitHub精选文章,欢迎Star:https://github.com/ZhongFuCheng3y/3y 在前一阵子,大哥问过我:"你知道MySQL的原子性是怎么保证的吗 ...

  2. 大西瓜卖了130多个!没想到后面还结出更多惊喜

    一.西瓜里有虫 本周赶上了大西瓜的热度,由Nowpaper杨老师在B站直播开发的<合成大西瓜>目前在我们的资源商店上卖出了130+了! 随着购买的人增多,游戏中的BUG也被大家给发现了,一 ...

  3. kafka删除队列_没想到 Kafka 还会这样问,学会这些带你轻松搞定大厂面试!

    一.前言 自上次师兄遭受了面试官 「Kafka」 的暴击追问后,回来发奋图强,企图"「吊打面试官」",奈何还是面试官套路深啊,最近的面试,又被问到「知识盲点」了!让我们一起来看看, ...

  4. python slice函数画高维图_没想到Python还能画六维图

    我们的大脑通常最多能感知三维空间,超过三维就很难想象了.尽管是三维,理解起来也很费劲,所以大多数情况下都使用二维平面. 来自维基百科 不过,我们仍然可以绘制出多维空间,今天就来用 Python 的 p ...

  5. python做动态课件素材_万万没想到,还能这么玩!用 Python 生成动态 PPT

    在工作的过程中,我们会发现那些能够把知识.成果讲透的人很多都会做动态图表. 这篇文章就介绍了 Python 中一种简单的动态图表制作方法,这样生成的动图就可以丰富我们的PPT啦~ 数据暴增的年代,数据 ...

  6. 不断在创业路上寻找突破,他开店多家,没想到公司还上市了

    1990年,正是我国从计划经济向市场经济的关键时期,一个标志性的事件就是上海.深圳两个证券交易所的营业.也是这一年,22岁的李伟从河南大学新闻系毕业.此后的6年中,李伟前前后后换了五六份工作,但是全都 ...

  7. 真没想到,Python还能实现5毛特效

    来源 | ZackSock(ID:ZackSock) 图源 | 视觉中国 Python牛已经不是一天两天的事了,但是我开始也没想到,Python能这么牛.前段时间接触了一个批量抠图的模型库,而后在一些 ...

  8. python合并视频和音频_真没想到,Python 还能实现 5 毛特效

    作者 | ZackSock 来源 | ZackSock(ID:ZackSock) Python牛已经不是一天两天的事了,但是我开始也没想到,Python能这么牛.前段时间接触了一个批量抠图的模型库,而 ...

  9. 造了一个 Redis 分布锁的轮子,没想到还学到这么多东西!!!

    手撸分布式锁 这篇文章本来是准备写下 Mysql 查询左匹配的问题,但是还没研究出来.那就先写下最近在鼓捣一个东西,使用 Redis 实现可重入分布锁. 看到这里,有的朋友可能会提出来使用 redis ...

  10. python可以做特效吗_真没想到,Python还能实现5毛特效

    来源 | ZackSock(ID:ZackSock) 图源 | 视觉中国 Python牛已经不是一天两天的事了,但是我开始也没想到,Python能这么牛.前段时间接触了一个批量抠图的模型库,而后在一些 ...

最新文章

  1. BZOJ1002 [FJOI2007]轮状病毒(最小生成树计数)
  2. 简单分析Flask 数据库迁移详情
  3. pythontuple数据类型_Python数据类型之元组的详细介绍
  4. [结对2]必应缤纷桌面软件测试报告
  5. Android 四大组件 —— 活动(活动的隐式跳转)
  6. 【雕爷学编程】Arduino动手做(16)---数字触摸传感器
  7. Intellij idea 自动生成serialVersionUID
  8. linux基础命令-查看系统状态-free -m以及top命令详解
  9. Python小游戏-接苹果
  10. 佳铁精雕机连接电脑设置_佳铁精雕机在程式里怎么更改G57之后的坐标
  11. 常见四大类型视频接口
  12. linux配置超时_自动退出登录TMOUT
  13. 服务器网页篡改,网站服务器网页防篡改系统
  14. Web全栈工程师技能树梳理
  15. mac输密码麻烦?一位数密码来了!
  16. CS106B Assignment #4:Boggle
  17. python 加速度_「加速度公式」加速度公式1 - seo实验室
  18. java测试输入星座匹配_java十二星座 (快来测试你是什么星座吧)
  19. python预测糖尿病_使用决策树与随机深林预测糖尿病(python)
  20. 光电振荡器的MATLAB仿真,基于Matlab的RLC阻尼振荡电路建模与仿真研究

热门文章

  1. 使用struts2未登录,不能操作
  2. Windows Bash on Ubuntu
  3. Hidden (NOIP模拟赛)(字符串模拟QAQ)
  4. HDFS体系结构(NameNode、DataNode详解)
  5. 十一、JUC包中的锁
  6. 20175323 团队项目 服务器端函数功能与业务逻辑详解
  7. Python 将中文、字母转成数字
  8. python之IO多路复用
  9. 我遇到的JPA中事务回滚的问题
  10. classNotFound异常的一个原因