对比学习的思想是拉近同类样本的距离,增大不同类样本的距离,目标是要从样本中学习到一个好的语义表示空间。SimCSE是一种简单的无监督对比学习框架,它通过对同一句子两次Dropout得到一对正样例,将该句子与同一个batch内的其它句子作为一对负样例。模型结构如下所示:

损失函数为:
ℓi=−log⁡esim⁡(hizi,hizi′)/τ∑j=1Nesim⁡(hizi,hjzj′)/τ\ell_{i}=-\log \frac{e^{\operatorname{sim}\left(\mathbf{h}_{i}^{z_{i}}, \mathbf{h}_{i}^{z_{i}^{\prime}}\right) / \tau}}{\sum_{j=1}^{N} e^{\operatorname{sim}\left(\mathbf{h}_{i}^{z_{i}}, \mathbf{h}_{j}^{z_{j}^{\prime}}\right) / \tau}} i=logj=1Nesim(hizi,hjzj)/τesim(hizi,hizi)/τ

代码实现

在作者的代码中,并不是将一个句子输入到模型中两次,而是复制一份放到同一个batch里。模型的核心是 cl_forward 函数:

def cl_forward(cls,encoder,input_ids=None,attention_mask=None,token_type_ids=None,position_ids=None,head_mask=None,inputs_embeds=None,labels=None,output_attentions=None,output_hidden_states=None,return_dict=None,mlm_input_ids=None,mlm_labels=None,
):return_dict = return_dict if return_dict is not None else cls.config.use_return_dictori_input_ids = input_ids    # 形状为[bs, num_sent, sent_len], bs=32batch_size = input_ids.size(0)# Number of sentences in one instance# 2: pair instance,[自己,自己]; 3: pair instance with a hard negative,[自己,自己,难例]num_sent = input_ids.size(1)mlm_outputs = None# Flatten input for encodinginput_ids = input_ids.view((-1, input_ids.size(-1))) # [bs * num_sent, sent_len]attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # [bs * num_sent, sent_len]if token_type_ids is not None:token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # [bs * num_sent, sent_len]# Get raw embeddings, [bs, num_sent, sent_len, hidden_size]outputs = encoder(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,return_dict=True,)# MLM auxiliary objectiveif mlm_input_ids is not None:mlm_input_ids = mlm_input_ids.view((-1, mlm_input_ids.size(-1)))mlm_outputs = encoder(mlm_input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=True if cls.model_args.pooler_type in ['avg_top2', 'avg_first_last'] else False,return_dict=True,)# Poolingpooler_output = cls.pooler(attention_mask, outputs)pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1))) # (bs, num_sent, hidden_size)# If using "cls", we add an extra MLP layer# (same as BERT's original implementation) over the representation.if cls.pooler_type == "cls":pooler_output = cls.mlp(pooler_output)# Separate representation, [bs, hidden_size], 同一样本经过“两次Dropout”得到的两个句向量z1, z2 = pooler_output[:,0], pooler_output[:,1]# Hard negativeif num_sent == 3:z3 = pooler_output[:, 2]# Gather all embeddings if using distributed trainingif dist.is_initialized() and cls.training:# Gather hard negativeif num_sent >= 3:z3_list = [torch.zeros_like(z3) for _ in range(dist.get_world_size())]dist.all_gather(tensor_list=z3_list, tensor=z3.contiguous())z3_list[dist.get_rank()] = z3z3 = torch.cat(z3_list, 0)# Dummy vectors for allgatherz1_list = [torch.zeros_like(z1) for _ in range(dist.get_world_size())]z2_list = [torch.zeros_like(z2) for _ in range(dist.get_world_size())]# Allgatherdist.all_gather(tensor_list=z1_list, tensor=z1.contiguous())dist.all_gather(tensor_list=z2_list, tensor=z2.contiguous())# Since allgather results do not have gradients, we replace the# current process's corresponding embeddings with original tensorsz1_list[dist.get_rank()] = z1z2_list[dist.get_rank()] = z2# Get full batch embeddings: (bs x N, hidden)z1 = torch.cat(z1_list, 0)z2 = torch.cat(z2_list, 0)# [bs, bs],计算该样本与其它样本的相似度cos_sim = cls.sim(z1.unsqueeze(1), z2.unsqueeze(0))# Hard negativeif num_sent >= 3:z1_z3_cos = cls.sim(z1.unsqueeze(1), z3.unsqueeze(0))cos_sim = torch.cat([cos_sim, z1_z3_cos], 1)# [bs, ], 内容为[0,1,...,bs-1],表示每个样本最相似的样本下标labels = torch.arange(cos_sim.size(0)).long().to(cls.device)# 此处显示出对比学习loss和常规交叉熵loss的区别,# 对比学习的label数是[bs,bs],而交叉熵的label数是[bs, label_nums]loss_fct = nn.CrossEntropyLoss()# Calculate loss with hard negativesif num_sent == 3:# Note that weights are actually logits of weightsz3_weight = cls.model_args.hard_negative_weightweights = torch.tensor([[0.0] * (cos_sim.size(-1) - z1_z3_cos.size(-1)) + [0.0] * i + [z3_weight] + [0.0] * (z1_z3_cos.size(-1) - i - 1) for i in range(z1_z3_cos.size(-1))]).to(cls.device)cos_sim = cos_sim + weightsloss = loss_fct(cos_sim, labels)# Calculate loss for MLMif mlm_outputs is not None and mlm_labels is not None:mlm_labels = mlm_labels.view(-1, mlm_labels.size(-1))prediction_scores = cls.lm_head(mlm_outputs.last_hidden_state)masked_lm_loss = loss_fct(prediction_scores.view(-1, cls.config.vocab_size), mlm_labels.view(-1))loss = loss + cls.model_args.mlm_weight * masked_lm_lossif not return_dict:output = (cos_sim,) + outputs[2:]return ((loss,) + output) if loss is not None else outputreturn SequenceClassifierOutput(loss=loss,logits=cos_sim,hidden_states=outputs.hidden_states,attentions=outputs.attentions,)

