基于Robert的文本分类任务,在此基础上考虑融合对比学习、Prompt和对抗训练来提升模型的文本分类能力,我本地有SST-2数据集的train.txt、dev.txt两个文件,每个文件包含文本内容和标签两列,是个二分类任务,本项目基于pytorch实现。

先介绍一下要融合的三个技术。

1. 对比学习旨在通过对比相似和不相似的样本来提高分类模型的性能。对于每个样本,我们可以在训练时随机选取一个与其相似的样本,并加入到训练中,以鼓励模型更好地学习相似样本的特征,同时在训练时也要随机选取一个不相似的样本,并将其加入到训练中。这可以帮助模型更好地区分不同类别之间的特征。

2. Prompt是一种基于预设文本片段的模型输入方式。通过给定关键词和语法结构,Prompt可以引导模型学习某些具体任务。在文本分类任务中,我们可以给模型预设一些文本提示,以帮助模型更好地学习关键特征。

3. 对抗训练是一种在训练模型时加入干扰数据(扰动)的技术,以增强模型的鲁棒性。在文本分类任务中,我们可以通过向文本中添加词语或修改词语顺序,来生成干扰数据,从而帮助模型更好地区分和理解输入文本。

目录

一、安装依赖库

二、载数据集并进行数据预处理

三、定义模型并训练模型

四、对比学习实现

五、Prompt实现

六、对抗训练实现

七、整个过程封装成一个函数


一、安装依赖库

下面是具体实现的代码,我们将使用PyTorch框架:

首先安装必要的库:

!pip install transformers
!pip install torch
!pip install scikit-learn

然后我们导入需要的库以及设置随机种子以保证实验可重复性等必要组件:

import random
import numpy as np
import torch
from sklearn.metrics import accuracy_score, f1_score
from transformers import RobertaTokenizer, RobertaForSequenceClassification, AdamW
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import get_linear_schedule_with_warmupdevice = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True

二、载数据集并进行数据预处理

class TextDataset(Dataset):def __init__(self, tokenizer, path, max_length):self.tokenizer = tokenizerself.max_length = max_lengthself.labels = []self.texts = []with open(path) as f:for line in f:line = line.strip().split('\t')text, label = line[0], int(line[1])self.labels.append(label)self.texts.append(text)def __len__(self):return len(self.labels)def __getitem__(self, idx):text, label = self.texts[idx], self.labels[idx]encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length,return_tensors='pt')return dict(text=text,input_ids=encoding['input_ids'].squeeze(),attention_mask=encoding['attention_mask'].squeeze(),labels=torch.tensor(label))tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
train_dataset = TextDataset(tokenizer, 'train.txt', 256)
dev_dataset = TextDataset(tokenizer, 'dev.txt', 256)train_sampler = RandomSampler(train_dataset)
dev_sampler = SequentialSampler(dev_dataset)train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=16)
dev_loader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=16)

三、定义模型并训练模型

model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2).to(device)
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)# We will use a linear decay scheduler
total_steps = len(train_loader) * 5
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)for epoch in range(5):model.train()for batch in train_loader:batch = {k: v.to(device) for k, v in batch.items()}optimizer.zero_grad()outputs = model(**batch)loss = outputs[0]loss.backward()optimizer.step()scheduler.step()model.eval()with torch.no_grad():targets, preds = [], []for batch in dev_loader:batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)targets.extend(batch['labels'].tolist())preds.extend(torch.argmax(outputs.logits, axis=-1).tolist())acc = accuracy_score(targets, preds)f1 = f1_score(targets, preds)print(f'\nEpoch {epoch + 1}:')print(f'Dev Accuracy: {acc:.4f}')print(f'Dev F1 Score: {f1:.4f}')

至此,我们已经成功地训练了一款基于RoBERTa模型的文本分类器。下面是加入融合技术的实现。

四、对比学习实现

