基于TNEWS' 今日头条中文新闻(短文本)分类

  • 数据部分
    • 内容
    • 数据处理
  • 模型构建
    • 训练配置
  • 模型训练和预测
    • 定义评估函数
    • 训练
    • 预测

数据部分

内容

TNEWS’今日头条中文新闻数据集来自今日头条的新闻版块,共提取了15个类别的新闻,包括旅游,教育,金融,军事等。数据量:训练集(53,360),验证集(10,000),测试集(10,000)

数据处理

将文本数据转换成id

# 转换成id的函数
def convert_example(example, tokenizer):encoded_inputs = tokenizer(text=example["sentence"], max_seq_len=128, pad_to_max_seq_len=True)return tuple([np.array(x, dtype="int64") for x in [encoded_inputs["input_ids"], encoded_inputs["token_type_ids"], [example["label"]]]])# 把训练集合转换成id
train_ds = train_ds.map(partial(convert_example, tokenizer=tokenizer))
# 把验证集合转换成id
dev_ds = dev_ds.map(partial(convert_example, tokenizer=tokenizer))
# 构建训练集合的dataloader
train_batch_size=32
dev_batch_size=32
train_batch_sampler = paddle.io.DistributedBatchSampler(dataset=train_ds, batch_size=train_batch_size, shuffle=True)
train_data_loader = paddle.io.DataLoader(dataset=train_ds, batch_sampler=train_batch_sampler, return_list=True)# 针对验证集数据加载,我们使用单卡进行评估,所以采用 paddle.io.BatchSampler 即可
# 定义验证集的dataloader
dev_batch_sampler = paddle.io.BatchSampler(dev_ds, batch_size=dev_batch_size, shuffle=False)dev_data_loader = paddle.io.DataLoader(dataset=dev_ds,batch_sampler=dev_batch_sampler,return_list=True)

模型构建

class ShortTextClassification(nn.Layer):def __init__(self, pretrained_model,num_class,dropout=None):super().__init__()self.ptm = pretrained_modelself.dropout = nn.Dropout(dropout if dropout is not None else 0.1)# num_labels = 2 (similar or dissimilar)self.classifier = nn.Linear(self.ptm.config["hidden_size"], num_class)def forward(self,input_ids,token_type_ids=None,position_ids=None,attention_mask=None):_, cls_embedding = self.ptm(input_ids, token_type_ids, position_ids,attention_mask)cls_embedding = self.dropout(cls_embedding)logits = self.classifier(cls_embedding)return logits
model = ShortTextClassification(pretrained_model,num_class=len(train_ds.label_list))

训练配置

epochs = 3
num_training_steps = len(train_data_loader) * epochs# 定义 learning_rate_scheduler,负责在训练过程中对 lr 进行调度
lr_scheduler = LinearDecayWithWarmup(2E-5, num_training_steps, 0.0)# 训练结束后,存储模型参数
save_dir ="checkpoint"
# 创建保存的文件夹
os.makedirs(save_dir,exist_ok=True)# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [p.name for n, p in model.named_parameters()if not any(nd in n for nd in ["bias", "norm"])
]# 定义 Optimizer
optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,parameters=model.parameters(),weight_decay=0.0,apply_decay_param_fun=lambda x: x in decay_params)
# 交叉熵损失
criterion = paddle.nn.loss.CrossEntropyLoss()
# 评估的时候采用准确率指标
metric = paddle.metric.Accuracy()

模型训练和预测

定义评估函数

# 因为训练过程中同时要在验证集进行模型评估,因此我们先定义评估函数
@paddle.no_grad()
def evaluate(model, criterion, metric, data_loader, phase="dev"):model.eval()metric.reset()losses = []for batch in data_loader:input_ids, token_type_ids, labels = batchprobs = model(input_ids=input_ids, token_type_ids=token_type_ids)# 计算损失loss = criterion(probs, labels)losses.append(loss.numpy())# 计算准确率correct = metric.compute(probs, labels)#准确率更新metric.update(correct)accu = metric.accumulate()print("eval {} loss: {:.5}, accu: {:.5}".format(phase,np.mean(losses), accu))model.train()metric.reset()return np.mean(losses),accu

