1 训练

import tqdm
from datasets import load_dataset
import lawrougeimport datasets
import random
import pandas as pdfrom datasets import dataset_dict
import datasetsfrom transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainerimport warnings
from pathlib import Path
from typing import List, Tuple, Unionfrom torch import nnimport jieba
import numpy as np
import lawrougefrom transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
from transformers.utils import loggingdataset = load_dataset('json', data_files='nlpcc_data.json', field='data')def flatten(example):return {"document": example["content"],"summary": example["title"],"id":"0"}
dataset = dataset["train"].map(flatten, remove_columns=["title", "content"]) # , remove_columns=["title", "content"]TokenModel = "bert-base-chinese"from transformers import AutoTokenizer, BertConfig
tokenizer = AutoTokenizer.from_pretrained(TokenModel)config = BertConfig.from_pretrained(TokenModel)model_checkpoint = "fnlp/bart-large-chinese"print(model_checkpoint)max_input_length = 512 # input, source text 注意长度,复旦BART中文预训练模型使用的bert tokenizer
max_target_length = 128 # summary, target textdef preprocess_function(examples):inputs = [doc for doc in examples["document"]]model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)# Setup the tokenizer for targetswith tokenizer.as_target_tokenizer():labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)model_inputs["labels"] = labels["input_ids"]return model_inputsraw_datasets = dataset
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)train_data_txt, validation_data_txt = dataset.train_test_split(test_size=0.1,shuffle=True,seed=42).values()
train_data_txt, test_data_tex = train_data_txt.train_test_split(test_size=0.1,shuffle=True,seed=42).values()
# 装载数据
dd = datasets.DatasetDict({"train":train_data_txt,"validation": validation_data_txt,"test":test_data_tex }) raw_datasets = dd
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)logger = logging.get_logger(__name__)batch_size = 4
args = Seq2SeqTrainingArguments(output_dir="results",num_train_epochs=50,  # demodo_train=True,do_eval=True,per_device_train_batch_size=batch_size,  # demoper_device_eval_batch_size=batch_size,learning_rate=1e-04,warmup_steps=500,weight_decay=0.001,label_smoothing_factor=0.1,predict_with_generate=True,logging_dir="logs",logging_steps=500,evaluation_strategy="epoch",save_total_limit=3,# generation_max_length最大生成长度,系统默认20 generation_num_beams=1表示贪心解码,大于1为树搜索generation_max_length=64,generation_num_beams=1,
)data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)# 这里用的是中文lawrouge 至于字符级还是词级计算看自己调整 这里是字符级
def compute_metrics(eval_pred):predictions, labels = eval_preddecoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)# Replace -100 in the labels as we can't decode them.labels = np.where(labels != -100, labels, tokenizer.pad_token_id)decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)decoded_preds = ["".join(pred.replace(" ", "")) for pred in decoded_preds]decoded_labels = ["".join(label.replace(" ", "")) for label in decoded_labels]# Rouge with jieba cut# decoded_preds = [" ".join(jieba.cut(pred.replace(" ", ""))) for pred in decoded_preds]# decoded_labels = [" ".join(jieba.cut(label.replace(" ", ""))) for label in decoded_labels]labels_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in labels]# length = len(prediction_lens)# print(decoded_preds)# print(decoded_labels)rouge = lawrouge.Rouge()result = rouge.get_scores(decoded_preds, decoded_labels,avg=True)# print(result)print(result)result = {'rouge-1': result['rouge-1']['f'], 'rouge-2': result['rouge-2']['f'], 'rouge-l': result['rouge-l']['f']}result = {key: value * 100 for key, value in result.items()}return result;trainer = Seq2SeqTrainer(model,args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["validation"],data_collator=data_collator,tokenizer=tokenizer,compute_metrics=compute_metrics
)train_result = trainer.train()
print(train_result)trainer.save_model()
metrics = train_result.metrics
trainer.log_metrics("train",metrics)
trainer.save_metrics("train",metrics)
trainer.save_state()import torch
model.load_state_dict(torch.load('./results/pytorch_model.bin'))def generate_summary(test_samples, model):inputs = tokenizer(test_samples,padding="max_length",truncation=True,max_length=max_input_length,return_tensors="pt",)input_ids = inputs.input_ids.to(model.device)attention_mask = inputs.attention_mask.to(model.device)outputs = model.generate(input_ids, attention_mask=attention_mask,max_length=128)print(outputs)output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)return outputs, output_str_,x = generate_summary("20日凌晨,寒风刺骨,两名年纪相仿的婴儿相继被狠心的父母遗弃在翔安的两个角落,一个在莲花总院厕所里,一个在东园社区一榕树下。两名婴儿被发现时间相距不过10分钟,莲河边防派出所民警连夜走访,未寻得婴儿家属。目前,一名婴儿已被送往福利院,另一名暂时安置在村民家中。据悉,经医生初步检查,两名婴儿均身体健康,无残疾、无疾病。记者陈佩珊通讯员蔡美娟林才龙",model)
print(x)
print(len(x[0]))'''
tensor([[ 102,  101, 1336, 7305, 5425, 2128,  697, 1399, 2399, 5279, 4685,  820,4638, 2048, 1036, 4685, 5326, 6158, 6890, 2461, 1762,  697,  702, 6235,5862,  117,  671,  782, 1762, 5813, 5709, 2600, 7368, 1329, 2792, 7027,117,  671,  702, 1762,  691, 1736, 4852, 1277,  671, 3525, 3409,  678,511,  102]], device='cuda:0')
['厦 门 翔 安 两 名 年 纪 相 仿 的 婴 儿 相 继 被 遗 弃 在 两 个 角 落, 一 人 在 莲 花 总 院 厕 所 里, 一 个 在 东 园 社 区 一 榕 树 下 。']
91
'''eval_results = trainer.evaluate()
print(eval_results)