def random_similar_text(texts, labels):res_texts, res_labels = [], []for idx, text in enumerate(texts):res_texts.append(text)res_labels.append(labels[idx])# 随机选择一个与当前样本相似的样本,将它加入到数据集中rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])# 随机选择一个不相似的样本,将它加入到数据集中rand_idx = np.random.choice(len(texts), 1)[0]while rand_idx == idx:rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])return res_texts, res_labelstrain_texts, train_labels = random_similar_text(train_dataset.texts, train_dataset.labels)
train_dataset = TextDataset(tokenizer, 'train.txt', 256)

五、Prompt实现

def add_prompt(prompt, texts):return [f'{prompt}{text}' for text in texts]train_dataset.texts = add_prompt('This text is', train_dataset.texts)
dev_dataset.texts = add_prompt('This text is', dev_dataset.texts)

六、对抗训练实现

def add_perturbations(text, n):# 随机选择n个词,并在其周围添加一些噪声生成n个干扰文本words = text.split()idx_list = np.random.choice(len(words), n, replace=False)for idx in idx_list:words[idx] = f'[{words[idx]}]'return ' '.join(words)def generate_perturbations(texts):return [add_perturbations(text, 3) for text in texts]train_dataset.texts += generate_perturbations(train_dataset.texts)
dev_dataset.texts += generate_perturbations(dev_dataset.texts)

七、整个过程封装成一个函数

def train_roberta_with_fusion(train_path, dev_path, num_classes, fusion_type):def random_similar_text(texts, labels):res_texts, res_labels = [], []for idx, text in enumerate(texts):res_texts.append(text)res_labels.append(labels[idx])rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])rand_idx = np.random.choice(len(texts), 1)[0]while rand_idx == idx:rand_idx = np.random.choice(len(texts), 1)[0]res_texts.append(texts[rand_idx])res_labels.append(labels[rand_idx])return res_texts, res_labelsdef add_perturbations(text, n):words = text.split()idx_list = np.random.choice(len(words), n, replace=False)for idx in idx_list:words[idx] = f'[{words[idx]}]'return ' '.join(words)def add_prompt(prompt, texts):return [f'{prompt}{text}' for text in texts]def generate_perturbations(texts):return [add_perturbations(text, 3) for text in texts]class TextDataset(Dataset):def __init__(self, tokenizer, path, max_length):self.tokenizer = tokenizerself.max_length = max_lengthself.labels = []self.texts = []with open(path) as f:for line in f:line = line.strip().split('\t')text, label = line[0], int(line[1])self.labels.append(label)self.texts.append(text)if fusion_type == 'contrastive':self.texts, self.labels = random_similar_text(self.texts, self.labels)if fusion_type == 'adversarial':self.texts += generate_perturbations(self.texts)if fusion_type == 'prompt':self.texts = add_prompt('This text is', self.texts)def __len__(self):return len(self.labels)def __getitem__(self, idx):text, label = self.texts[idx], self.labels[idx]encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_length,return_tensors='pt')return dict(text=text,input_ids=encoding['input_ids'].squeeze(),attention_mask=encoding['attention_mask'].squeeze(),labels=torch.tensor(label))device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')tokenizer = RobertaTokenizer.from_pretrained('roberta-base')train_dataset = TextDataset(tokenizer, train_path, 256)dev_dataset = TextDataset(tokenizer, dev_path, 256)train_sampler = RandomSampler(train_dataset)dev_sampler = SequentialSampler(dev_dataset)train_loader = DataLoader(train_dataset, sampler=train

