用模型“想象”出来的target来训练,可以提高分类的效果! - 知乎LearnFromPapers系列——用模型“想象”出来的target来训练,可以提高分类的效果作者:郭必扬 时间:2020年最后一天前言:今天是2020年最后一天,这篇文章也是我的SimpleAI公众号2020年的最后一篇推文,感谢大家一…https://zhuanlan.zhihu.com/p/340930767

https://github.com/Changanyue/NLP_paper_implementationhttps://github.com/Changanyue/NLP_paper_implementation        这是一篇在文本分类上的文章,但是整体的对于分类问题的流程,图像和文本都是相似的。文章可以看成是对label smoothing的进一步优化,label smoothing对于label的soft是利用一个均匀分布来加权,本身并没有意义,它只是让target更加soft了,让softmax probs更小了从而保证它在loss曲线的中间那段,本文的思路是对label进行一个编码,让label除了更加soft之外,还能反映标签之间存在的联系(词向量空间改造one-hot), 是对one-hot这种暴力标签的一种优化,其实在做创意项目时,创意图的分类往往大多数情况下标签都不是非常准确的,有噪声的,比如说简约文艺和前卫新锐之间差别很多时间都很难以准确区分,one-hot标签除了噪声之外还特别容易over-confident,因此这方面优化是很有必要的。

1.Abstract

The inadequate one-hot representations tend to train the model to be over-confident, which may result in arbitrary prediction and model overfitting, especially for confused datasets (datasets with very similar labels) or noisy datasets (datasets with labeling errors).不足的one-hot表示去训练模型over-confident,这可能导致任意预测和模型过拟合,尤其是混淆(具有非常相似的标签)和噪声数据集(具有错误标签)。label smoothing虽然可以在一定程度上缓解这个问题,但是它无法捕获标签之间的真实关系。这里有两个点,1.one-hot容易over-confident,2.label smoothing无法捕获标签联系。

2.Intrdouction

作者在介绍中也说了one-hot的三个问题:1.真实标签跟其他标签之间的关系被忽略了,很多有用的知识无法学到;比如:“鸟”和“飞机”本来也比较像,因此如果模型预测觉得二者更接近,那么应该给予更小的loss。2.倾向于让模型更加“武断”,成为一个“非黑即白”的模型,导致泛化性能差。3.面对易混淆的分类任务、有噪音(误打标)的数据集时,更容易受影响。In brief, the limitation of current learning paradigm will lead to confusion in prediction that the model is hard to distinguish some labels, which we refer as label confusion problem (LCP). 当前学习范式的局限性会导致模型难以区分某些标签的预测混乱,我们将其称为标签混淆问题(LCP)。

3.label smoothing

由 LS 生成的标签分布不能反映每个训练样本的真实标签分布,因为它是通过简单地添加一些噪声获得的。真实的标签分布应该揭示实例和每个标签之间的语义关系,并且相似的标签在分布中应该具有相似的程度。In nature, label smoothing encourages the model to learn less, rather than learn more accurately of the knowledge in training samples, which may have the risk of underfitting.本质上,标签平滑鼓励模型学习更少,而不是更准确地学习训练样本中的知识,这可能存在欠拟合的风险。

4.Our approach

直观地说,存在一个标签分布,它反映了每个标签如何描述大多数分类任务的当前实例的程度。 然而,在实践中,我们只能获得样本的唯一标签(单标签分类)或多个标签(多标签分类),而不是样本与标签的关联度。如果没有提供统计信息,就没有一种自然且经过验证的方法可以将one-hot标签转移到适当的分布。 虽然理论上的真实标签分布不容易实现,但我们仍然可以尝试通过挖掘实例和标签背后的语义信息来模拟它。做创意图质量评估时,做过这样的方案,一幅图质量怎么样,可以分为1-10,找200个人来评分,则有一个0-1之间的分布,那么对于模型来说,就是监督这个分布,最终用均值方差来刻画。对于one-hot形式来说,softmax对输出的prob进行刻画时,就是0或者1,不是样本和label的关联度,作者用lcm其实就想模拟这个有类别个数维度的分布,比如鸟,花,飞机这个标签原本是(1,0,0),最终在模拟的标签分布上应该是(0.8,0,0.2)。

考虑到标签混淆问题通常是由语义相似性引起的,我们假设能够反映标签之间相似关系的标签分布有助于训练更强的模型并解决标签混淆问题。一个简单的想法是找到每个标签的描述并计算每两个标签之间的相似度。 然后对于每个one-hot标签表示,我们可以使用归一化的相似度值来创建标签分布。 然而,以这种方式得到的标签分布对于具有相同标签的实例来说都是相同的,无论它们的内容如何。 实际上,即使两个实例具有相同的标签,它们的内容也可能大不相同,因此它们的标签分布也应该不同。解释了作者为什么要在label编码之后和输入数据进行相似度计算,即便是具有相同标签的实例,他们的内容也在不断变化,有的差别甚至很大,他需要标签的同步变化来让loss始终有一个很好的值,让他始终在sigmoid中间的那块区域上。

