NLP实践——以T5模型为例训练seq2seq模型

  • 0. 介绍
  • 1. 数据下载与加载
  • 2. 创建模型
  • 3. 训练评估函数
  • 4. 模型训练
  • 5. 模型预测

0. 介绍

回顾这两年NLP领域的研究,生成式模型可谓是一大热门方向,huggingface的transformers模块中,包含了常见的各类生成式模型框架,以及它们对应的生成任务,使得生成模型的搭建已经非常方便。

本文将以T5为例,介绍如何利用transformers模块搭建生成式模型,并训练seq2seq生成模型。

本文是对huggingface社区的一个项目的搬运整理,但是我现在找不到原项目的链接了,所以很抱歉没能贴出。

正式开始之前,设置全局变量:

TRAIN_BATCH_SIZE = 2
VALID_BATCH_SIZE = 2
TRAIN_EPOCHS = 5
VAL_EPOCHS = 1
LEARNING_RATE = 1e-4
MAX_LEN = 512
SUMMARY_LEN = 150

1. 数据下载与加载

数据来自kaggle网站,同样找不到地址了,所以我把数据传网盘了。
提取码:t7pv

使用pandas组织数据:

import pandas as pddf = pd.read_csv('./news_summary.csv', encoding='latin-1')
df = df[['text','ctext']]
df.head()

其中text是摘要,ctext是对应的原文。

对训练集和验证集进行划分:

train_size = 0.8
train_dataset = df.sample(frac=train_size, random_state=0)
val_dataset = df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("VAL Dataset: {}".format(val_dataset.shape))

然后我们构造一个DataSet类,进而生成DataLoader。

from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSamplerclass CustomDataset(Dataset):def __init__(self, dataframe, tokenizer, source_len, summ_len):self.tokenizer = tokenizerself.data = dataframeself.source_len = source_lenself.summ_len = summ_lenself.text = self.data.textself.ctext = self.data.ctextdef __len__(self):return len(self.data)def __getitem__(self, index):ctext = str(self.ctext[index])ctext = ' '.join(ctext.split())text = str(self.text[index])text = ' '.join(text.split())source = self.tokenizer.batch_encode_plus([ctext], max_length= self.source_len, pad_to_max_length=True, return_tensors='pt')target = self.tokenizer.batch_encode_plus([text], max_length= self.summ_len, pad_to_max_length=True, return_tensors='pt')source_ids = source['input_ids'].squeeze()source_mask = source['attention_mask'].squeeze()target_ids = target['input_ids'].squeeze()target_mask = target['attention_mask'].squeeze()return {'source_ids': source_ids.to(dtype=torch.long), 'source_mask': source_mask.to(dtype=torch.long), 'target_ids': target_ids.to(dtype=torch.long),'target_ids_y': target_ids.to(dtype=torch.long)}# 创建DataSet
training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN, SUMMARY_LEN)
val_set = CustomDataset(val_dataset, tokenizer, MAX_LEN, SUMMARY_LEN)# 创建DataLoader
train_params = {'batch_size': TRAIN_BATCH_SIZE,'shuffle': True,'num_workers': 0}val_params = {'batch_size': VALID_BATCH_SIZE,'shuffle': False,'num_workers': 0}training_loader = DataLoader(training_set, **train_params)
val_loader = DataLoader(val_set, **val_params)

2. 创建模型

我们利用transformers模块搭建模型。

from transformers import T5Tokenizer, T5ForConditionalGeneration, PreTrainedTokenizer, PreTrainedModel# 下载T5模型,存放于某目录下
t5_path = 'xxxxxxxxx/T5-base'
tokenizer = T5Tokenizer.from_pretrained(t5_path)
model = T5ForConditionalGeneration.from_pretrained(t5_path)
model.to('cuda:0')
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

3. 训练评估函数

接下来我们需要写3个方法,分别对应模型的训练、评估和预测。

训练没有什么问题,直接采用原有的方法:

def train(epoch, tokenizer, model, device, loader, optimizer):model.train()for _, data in tqdm(enumerate(loader, 0), desc='step'):y = data['target_ids'].to(device, dtype = torch.long)y_ids = y[:, :-1].contiguous()lm_labels = y[:, 1:].clone().detach()lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100ids = data['source_ids'].to(device, dtype = torch.long)mask = data['source_mask'].to(device, dtype = torch.long)outputs = model(input_ids=ids, attention_mask=mask, decoder_input_ids=y_ids, labels=lm_labels)loss = outputs[0]# if _%500==0:#     print(f'Epoch: {epoch}, Loss:  {loss.item()}')optimizer.zero_grad()loss.backward()optimizer.step()