训练

def do_train(model, criterion, metric, dev_data_loader,train_data_loader):global_step = 0tic_train = time.time()best_accuracy=0.0for epoch in range(1, epochs + 1):for step, batch in enumerate(train_data_loader, start=1):input_ids, token_type_ids, labels = batchprobs = model(input_ids=input_ids, token_type_ids=token_type_ids)loss = criterion(probs, labels)correct = metric.compute(probs, labels)metric.update(correct)acc = metric.accumulate()global_step += 1 # 每间隔 100 step 输出训练指标if global_step % 100 == 0:print("global step %d, epoch: %d, batch: %d, loss: %.5f, accu: %.5f, speed: %.2f step/s"% (global_step, epoch, step, loss, acc,10 / (time.time() - tic_train)))tic_train = time.time()loss.backward()optimizer.step()lr_scheduler.step()optimizer.clear_grad()# 每间隔 100 step 在验证集和测试集上进行评估if global_step % 500 == 0:eval_loss,eval_accu=evaluate(model, criterion, metric, dev_data_loader, "dev")if(best_accuracy<eval_accu):best_accuracy=eval_accu# 保存模型save_param_path = os.path.join(save_dir, 'model_best.pdparams')paddle.save(model.state_dict(), save_param_path)# 保存tokenizertokenizer.save_pretrained(save_dir)do_train(model, criterion, metric, dev_data_loader,train_data_loader)

预测

state_dict=paddle.load('checkpoint/model_best.pdparams')
model.load_dict(state_dict)
# 测试集可以选择 test,test1.0两个
test_ds = load_dataset('clue', task_name, splits=['test1.0'])
def do_predict(model,example):# 把文本转换成input_ids,token_type_ids# example=test_ds[0]encoded_text = tokenizer(text=example["sentence"], max_seq_len=512, pad_to_max_seq_len=True)# 把input_ids变成paddle tensorinput_ids = paddle.to_tensor([encoded_text['input_ids']])# 把token_type_ids变成paddle tensorsegment_ids = paddle.to_tensor([encoded_text['token_type_ids']])# 模型预测pooled_output = model(input_ids, segment_ids)# 取概率值最大的索引out2 = paddle.argmax(pooled_output, axis=1)fgh# print('预测的label标签为 {}'.format(out2.numpy()[0]))# print('真实的label标签为 {}'.format(test_ds[0]['label']))return out2.numpy()[0]predict_label=[]
for i in tqdm(range(len(test_ds))):example=test_ds[i]label_pred=do_predict(model,example)predict_label.append(label_pred)
output_submit_file = "tnews10_predict.json"
label_map = {i: label for i, label in enumerate(train_ds.label_list)}
# 保存标签结果
with open(output_submit_file, "w") as writer:for i, pred in enumerate(predict_label):json_d = {}json_d['id'] = ijson_d['label'] = str(label_map[pred])writer.write(json.dumps(json_d) + '\n')