上述代码考虑诸多场景,比如分布式训练、难例三元组、mlm mask,写的较为复杂。

以下是简化版,更加符合论文的表述:

loss_func = nn.CrossEntropyLoss()
def simcse_loss(batch_emb):"""用于无监督SimCSE训练的loss"""batch_size = batch_emb.size(0)    # [bs, hidden_size]# 构造标签, [bs, 2], bs=64y_true = torch.cat([torch.arange(1, batch_size, step=2, dtype=torch.long).unsqueeze(1),torch.arange(0, batch_size, step=2, dtype=torch.long).unsqueeze(1)],dim=1).reshape([batch_size,])# 计算score和lossnorm_emb = F.normalize(batch_emb, dim=1, p=2)# [bs, bs],计算该样本与其它样本的相似度sim_score = torch.matmul(norm_emb, norm_emb.transpose(0,1))# 对角线的位置,也就是自身的余弦相似度,肯定为1,不产生loss,需要mask掉sim_score = sim_score - torch.eye(batch_size) * 1e12sim_score = sim_score * 20    # 温度系数loss = loss_func(sim_score, y_true)return loss

FAQ

  • 如果同一个batch里有其它语义相似的正样本,但在这里被当作了负样例处理,不是也拉远了同类样本的距离吗?

参考

  • princeton-nlp/SimCSE
  • “被玩坏了”的Dropout
  • 细节满满!理解对比学习和SimCSE,就看这6个知识点
  • SIMCSE算法源码分析