评估方法在原来的笔记中没有引入评价指标,我认为是有问题的,这里我选择了用BLEU-4指标进行评估。当然,你也可以采用其他合理的指标进行评价。
(我不记得原来的代码是怎么写的了,因为已经被我删了,我只贴出我的评估方法)

from nltk.translate.bleu_score import sentence_bleudef evaluate(tokenizer, model, device, loader):"""用BLEU4评估"""model.eval()bleus = []with torch.no_grad():for _, data in tqdm(enumerate(loader, 0), desc='Evaluate'):target_ids = data['target_ids'].to(device, dtype = torch.long)ids = data['source_ids'].to(device, dtype = torch.long)mask = data['source_mask'].to(device, dtype = torch.long)generated_ids = model.generate(input_ids = ids,attention_mask = mask, max_length=150, num_beams=2,repetition_penalty=2.5, length_penalty=1.0, early_stopping=True)preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in target_ids]bleu_4 = sentence_bleu([tar.split() for tar in target], preds[0].split(), [0, 0, 0, 1])bleus.append(bleu_4)return sum(bleus) / len(bleus)

原来的笔记没有提供预测方法,所以这里我们自己写一个,也非常简单:

def predict(tokenizer: PreTrainedTokenizer, model: PreTrainedModel, text: str, device):with torch.no_grad():inputs = tokenizer(text, max_length=MAX_LEN, padding=True, return_tensors='pt')ids = inputs['input_ids']mask = inputs['attention_mask']ids = ids.to(device)mask = mask.to(device)generated_ids = model.generate(input_ids = ids,attention_mask = mask, max_length=150, num_beams=2,repetition_penalty=2.5, length_penalty=1.0, early_stopping=True)preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]return preds

4. 模型训练

接下来进入训练环节。

best_bleu = 0
for epoch in tqdm(range(TRAIN_EPOCHS), desc='epoch'):train(epoch, tokenizer, model, device, training_loader, optimizer)cur_bleu = evaluate(tokenizer, model, device, val_loader)if cur_bleu > best_bleu:torch.save(model.state_dict, 't5_best_model.pt')best_bleu = cur_bleuprint('Best bleu: {}, Current bleu: {}'.format(best_bleu, cur_bleu))

BLEU-4最高的模型会被保存在当前目录下。

5. 模型预测

输入一段文字,使用之前写的预测函数进行预测:

text = """Provided by NBC News The remains of a sailor missing in action since the Dec. 7, 1941, attack on Pearl Harbor have been identified, a federal agency said. Petty Ofc. 2nd Class Claude Ralph Garcia died at age 25 while serving as a ship fitter aboard the USS West Virginia when Japanese forces attacked the U.S. naval base near Honolulu. The Defense POW/MIA Accounting Agency, which accounts for missing defense personnel, made the positive identification recently. Garcia was born to father Rafael Garcia in Ventura County, California, on April 27, 1916, according to Honor States, an organization that tracks the life and achievements of fallen military members. He graduated from Ventura High School in 1933 and attended community college before enlisting in the Navy, according to the VC Star, which said local news reports from 1943 described Garcia as Ventura's first World War II presumed casualty, and his memorial service was estimated to have drawn over 300 mourners."""preds = predict(tokenizer, model, text, device)
print(preds)# ['the remains of a sailor missing in action since the Dec. 7, 1941, attack on Pearl Harbor have been identified, a federal agency has said. Garcia was born to father Rafael Garcia in Ventura County, California, on April 27, 1916. He died at age 25 while serving as a ship fitter aboard the USS West Virginia when Japanese forces attacked the U.S. naval base near Honolulu.']

从预测的结果来看,使用这个数据集训练出来的模型尽管是一个生成模型,但是生成的结果基本都是原文中连续的片段,类似于抽取式的模型。

以上就是本文的全部内容了。

在今后的节目 博客中,我还准备了很多我自己编写、整理的原创作品,期待的话请多多为我投币吧。