因此,我们应该利用实例和标签之间的关系来构建标签分布,这样标签分布就会随着标签相同的不同实例而动态变化。 对于文本分类问题,我们可以通过文档文本表示与每个标签之间的相似度来模拟标签分布。 这样,不仅捕获了实例和标签之间的关系,标签之间的依赖关系也反映在这些关系上。我们设计了一个标签混淆模型,通过计算实例和标签之间的语义关系来学习模拟标签分布。 然后将 模拟标签分布视为真实的标签分布,并将其与预测分布进行比较,以通过 KL散度计算损失。

上式中, v是每个特征表示,经过softmax拿到预测的标签分布(PLD),是个0-1之间的分布。

上式中,V是标签训练的矩阵(nn.embedding),c个类,拿到标签矩阵之后,这里有一个nn.linear,来建模输入的特征表示和标签的训练矩阵之间的关系,通过点积计算它们的相似度值,然后用softmax做一次归一化,yt是one-hot的,yc是标签矩阵的,做了个融合的到ys。

最终ys和预测得到的yp做kl loss。

如下图:

PLD是预测得到的经过softmax的0-1分布,结合下面的代码,就是y_pred,右侧是LCM,先对输入的label做nn.embedding,得到label_emb,在经过一个nn.linear+激活函数tanh,就是作者的label encoder,之后得到label_emb就是标签的矩阵表示,这里有个similarity layer,其实就是用torch.mm做一次相似度计算,在经过一个nn.linear层,此时已经得到了label_sim_dict,想要得到作者图中的label condusion vector,则是在lcm loss中对label_sim_dict做了一次softmax,再将label_sim_dict和one-hot标签做了次融合,在经过softmax,得到模拟的标签分布,nn.kldivloss前对预测结果做交叉熵,这里看原论文作者的代码可能更清晰一些,直接用nn.kldivloss优化也没问题。

simulated_y_true = K.softmax(label_sim_dist+alpha*y_true)
loss1 = -K.categorical_crossentropy(simulated_y_true,simulated_y_true)
loss2 = K.categorical_crossentropy(simulated_y_true,pred_probs)
class BERT_LCM(nn.Module):def __init__(self,pretrained_model_name_or_path,hidden_size,num_classes,alpha,wvdim=768,max_len=128,label_embedding_matrix=None):super(BERT_LCM, self).__init__()self.num_classes = num_classesself.bert = BertModel.from_pretrained(pretrained_model_name_or_path)self.bert_fc1 = nn.Linear(self.bert.config.hidden_size, hidden_size)self.bert_fc2 = nn.Linear(hidden_size, num_classes)# label_encoder:if label_embedding_matrix is None: # 不使用pretrained embeddingself.label_emb = nn.Embedding(num_classes,wvdim) # (n,wvdim)else:self.label_emb = nn.Embedding(num_classes,wvdim,_weight=label_embedding_matrix)self.label_fc = nn.Linear(wvdim, hidden_size)self.sim_fc = nn.Linear(num_classes, num_classes)def forward(self, input_ids=None, token_type_ids=None, labels=None):bert_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids)text_emb = bert_output['last_hidden_state'][:,0,:]text_emb = torch.tanh(self.bert_fc1(text_emb))# print(text_emb.shape,'text')  # [16,64]y_pred = self.bert_fc2(text_emb)label_emb = self.label_emb(labels)label_emb = F.tanh(self.label_fc(label_emb))# print(label_emb.shape,'label')  # [16,20,64]doc_product = torch.bmm(label_emb, text_emb.unsqueeze(-1))  # (b,n,d) dot (b,d,1) --> (b,n,1)# print(doc_product.shape)   # [16,20,1]label_sim_dict = self.sim_fc(doc_product.squeeze(-1))#print(label_sim_dict.shape)return y_pred, label_sim_dictdef lcm_loss(y_true,y_pred,label_sim_dist,alpha):label_sim_dist = F.softmax(label_sim_dist)pred_log_probs = F.log_softmax(y_pred)simulated_y_true = F.softmax(label_sim_dist + alpha * F.one_hot(y_true, num_classes))loss = nn.KLDivLoss()(pred_log_probs, simulated_y_true)return loss

5.experiment

左图是混淆矩阵,对角线表示正确分类,对角线外反应了两个标签的相似度,右侧是t-sne在2d上的20个类的表示,相似类有类聚。