2 测试

import loggingimport jieba
import lawrouge
import numpy as np
import datasets
import torch
from datasets import load_dataset, Dataset
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge
from torch.utils.data import dataloader
from tqdm import tqdm
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, BertTokenizer, \BartForConditionalGenerationlogger = logging.getLogger("bart-base-chinese")
logging.basicConfig(level=logging.INFO)dataset = load_dataset('json', data_files=[r'E:\zwj\nlp\datasets\Summarization\NLPCC2017\evaluation_with_ground_truth_origianl.txt'], field='data')tokenizer = BertTokenizer.from_pretrained(r'checkpoint-337500')
model = BartForConditionalGeneration.from_pretrained(r'checkpoint-337500')def flatten(example):return {"document": example["article"],"summary": example["summarization"],"id": "0"}dataset = dataset["train"].map(flatten, remove_columns=["summarization", "article"])max_input_length = 512
max_target_length = 64
model_inputs = tokenizer(dataset[0]["document"], max_length=max_input_length,padding="max_length", truncation=True)def preprocess_function(examples):inputs = [doc for doc in examples["document"]]model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)with tokenizer.as_target_tokenizer():labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)model_inputs["labels"] = labels["input_ids"]return  model_inputs# 装载数据
dd = datasets.DatasetDict({"test": dataset})
test_batch_size = 1
tokenized_datasets = dd.map(preprocess_function, batched=True)
args = Seq2SeqTrainingArguments(fp16 = True,output_dir=r'./',do_eval=True,per_device_eval_batch_size=test_batch_size,label_smoothing_factor=0.1,predict_with_generate=True,
)
trainer = Seq2SeqTrainer(model,args,train_dataset=tokenized_datasets["test"],eval_dataset=tokenized_datasets["test"],data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),tokenizer=tokenizer,
)eval_dataloader = trainer.get_eval_dataloader()model.to("cuda:0")rouge = Rouge()
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0num_return_sequences = 4for i,batch in enumerate(tqdm(eval_dataloader)):model.eval()with torch.no_grad():output = model.generate(input_ids=batch["input_ids"].to("cuda:0"),attention_mask=batch["attention_mask"].to("cuda:0"),early_stopping=True,num_beams=num_return_sequences,length_penalty=1.,# no_repeat_ngram_size=0,# diversity_penalty=0.,num_return_sequences=1,num_beam_groups=1,max_length=64,)#lsp = []for s in range(test_batch_size):lsp.append(int(s * num_return_sequences)) # num_return_sequencesoutputs = output[lsp,:]labels = batch["labels"]labels = np.where(labels != -100, labels, tokenizer.pad_token_id)decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)decoded_preds = [" ".join((pred.replace(" ", ""))) for pred in decoded_preds]decoded_labels = [" ".join((label.replace(" ", ""))) for label in decoded_labels]for decoded_label, decoded_pred in zip(decoded_labels, decoded_preds):scores = rouge.get_scores(hyps=decoded_pred, refs=decoded_label)rouge_1 += scores[0]['rouge-1']['f']rouge_2 += scores[0]['rouge-2']['f']rouge_l += scores[0]['rouge-l']['f']bleu += sentence_bleu(references=[decoded_label.split(' ')],hypothesis=decoded_pred.split(' '),smoothing_function=SmoothingFunction().method1)
bleu /= len(dataset)
rouge_1 /= len(dataset)
rouge_2 /= len(dataset)
rouge_l /= len(dataset)

3 数据格式

这里强调一下,这里使用的复旦fnlp/bart-large-chinese的bart中文预训练模型,数据集是nlpcc,五万条数据,只训练epoch=1,rouge-1 49还算可以吧,感谢复旦提供预训练模型。

lcsts的数据集也可以直接训练,预处理方式见上一片博客。

# 用PART_I做训练的,PART_II做验证的,epoch = 1# 给几个测试的结果,有的确实很不错,但有的就很糟糕了,颠倒是非了
'''
(1)
中信证券员工迎娶世界小姐张梓琳的消息,瞬间引爆投行圈。何许人能娶到世界小姐?据爆料,张梓琳的新郎名叫聂磊(Neil),目前是中信证券债务资本市场部的SVP(高级副总裁)。有投行圈人士感叹“我追高圆圆也有戏啊”。
生成:中 信 证 券 员 工 迎 娶 世 界 小 姐 投 行 圈 人 士
参考:中信员工迎娶世界小姐张梓琳投行男直呼励志
(2)此前曾宣布退出线下POS市场的支付宝,近日正在宁夏、江西等地布局线下支付业务。此举被一些业内人士解读为向“银联发起总攻”。一位业内分析人士更是指出,支付宝甚至可能挑战中国银联在线下支付市场的主导地位,“成为第二家‘银联’”
生成:支 付 宝 向 银 联 发 起 总 攻
参考:支付宝重返线下市场二维码支付监管标准公布可期
(3)上海质局对上海生产和销售的水嘴产品质量进行了专项监督抽查,共抽查水嘴产品68批次,不合格产品达21批次,其中包括成霖洁具等。3成不合格产品中6批次产品经检测析出过量的铅。过量的铅会损害人的神经系统、造血系统,甚至生殖系统。
生成:成 霖 洁 具 等 6 批 次 水 嘴 产 品 检 出 过 量 铅
参考:成霖洁具等被爆铅超标或损害造血生殖系统
(4)韩方应对路径可以概括为:企业道歉担责;政府公正不护短;民间祈福关怀。他们深知形象的重要,竭力呵护企业品牌和国家形象。正如有评论,韩国“政府+企业+民众”三位一体式呵护韩国国家形象的“苦心经营”,的确有值得我们借鉴之处。
生成:韩 国 国 家 形 象 的 苦 心 经 营
参考:从韩亚航空事故看其应对路径
(5)63岁退休教师谢淑华,拉着人力板车,历时1年,走了2万4千里路,带着年过九旬的妈妈环游中国,完成了妈妈“一辈子在锅台边转,也想出去走走”的心愿。她说:“妈妈愿意出去走走,我就愿意拉着,孝心不能等,能走多远就走多远。
生成:63 岁 女 教 师 拉 板 车 带 九 旬 母 亲 游 中 国
参考:女子用板车拉九旬老母环游中国1年走2万4千里
'''

我还测试了复旦cpt模型,效果感觉不太行阿。

细节不同的地方如下:

from modeling_cpt import CPTForConditionalGeneration
model_checkpoint = "fnlp/cpt-large"

注意不要搞错了,具体见下一篇

ps:注意计算metric是的rouge,有两个库都可以计算,rouge和lawrouge

(1)不分词:直接调用lawrouge中文计算函数就行不用分词

(2)rouge库需要将中文用空格分开(不分词+空格隔开),如果要计算分词后的rouge值也要用rouge库(分词+空格隔开)

(3)注意:lawrouge比rouge库计算指标低一些,做论文一定要注意!

import lawrouge
from rouge import Rouge
hypothesis = ['我 爱 最 美 的 中 国','老 夫 子']
reference = ['我 爱 中 国','你 好 老 夫子']rouge = Rouge()
scores = rouge.get_scores(hypothesis, reference)
print(scores)rouge = lawrouge.Rouge()
scores = rouge.get_scores(['我爱最美的中国','老夫子'], ['我爱中国','你好老夫子'])
print(scores)'''
[{'rouge-1': {'r': 1.0, 'p': 0.5714285714285714, 'f': 0.7272727226446282}, 'rouge-2': {'r': 0.6666666666666666, 'p': 0.3333333333333333, 'f': 0.44444444000000005}, 'rouge-l': {'r': 1.0, 'p': 0.5714285714285714, 'f': 0.7272727226446282}}, {'rouge-1': {'r': 0.6, 'p': 1.0, 'f': 0.7499999953125}, 'rouge-2': {'r': 0.5, 'p': 1.0, 'f': 0.6666666622222223}, 'rouge-l': {'r': 0.6, 'p': 1.0, 'f': 0.7499999953125}}]
[{'rouge-1': {'f': 0.7272727226446282, 'p': 0.5714285714285714, 'r': 1.0}, 'rouge-2': {'f': 0.44444444000000005, 'p': 0.3333333333333333, 'r': 0.6666666666666666}, 'rouge-l': {'f': 0.7272727226446282, 'p': 0.5714285714285714, 'r': 1.0}}, {'rouge-1': {'f': 0.7499999953125, 'p': 1.0, 'r': 0.6}, 'rouge-2': {'f': 0.6666666622222223, 'p': 1.0, 'r': 0.5}, 'rouge-l': {'f': 0.7499999953125, 'p': 1.0, 'r': 0.6}}]
'''

pps:transformers==4.4.1 pytorch = 1.10.0


给大家提供一个经过简单清洗的CNewSum中文摘要数据集,数据集共包含275596条摘要数据,源文件是经过分词的,我这里没有分词,先提供一个训练集。

百度网盘链接:CNewSum_trainhttps://pan.baidu.com/s/1FXq_deJ0Yi9rfSAElQLoIQ

提取码:b6f9


nlpcc2017清洗数据

链接:nlpcc2017_cleanhttps://pan.baidu.com/s/1qXC81XmcWY9GprQKe5i58whttps://pan.baidu.com/s/1qXC81XmcWY9GprQKe5i58w 
提取码:knci


提供一个经过层次位置分解编码的BART中文预训练权重支持最大输入长度为1024,修改配置文件config.json的max_position_embeddings=1024即可

百度网盘链接:bart-large-chinese-1024https://pan.baidu.com/s/1xisndfl27sOm1YyZzp-vYw

提取码:b6f9


参考文献:pytorch 使用BART模型进行中文自动摘要_keep-hungry的博客-CSDN博客

复旦bart中文预训练模型:小: https://huggingface.co/fnlp/bart-base-chinese/tree/main

大:

复旦cpt中文预训练模型:fnlp/cpt-large · Hugging Face

BART中文摘要生成,(nplcc与LCSTS数据集)相关推荐

  1. 如何用pytorch做文本摘要生成任务(加载数据集、T5 模型参数、微调、保存和测试模型,以及ROUGE分数计算)

    摘要:如何使用 Pytorch(或Pytorchlightning) 和 huggingface Transformers 做文本摘要生成任务,包括数据集的加载.模型的加载.模型的微调.模型的验证.模 ...

  2. 中文摘要生成 综述

    GPT2-中文摘要生成青空栀浅:https://zhuanlan.zhihu.com/p/113869509 IJCAI 2018 | 腾讯知文等提出新型生成式摘要模型:结合主题信息和强化训练生成更优 ...

  3. 中文自动文本摘要生成指标计算,Rouge/Bleu/BertScore/QA代码实现

    本部分讲述下如何计算生成摘要与参考摘要的指标,指标方面分为两类,一类基于n-grams计算,如Rouge-1,Rouge-2,Rouge-L,BLEU,主要衡量摘要的句法的连贯性,不能衡量生成摘要的真 ...

  4. 基于BERT-PGN模型的中文新闻文本自动摘要生成——文本摘要生成(论文研读)

    基于BERT-PGN模型的中文新闻文本自动摘要生成(2020.07.08) 基于BERT-PGN模型的中文新闻文本自动摘要生成(2020.07.08) 摘要: 0 引言 相关研究 2 BERT-PGN ...

  5. 知识图谱如何助力文本摘要生成

    来源:丁香园大数据 本文约3800字,建议阅读8分钟 本文基于摘要生成,重点考虑如何帮助模型生成特定领域的知识点,并简要介绍一些用于应对无关重复这类退化现象的方案. 引言 文本生成类任务应用场景广泛, ...

  6. ACL 2021 | SimCLS: 概念简单但足够有效的对比学习摘要生成框架

    ©PaperWeekly 原创 · 作者 | Maple小七 学校 | 北京邮电大学硕士生 研究方向 | 自然语言处理 作者提出了一个概念简单但足够有效的摘要生成框架:SimCLS,在当前的 SOTA ...

  7. COLING 2020 | 面向医疗对话的摘要生成

    ©PaperWeekly 原创 · 作者|李东明 学校|香港中文大学(深圳)本科生 研究方向|文本生成 摘要 医疗对话是一类特殊的对话形态,属于任务驱动型的对话场景,通常包含极为关键的病人求诊信息以及 ...

  8. 优于人类参考摘要,适用CNN新闻,OpenAI用人类反馈提升了摘要生成质量

    选自arXiv 作者:Nisan Stiennon 等 机器之心编译 编辑:杜伟.小舟.陈萍 近日,来自 OpenAI 的研究者利用人类反馈优化了文本摘要生成模型,该模型生成的摘要质量显著提升,并且可 ...

  9. TensorFlow文本摘要生成 - 基于注意力的序列到序列模型

    1 相关背景 维基百科对自动摘要生成的定义是, "使用计算机程序对一段文本进行处理, 生成一段长度被压缩的摘要, 并且这个摘要能保留原始文本的大部分重要信息". 摘要生成算法主要分 ...

最新文章

  1. java gps时间转换工具_java 时间戳和时间互转工具 和 时间偏移量计算
  2. 通过cat /proc/cpuinfo看处理器特点
  3. [转]使用批处理设置、启动和停止服务
  4. Kernel PCI总线框架
  5. Ubuntu16.04中WPS不能输入中文
  6. sql_1-2_get基于盲注
  7. Excel VBA实现批量创建链接
  8. 圆形比例分布图怎么做_解读宝山区2035总体规划:建设用地的比例在上海非中心城区中最高...
  9. 前端学习(3158):react-hello-react之一个简单的helloworld
  10. 修改MySQL自动递增值
  11. 支付宝升级商家积分等服务 商家积分权益增至60余种
  12. Fluent Ribbon项目出现“命名空间“clr-namespace:Fluent;assembly=Fluent”中不存在“RibbonWindow”名称”的解决方法...
  13. paip.手机ROOT过程总结
  14. r 中文乱码_配置R语言环境,这一篇就够了!
  15. el-table中使用el-popover点击取消按钮时popover框的显示与隐藏问题
  16. IMPERVA - WAF syslog配置及注意事项
  17. unix time stamp(时间戳)和常规时间相互转换的C++代码
  18. WEB 期末项目 小黑帽书屋
  19. 对链特异性建库的理解
  20. 想进世界顶尖投行 那我该上什么样的学校?

热门文章

  1. linux下route路由设置命令详解
  2. [网赚项目] 分享一个刚需赚钱项目,可多重变现,月入好几个w
  3. 如何解决 MacBook 电池耗电问题
  4. openCV之waitKey函数简介
  5. KernelSU: 内核 ROOT 方案, KernelSU KernelSU KernelSU 新的隐藏root防止检测 封号方案
  6. 我TM究竟应该选哪个版本的MySQL?!
  7. 讨论 | 博士延期毕业?如何避免?
  8. TM4C123G学习记录(2)--GPIO
  9. MongoDB数据库设计中6条重要经验法则 Part 2
  10. 开启xmp1还是2_英雄联盟手游高帧率模式怎么开启-高帧率模式开启方法