NLP实践——以T5模型为例训练seq2seq模型相关推荐

  1. seq2seq模型_彻底理解 Seq2Seq 模型

    Seq2Seq 是一种循环神经网络的变种,包括编码器 (Encoder) 和解码器 (Decoder) 两部分.Seq2Seq 是自然语言处理中的一种重要模型,可以用于机器翻译.对话系统.自动文摘. ...

  2. 保存模型后无法训练_模型构建到部署实践

    导读 在工业界一般会采用了tensorflow-serving进行模型的部署,而在模型构建时会因人而异会使用不同的深度学习框架,这就需要在使用指定深度学习框架训练出模型后,统一将模型转为pb格式,便于 ...

  3. MMDeteceion之系列一(环境安装、模型测试、训练以及模型后处理工具)

    1.MMDeteceion初识 MMDetection是一款优秀的基于PyTorch的深度学习目标检测工具箱,由香港中文大学(CUHK)多媒体实验室(mmlab)开发.基本上支持所有当前SOTA二阶段 ...

  4. 软件测试设计之MFQ模型用例设计覆盖模型

    建模完成后,需要使用测试用例来覆盖这些模型,在以前的写用例过程中用例和数据是同时完成的,在MFQ模型中,将测试用例设计分成两个步骤:一是设计基础测试用例来覆盖模型:二是针对每个测试用例更多的测试数据产 ...

  5. seq2seq模型_Pytorch学习记录-Seq2Seq模型对比

    Pytorch学习记录-torchtext和Pytorch的实例4 0. PyTorch Seq2Seq项目介绍 在完成基本的torchtext之后,找到了这个教程,<基于Pytorch和tor ...

  6. 使用TensorFlow训练WDL模型性能问题定位与调优

    简介 TensorFlow是Google研发的第二代人工智能学习系统,能够处理多种深度学习算法模型,以功能强大和高可扩展性而著称.TensorFlow完全开源,所以很多公司都在使用,但是美团点评在使用 ...

  7. ####haohaohao######基于知识库的问答KBQA:seq2seq模型实践

    问题描述 基于知识图谱的自动问答(Question Answering over Knowledge Base, 即 KBQA)问题的大概形式是,预先给定一个知识库(比如Freebase),知识库中包 ...

  8. 论文浅尝 | 使用预训练深度模型和迁移学习方法的端到端模糊实体匹配

    论文笔记整理:高凤宁,南京大学硕士,研究方向为知识图谱.实体消解. 链接:https://doi.org/10.1145/3308558.3313578 动机 目前实体匹配过程中实体之间的差异比较微妙 ...

  9. RNN模型与NLP应用:机器翻译与Seq2Seq模型-7/9

    目录 一.前言 二.Seq2Seq模型的搭建 三.Seq2Seq模型的预测 四.Seq2Seq的改进 五.总结 六.参考连接 一.前言 Seq2Seq模型把英语翻译成德语 我们可以注意到机器翻译是一个 ...

最新文章

  1. 在linux上面合并多个windows文件乱码的问题
  2. Yii2.0 数据库更新update
  3. 19.12 添加自定义监控项目 19.13/19.14 配置邮件告警 19.15 测试告警 19.16 不发邮件的问题处理...
  4. 酱油和gbt酱油哪个好_都说日本的酿造酱油品质好,我国的酱油究竟差在哪儿?...
  5. webpack配置--传统多页面项目
  6. java读取excel数据保存到数据库中_java读取excel的内容(可保存到数据库中)
  7. fiddler抓取火狐浏览器上https协议请求
  8. halcon2D Metrology测量算子,卡尺测量算子,持续更新
  9. 【LeetCode】【HOT】142. 环形链表 II(快慢指针)
  10. 前端工程师的迷茫:不知道我这种前端是不是被淘汰了?
  11. 算法 Tricks(二) —— 大数的处理
  12. opencv 模板匹配,在图像中寻找物体
  13. 清理Windows.edb文件释放C盘空间(原创)
  14. 写了个淡入淡出的jq幻灯片插件
  15. FreeBSD下使用Blogbio写cnblogs博客
  16. matlab随机线性微分方程,基于MATLAB的随机线性微分方程的求解
  17. 重温LuGre摩擦力模型
  18. 百战程序员python视频下载_[视频教程] 百战程序员python400集(第一季115集)
  19. python制作u盘病毒_Python-记一次U盘中病毒及文件找回
  20. ChatGPT聊天机器人如何发图片????

热门文章

  1. 江苏省一级计算机ms,计算机一级六大MS题型介绍
  2. MXNet对DenseNet(稠密连接网络)的实现
  3. uniapp写微信授权登录
  4. t-star腾讯安全高校挑战赛2022 writeup
  5. 外业精灵,在水土流失监测野外调查工作中的应用
  6. 基于Matlab模拟AWGN 信道上 OFDM附完整代码
  7. Java 中的设计模式详细介绍
  8. 在pycharm中如何使用pyinstaller
  9. 在C语言中使用bool
  10. “知识共享”早期版本是什么样子?