本文整体上看还是对标签分布的一次优化,是label smoothing的升级版本,label smoothing是加一个均匀分布的噪声来改善one-hot,本文是通过nn.embedding+nn.linear来生成一个标签矩阵,但是这个标签矩阵要和输入产生联系,产生联系之后每一次的标签分布就不一样了,随着进来的输入不断在变化,因此,用输入的特征和标签矩阵计算相似度,再包一次nn.linear用softmax再归一化,最终怕信号丢了,又融合了一次one-hot做了这么一个模拟标签分布,用kl loss来控制两个分布的相似度,从而输出监督信号,其实设计的还是挺复杂的。不过作者是希望model能够拿到label之间联系的,至于有没有真正实现了这个意图还是得看看模拟标签分布的实际情况。

label confusion learning to enhance text classification models相关推荐

  1. 【论文笔记】Adversarial Multi-task Learning for Text Classification

    一.概要   该文章发于ACL 2017,针对于已有的大多数神经网络多任务学习模型进行文本分类,存在的共享特征(shared features)可能再次包含特定任务特征(task-specific f ...

  2. 综述:基于深度学习的文本分类 --《Deep Learning Based Text Classification: A Comprehensive Review》总结(一)

    文章目录 综述:基于深度学习的文本分类 <Deep Learning Based Text Classification: A Comprehensive Review>论文总结(一) 总 ...

  3. Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation 阅读笔记

    Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation PyTorch实现:https://g ...

  4. [ACL22] An Imitation Learning Curriculum for Text Editing with Non-Autoregressive Models阅读笔记

    An Imitation Learning Curriculum for Text Editing with Non-Autoregressive Models [pdf] 论文状态:被ACL22接收 ...

  5. TensorFlow版本的BERT微调:Fine Tune BERT for Text Classification with TensorFlow

    文章目录 写在前面 Project Structure Task 2: Setup your TensorFlow and Colab Runtime. Install TensorFlow and ...

  6. Deep Unordered Composition Rivals Syntactic Methods for Text Classification(简摘DAN模型)

    Deep Unordered Composition Rivals Syntactic Methods for Text Classification(简摘) 摘要 成果 模型结构 结论 摘要 Man ...

  7. 论文解读:Exploiting Cloze Questions for Few Shot Text Classification and Natural Language Inference

    论文解读:Exploiting Cloze Questions for Few Shot Text Classification and Natural Language Inference   随着 ...

  8. 朴素贝叶斯案例之text classification

    准备数据:20news groups 你可以在github上下到该数据集:20newsbydate.tar.gz 然后找到dataset loader 打开twenty_newsgroups.py 将 ...

  9. 【提示学习】Exploiting Cloze Questions for Few Shot Text Classification and Natural Language Inference

    论文信息 名称 内容 论文标题 Exploiting Cloze Questions for Few Shot Text Classification and Natural Language Inf ...

  10. Text Classification Algorithms: A Survey——1. Introduction引言

    "Most text classification and document categorization systems can be deconstructed into the fol ...

最新文章

  1. 12.QT线程的两种启动方式
  2. 你还记得当初为什么进入IT行业吗?
  3. 如何使用Docker暴露多个端口?
  4. Transformer architecture的解释
  5. modelandview 可以返回html么_Python: 爬虫网页解析工具lxml.html(一)
  6. linux软件依赖库,【Linux】ubuntu系统安装及软件依赖库
  7. Android 入门篇(一)环境搭建
  8. 带你玩转Visual Studio(八)——带你跳出坑爹的Runtime Library坑
  9. 编程之美 裴波那楔数列
  10. ubuntu 14.04 server 安装virtualbox (虚拟机)deb
  11. (6)java的内存泄露问题
  12. JAVA经纬度距离计算并排序-Spatial4j+ForkJoin
  13. 修复软件图标不正常显示问题
  14. 美创解读|《数据安全法》实施,企业数据安全合规技术能力建设
  15. python动画精灵_Python小课堂第18课:如何使用Pygame做动画精灵和碰撞检测
  16. linux用户的主要配置文件,Linux用户和组的主要配置文件及其相关命令
  17. 信息学奥赛一本通(c++):2066:【例2.3】买图书
  18. 1、ABPZero系列教程之拼多多卖家工具 前言
  19. chrome 未连接到互联网 代理服务器出现问题,或者地址有误
  20. Python中and和or的运算规则,短路计算

热门文章

  1. Linked Data_百度百科
  2. 必读论文 | 机器交互必读论文8篇
  3. Codeforces1196D2
  4. 【年中总结】衣带渐宽终不悔
  5. 随机测试数据生成 与 Pandas迭代方法性能对比
  6. 哈夫曼码的编译码系统
  7. linux系统是不是国产的,LINUX是什么系统,是国产软件吗
  8. php nginx 伪静态规则,常见PHP程序的Nginx 伪静态规则
  9. U盘格式化导致存储空间变小的解决方法汇总
  10. excel文件的工作表保护密码忘记了