Robert+Prompt+对比学习+对抗训练文本分类相关推荐

  1. 超越SimCSE两个多点,Prompt+对比学习的文本表示新SOTA

    可能是因为对比学习,今年以来文本表示方向突然就卷起来了,SOTA刷的嗖嗖的,我还停留在我们ConSERT的阶段,结果别人不精调就已经超了. 昨天实习同学发了我几篇Open Review上ACL的投稿, ...

  2. 用深度学习解决大规模文本分类问题

     用深度学习解决大规模文本分类问题 人工智能头条 2017-03-27 22:14:22 淘宝 阅读(228) 评论(0) 声明:本文由入驻搜狐公众平台的作者撰写,除搜狐官方账号外,观点仅代表作者 ...

  3. 【NLP】相当全面:各种深度学习模型在文本分类任务上的应用

    论文标题:Deep Learning Based Text Classification:A Comprehensive Review 论文链接:https://arxiv.org/pdf/2004. ...

  4. 【NLP从零入门】预训练时代下,深度学习模型的文本分类算法(超多干货,小白友好,内附实践代码和文本分类常见中文数据集)

    如今NLP可以说是预训练模型的时代,希望借此抛砖引玉,能多多交流探讨当前预训练模型在文本分类上的应用. 1. 任务介绍与实际应用 文本分类任务是自然语言处理(NLP)中最常见.最基础的任务之一,顾名思 ...

  5. VideoCLIP-FacebookCMU开源视频文本理解的对比学习预训练,性能SOTA!适用于零样本学习!...

    关注公众号,发现CV技术之美 0 写在前面 在本文中,作者提出了VideoCLIP,这是一种不需要下游任务的任何标签,用于预训练零样本视频和文本理解模型的对比学习方法.VideoCLIP通过对比时间重 ...

  6. [深度学习] 自然语言处理 --- 文本分类模型总结

    文本分类 包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMO,BERT等)的文本分类 fastText 模型 textCNN 模型 charCNN 模型 Bi-LSTM 模型 ...

  7. Prompt+对比学习,更好地学习句子表征

    每天给你送来NLP技术干货! 作者 | 王嘉宁@华师数据学院 整理 | NewBeeNLP https://wjn1996.blog.csdn.net/article/details/12552885 ...

  8. 【NLP】Prompt+对比学习,更好地学习句子表征

    作者 | 王嘉宁@华师数据学院 整理 | NewBeeNLP https://w‍jn1996.blog.csdn.net/article/details/125528859 ‍ 虽然BERT等语言模 ...

  9. NLP深度学习:PyTorch文本分类

    文本分类是NLP领域的较为容易的入门问题,本文记录文本分类任务的基本流程,大部分操作使用了torch和torchtext两个库. 1. 文本数据预处理 首先数据存储在三个csv文件中,分别是train ...

最新文章

  1. 19_Android中图片处理原理篇,关于人脸识别网站,图片加载到内存,图片缩放,图片翻转倒置,网上撕衣服游戏案例编写...
  2. 学python是看书还是看视频-自学Python是看书还是看视频?
  3. 响应式精美列商城发卡源码
  4. C++ string类不能像C字符串能靠在i位赋值为‘\0’来截断
  5. 欧姆龙PLC HostLink协议整理
  6. NASA数据批量下载——wget
  7. Latex学习之插入编号-实心圆点列表,横杆,数字
  8. 配置JDK、Tomcat环境、DNK环境
  9. 关于网页加载慢的一个解决方法——取消勾选【局域网设置】中的【自动检测设置】
  10. 《从零开始的 RPG 游戏制作教程》第五期:制作物品和技能
  11. 韩商言房子卖价有多高,做现女友就有多难?
  12. teamviewer 使用数量到达上限_解决Teamviewer免费版设备数量限制
  13. c语言编写51单片机中断程序,执行过程是怎样的?
  14. linux下 部署调用SAP接口
  15. c语言汉诺塔问题用指针变量,谁会用C语言解决汉诺塔问题?请进,最好把每一步的解释写上有三个 爱问知识人...
  16. Saruman's Army
  17. eclipse汉化方式(下载,安装,中英切换)
  18. 【Unity3D开发小游戏】《塔防游戏》Unity开发教程
  19. 【python】在图片加上数字
  20. Proxmox VE 超融合集群不停服务更换硬盘操作实录

热门文章

  1. 如何正确设计一个界面
  2. opentracing-02 dapper论文词汇摘要
  3. 南京地铁票价查询HTML版
  4. ART、SIRT、SART算法
  5. 玲珑杯 1160 - 康娜与玲珑杯
  6. linux block layer第二篇bio 的操作
  7. FilterDispatcher is deprecated! 1
  8. 从零搭建若依(Ruoyi-Vue)管理系统(13)--登录和鉴权的实现
  9. NoSQL数据库知多少--列存储Cassandra数据库
  10. 解决Maven报错:Plugin execution not covered by lifecycle configuration