SimCSE论文及源码解读相关推荐

  1. Faster R-CNN论文及源码解读

    R-CNN是目标检测领域中十分经典的方法,相比于传统的手工特征,R-CNN将卷积神经网络引入,用于提取深度特征,后接一个分类器判决搜索区域是否包含目标及其置信度,取得了较为准确的检测结果.Fast R ...

  2. 语义分割之PointRend论文与源码解读

    参考:https://zhuanlan.zhihu.com/p/98508347?utm_source=qq 存在问题: 一般的语义分割网络,在得到一定分辨率的mask之后,都会直接插值回原像素尺寸, ...

  3. Transformer-XL解读(论文 + PyTorch源码)

    前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...

  4. DeformableDetr论文简介+mmdet源码解读

    文章目录 前言 一.论文解读 1.1. 研究问题 1.2. 可形变注意力模块 1.3. 拓展到多层特征图 二. mmdet源码讲解 2.1. 图像特征提取 2.2. 生成mask和位置编码 2.3. ...

  5. 【Deformable DETR 论文+源码解读】Deformable Transformers for End-to-End Object Detection

    目录 前言 一.背景和改进思路 二.细节原理和源码讲解 2.1.多尺度特征 2.1.1.backbone生成多尺度特征 2.1.2.多尺度位置编码 2.2.多尺度可变形注意力 2.2.1.普通多头注意 ...

  6. PTMs:QLoRA技巧的简介、使用方法、论文解读、源码解读之详细攻略

    PTMs:QLoRA技巧的简介.使用方法.论文解读.源码解读之详细攻略 目录 QLoRA技巧的简介 1.量化.分页优化器 QLoRA技巧的使用方法 1.安装 2.入

  7. Bert系列(二)——源码解读之模型主体

    本篇文章主要是解读模型主体代码modeling.py.在阅读这篇文章之前希望读者们对bert的相关理论有一定的了解,尤其是transformer的结构原理,网上的资料很多,本文内容对原理部分就不做过多 ...

  8. Ubuntu 16.04下Caffe-SSD的应用(四)——ssd_pascal.py源码解读

    前言 caffe-ssd所有的训练时的参数,全部由ssd_pascal.py来定义,之后再去调用相关的脚本和函数,所以想要训练自己的数据,首先要明白ssd_pascal.py各个定义参数的大体意思. ...

  9. Pseudo-document-based Topic Model(基于伪文档的主题模型)的理解以及源码解读

    本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 内容可能有不到之处,欢迎交流. 未经本人允许禁止转载. 论文来源 Zuo Y, Wu J, Zhang H, e ...

最新文章

  1. Redis学习和环境搭建
  2. oracle更改归档日志路径,oracle修改归档日志的路径
  3. Python中菱形继承的MRO顺序及property属性
  4. exar 带容隔离_带有美白功效的6款隔离霜 美白隔离两不误
  5. eclipse怎样改编码格式_Eclipse中各种编码格式及设置
  6. 鸿蒙手机(真机)播放音乐-第二集
  7. 谈谈API接口安全性设计思路
  8. 那些坑爹的老代码,究竟改还是不改?!
  9. c语言计算器程序代码 链栈,【C语言】简易科学计算器源代码(链栈应用)(原创).doc...
  10. BZOJ1001 狼抓兔子(网络流转最短路:对偶图)
  11. 杂项-黑苹果安装教程
  12. Could not find module ‘xxx‘ for target ‘xxx‘; found: i386, x86_64-apple-ios-simula错误解决
  13. centos7 安装wekan 看板
  14. 摆脱无效报警?十年运维监控报警优化经验总结
  15. tableau的骑行路线地理数据可视化
  16. Tmux_Hotkey
  17. Leco题目:回文数
  18. 《武则天正传》——林语堂版本,读后感
  19. 【电子电路】变送器信号 4~20mA 恒流源电路设计
  20. python format是什么

热门文章

  1. HTML基础学习记录
  2. 1758: [Wc2010]重建计划(TLE)
  3. React16版本更新的新特性
  4. 我的游戏学习日志52——游戏与艺术(3)
  5. Qt完成简易闹钟与画板
  6. android上调试H5小工具
  7. 科技也会成为一种文化
  8. 联想拯救者15isk-i5版加装固态硬盘和内存条
  9. mysql中字段约束unique_什么是MySQL UNIQUE约束,我们如何将其应用于表的字段?
  10. 阿呆穿越当程序员之设计模式系列-总纲