基于TNEWS‘ 今日头条中文新闻(短文本)分类相关推荐

  1. TNEWS今日头条中文新闻(短文本)分类

    数据分析: 数据分布情况:占比例多少 文本 :一句话多长.截断处理. 平均50个字<=110词.(分完词的长度) 多分类: (代码可复现) 方法一:6个二分类器.一条句子分别跑6个模型,分最高, ...

  2. 今日头条中文新闻文本(多层)分类数据集(NLP/文本分类)

    这是另一个数据集的加强版,为多级分类,分类更全(含1000+多级分类),量更大. 数据来源: 今日头条客户端 文本多层分类的概念见下图 数据格式: 1000866069|,|tip,news|,|[互 ...

  3. 基于BERT-PGN模型的中文新闻文本自动摘要生成——文本摘要生成(论文研读)

    基于BERT-PGN模型的中文新闻文本自动摘要生成(2020.07.08) 基于BERT-PGN模型的中文新闻文本自动摘要生成(2020.07.08) 摘要: 0 引言 相关研究 2 BERT-PGN ...

  4. 今日头条的新闻推荐算法原理

    转自: http://www.sohu.com/a/217514835_488163 信息越来越海量,用户获取信息越来越茫然,而推荐算法则能有助于更好的匹配海量内容和用户需求,使之更加的"有 ...

  5. 【干货】今日头条的新闻推荐算法原理

    信息越来越海量,用户获取信息越来越茫然,而推荐算法则能有助于更好的匹配海量内容和用户需求,使之更加的"有的放矢" .为让产业各方更好的了解算法分发的相关技术和原理,我们特整理了当下 ...

  6. python编程100例头条-python 简单爬取今日头条热点新闻(一)

    今日头条如今在自媒体领域算是比较强大的存在,今天就带大家利用python爬去今日头条的热点新闻,理论上是可以做到无限爬取的: 在浏览器中打开今日头条的链接,选中左侧的热点,在浏览器开发者模式netwo ...

  7. python 爬取今日头条热点新闻

    嗯,今天就让我们来一起爬爬今日头条的热点新闻吧! 今日头条地址:https://www.toutiao.com/ch/news_hot/ 在浏览器中打开今日头条的链接,选中左侧的热点,在浏览器开发者模 ...

  8. python爬虫今日头条_python 简单爬取今日头条热点新闻(

    今日头条如今在自媒体领域算是比较强大的存在,今天就带大家利用python爬去今日头条的热点新闻,理论上是可以做到无限爬取的: 在浏览器中打开今日头条的链接,选中左侧的热点,在浏览器开发者模式netwo ...

  9. python爬虫爬取今日头条_python 简单爬取今日头条热点新闻(一)

    今日头条如今在自媒体领域算是比较强大的存在,今天就带大家利用python爬去今日头条的热点新闻,理论上是可以做到无限爬取的: 在浏览器中打开今日头条的链接,选中左侧的热点,在浏览器开发者模式netwo ...

最新文章

  1. 回望2018,展望2019
  2. G - IP地址转换
  3. 卷积神经网络中的参数计算
  4. USB学习5---android usb驱动源代码目录说明
  5. 浅谈文献总结(2018.9.28)——坚恒勇毅论文课笔记
  6. 图片跟着鼠标_刷完几百张网易云Banner,我发现了2个PPT图片处理的大招!
  7. 【软考-软件设计师】CPU的功能
  8. Netty实战 IM即时通讯系统(八)服务端和客户端通信协议编解码
  9. python mro文件_Python MRO
  10. 苹果将削减iPhone SE及AirPods产量 iPhone 13也要求减产
  11. 一个maven错误:org/apache/maven/shared/filtering/MavenFilteringException
  12. Perl复制、移动、重命名文件/目录
  13. php CSRF攻击与防御
  14. c语言程序设计立体化教程,C语言程序设计立体化教程
  15. 第1关:身份证归属地查询
  16. 如何将qlv格式视频转换成mp4格式
  17. java多线程系列(一)
  18. 网络游戏服务器开发(一)
  19. 讯飞语义相似度baseline
  20. TP-link WR720N路由器刷入OpenWrt

热门文章

  1. 技术的本质与内在规律:进化
  2. Android原生分屏,原生ROM都有分屏 为啥MIUI做了那么久?
  3. 【C++】CString转LPCSTR
  4. 通达信接口被暂停是什么原因?
  5. 2021国家开放大学计算机网络安全技术形成性考核一
  6. 《程序员》10月:浪漫的计算机科学家Alan Kay
  7. office2016/2019版本打开时出现卡顿的解决办法
  8. 住房公积金个人缴存2000/月,缴存比例12%,是不是就代表个人每年的全部收入为20W?
  9. 枪火重生灵界狂潮攻略(七)猴子流派
  10. cursor设置为自定义图片