基于检索的多层次文本分类

1、项目说明

以前的分类任务中,标签信息作为无实际意义,独立存在的one-hot编码形式存在,这种做法会潜在的丢失标签的语义信息,本方案把文本分类任务中的标签信息转换成含有语义信息的语义向量,将文本分类任务转换成向量检索和匹配的任务。这样做的好处是对于一些类别标签不是很固定的场景,或者需要经常有一些新增类别的需求的情况非常合适。另外,对于一些新的相关的分类任务,这种方法也不需要模型重新学习或者设计一种新的模型结构来适应新的任务。总的来说,这种基于检索的文本分类方法能够有很好的拓展性,能够利用标签里面包含的语义信息,不需要重新进行学习。这种方法可以应用到相似标签推荐,文本标签标注,金融风险事件分类,政务信访分类等领域。

本方案是基于语义索引模型的分类,语义索引模型的目标是:给定输入文本,模型可以从海量候选召回库中快速、准确地召回一批语义相关文本。基于语义索引的分类方法有两种,第一种方法是直接把标签变成召回库,即把输入文本和标签的文本进行匹配,第二种是利用召回的文本带有类别标签,把召回文本的类别标签作为给定输入文本的类别。本方案使用双塔模型,训练阶段引入In-batch Negatives 策略,使用hnswlib建立索引库,并把标签作为召回库,进行召回测试。最后利用召回的结果使用 Accuracy 指标来评估语义索引模型的分类的效果。下面用一张图来展示与传统的微调方案的区别,在预测阶段,微调的方式则是用分类器分类得到的结果,而基于检索的方式是通过比较文本和标签的相似度得到的分类结果。

本项目源代码全部开源在 PaddleNLP。

详细内容可参考文本分类系统方案:https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/text_classification

  • 如果对您有帮助,欢迎star收藏一下,不易走丢哦~链接指路:https://github.com/PaddlePaddle/PaddleNLP

1.1 应用特色

  • 低门槛

    • 手把手搭建层次化文本分类
  • 效果好

    • 业界领先的文本分类模型: ERNIE 3.0
    • 业界领先的检索预训练模型: RocketQA Dual Encoder
    • 针对数据量少,类别体系复杂,类别数目多,并且经常需要变动的场景领先解决方案
  • 性能快

    • 基于 Paddle Inference 快速抽取向量
    • 基于 Milvus 快速查询和高性能建库
    • 基于 Paddle Serving 高性能部署

1.2 文本分类流程设计

对于相对均衡的样本,选择分类的方式,极端不均衡和类别变化频繁的样本,使用向量索引的方式分类。基于检索的搜索方案的流程图如图所示,对于向量索引的模型,我们选用基于ERNIE3.0训练的RocketQA模型,然后把文本和标签构成句子对训练语义向量模型,训练结束后得到语义向量抽取模型,然后把标签放入模型抽取标签向量(Label Embedding),最后放入索引向量库中。紧接着对于需要分类的文本,利用语义向量抽取模型抽取句子语义向量(Sentence Embedding),最后利用ANN引擎从标签索引库中找最相似的标签向量,然后得到对应的标签,并把该标签当作最终的分类文本。

2、安装说明

AI Studio平台默认安装了Paddle和PaddleNLP,并定期更新版本。 如需手动更新,可参考如下说明:

# 首次更新完以后,重启后方能生效
!pip install --upgrade paddlenlp
# 安装相关的依赖包
!pip install -r requirements.txt
# 加载系统的API
import abc
import sys
from functools import partial
import argparse
import os
import random
import time
import numpy as np
# 加载飞桨的API
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import inference
# 加载PaddleNLP的API
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset, MapDataset
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.utils.downloader import get_path_from_url
import paddle_serving_client.io as serving_io
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer

3、数据准备

下载数据集:

if(not os.path.exists('baike_qa_category.zip')):get_path_from_url('https://paddlenlp.bj.bcebos.com/applications/baike_qa_category.zip',root_dir='.')!unzip -o baike_qa_category.zip
Archive:  baike_qa_category.zipinflating: data/label.txt          inflating: data/.DS_Store          inflating: data/dev.txt            inflating: data/train.txt

3.1 加载数据

from data import read_text_pair, convert_example, create_dataloader, gen_id2corpus, gen_text_file, convert_corpus_example
from paddlenlp.transformers import AutoModel, AutoTokenizer
train_set_file = 'data/train.txt'
train_ds = load_dataset(read_text_pair,data_path=train_set_file,lazy=False)
max_seq_length = 384
pretrained_model = AutoModel.from_pretrained('rocketqa-zh-dureader-query-encoder')
tokenizer = AutoTokenizer.from_pretrained('rocketqa-zh-dureader-query-encoder')trans_func = partial(convert_example,tokenizer=tokenizer,max_seq_length=max_seq_length)
batchify_fn = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),  # query_inputPad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'),  # query_segmentPad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),  # title_inputPad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'),  # tilte_segment): [data for data in fn(samples)]
# batch_size设置大一点,效果会更好一点
batch_size = 24
train_data_loader = create_dataloader(train_ds,mode='train',batch_size=batch_size,batchify_fn=batchify_fn,trans_fn=trans_func)
recall_result_dir = "recall_result_dir"
recall_result_file = "recall_result.txt"
evaluate_result = "evaluate_result.txt"
similar_text_pair_file = "data/dev.txt"
corpus_file = 'data/label.txt'
batchify_fn_dev = lambda samples, fn=Tuple(Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype='int64'),  # text_inputPad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype='int64'),  # text_segment): [data for data in fn(samples)]eval_func = partial(convert_example,tokenizer=tokenizer,max_seq_length=max_seq_length)
id2corpus = gen_id2corpus(corpus_file)
# conver_example function's input must be dict
corpus_list = [{idx: text} for idx, text in id2corpus.items()]
corpus_ds = MapDataset(corpus_list)
corpus_data_loader = create_dataloader(corpus_ds,mode='predict',batch_size=batch_size,batchify_fn=batchify_fn_dev,trans_fn=eval_func)
# convert_corpus_example
query_func = partial(convert_example,tokenizer=tokenizer,max_seq_length=max_seq_length)
text_list, _ = gen_text_file(similar_text_pair_file)
query_ds = MapDataset(text_list)
query_data_loader = create_dataloader(query_ds,mode='predict',batch_size=batch_size,batchify_fn=batchify_fn_dev,trans_fn=query_func)
if not os.path.exists(recall_result_dir):os.mkdir(recall_result_dir)
recall_result_file = os.path.join(recall_result_dir,recall_result_file)
# 打印标签数据,标签文本数据被映射成了ID的形式
for batch in corpus_data_loader:print(batch)break
[Tensor(shape=[24, 19], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,[[1   , 82  , 227 , 17963, 287 , 70  , 30  , 343 , 2275, 2   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 128 , 941 , 17963, 305 , 742 , 30  , 163 , 716 , 94  , 159 , 2   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 691 , 736 , 30  , 407 , 193 , 188 , 390 , 30  , 1553, 64  , 407 ,193 , 2   , 0   , 0   , 0   , 0   , 0   ],[1   , 80  , 227 , 17963, 147 , 18  , 30  , 137 , 405 , 18  , 489 , 30  ,139 , 405 , 2   , 0   , 0   , 0   , 0   ],[1   , 80  , 227 , 17963, 147 , 18  , 30  , 389 , 26  , 80  , 227 , 30  ,32  , 159 , 138 , 318 , 401 , 525 , 2   ],[1   , 691 , 736 , 30  , 103 , 147 , 30  , 1208, 813 , 103 , 147 , 2   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 21  , 205 , 30  , 188 , 494 , 17963, 2608, 3255, 30  , 1134, 954 ,17963, 661 , 737 , 2   , 0   , 0   , 0   ],[1   , 691 , 736 , 30  , 407 , 193 , 188 , 390 , 30  , 274 , 2180, 407 ,193 , 2   , 0   , 0   , 0   , 0   , 0   ],[1   , 80  , 227 , 17963, 147 , 18  , 30  , 38  , 35  , 18  , 147 , 30  ,111 , 38  , 18  , 2   , 0   , 0   , 0   ],[1   , 278 , 26  , 17963, 38  , 625 , 30  , 161 , 645 , 2   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 80  , 227 , 17963, 147 , 18  , 30  , 38  , 35  , 18  , 147 , 30  ,21  , 122 , 18  , 2   , 0   , 0   , 0   ],[1   , 1030, 320 , 30  , 29  , 320 , 30  , 371 , 979 , 2   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 82  , 227 , 17963, 287 , 70  , 30  , 2001, 497 , 2   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 691 , 736 , 30  , 407 , 193 , 188 , 390 , 30  , 821 , 325 , 188 ,390 , 2   , 0   , 0   , 0   , 0   , 0   ],[1   , 343 , 621 , 30  , 128 , 367 , 343 , 621 , 2   , 0   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 80  , 227 , 17963, 147 , 18  , 30  , 245 , 225 , 212 , 399 , 2   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 691 , 736 , 30  , 326 , 292 , 111 , 38  , 147 , 2   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 343 , 621 , 30  , 1015, 19  , 343 , 621 , 30  , 241 , 734 , 203 ,280 , 2   , 0   , 0   , 0   , 0   , 0   ],[1   , 80  , 227 , 17963, 147 , 18  , 30  , 8   , 68  , 18  , 147 , 30  ,405 , 545 , 18  , 2   , 0   , 0   , 0   ],[1   , 80  , 227 , 17963, 147 , 18  , 30  , 389 , 26  , 80  , 227 , 30  ,53  , 112 , 141 , 401 , 525 , 2   , 0   ],[1   , 691 , 736 , 30  , 137 , 147 , 30  , 2249, 1601, 137 , 147 , 2   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 691 , 736 , 30  , 760 , 1411, 147 , 30  , 760 , 1411, 396 , 2   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 343 , 621 , 30  , 1032, 98  , 343 , 621 , 2   , 0   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ],[1   , 1599, 354 , 30  , 604 , 799 , 17963, 287 , 660 , 2   , 0   , 0   ,0   , 0   , 0   , 0   , 0   , 0   , 0   ]]), Tensor(shape=[24, 19], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])]
# 分类文本数据,文本数据被映射成了ID的形式
for batch in query_data_loader:print(batch)break
[Tensor(shape=[24, 384], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,[[1   , 647 , 358 , ..., 0   , 0   , 0   ],[1   , 75  , 883 , ..., 0   , 0   , 0   ],[1   , 2308, 889 , ..., 0   , 0   , 0   ],...,[1   , 1228, 1928, ..., 0   , 0   , 0   ],[1   , 252 , 560 , ..., 0   , 0   , 0   ],[1   , 76  , 1234, ..., 0   , 0   , 0   ]]), Tensor(shape=[24, 384], dtype=int64, place=Place(gpu_pinned), stop_gradient=True,[[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],...,[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0],[0, 0, 0, ..., 0, 0, 0]])]

4、 模型选择

模型 Accuracy 策略简要说明
ernie-3.0-medium-zh 50.580 ernie-3.0-medium-zh多分类,5个epoch,对于新增类别需要重新训练
In-batch Negatives + RocketQA 49.755 Inbatch-negative有监督训练,标签当作召回集,对新增类别不需要重新训练
In-batch Negatives + RocketQA + 投票 51.756 Inbatch-negative有监督训练,训练集当作召回集,对新增类别,需要至少一条的数据放入召回库中

对于类别体系不固定,且未来有新增类别可能性的情况,我们推荐使用RocketQA的语义检索的方案。

【注意】对于新增的类别,不需要重新训练检索模型,只需要把新增类别的标签抽取成向量,放入语义索引库中,就能够继续使用了。

margin = 0.2
scale = 20
# 设置为0,默认不对语义向量进行降维度
output_emb_size = 0
from model import SemanticIndexBatchNeg
model = SemanticIndexBatchNeg(pretrained_model,margin=margin,scale=scale,output_emb_size=output_emb_size)model = paddle.DataParallel(model)

5、模型训练

5.1 参数配置

# 正常需要设置50~100 epoch左右,为了演示方便,这里设置成了10
epochs = 5
learning_rate = 5E-5
warmup_proportion = 0
save_dir ='checkpoints_recall'
weight_decay = 0.0
log_steps= 100
recall_num = 20
num_training_steps = len(train_data_loader) * epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps,warmup_proportion)
decay_params = [p.name for n, p in model.named_parameters()if not any(nd in n for nd in ["bias", "norm"])]
optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,parameters=model.parameters(),weight_decay=weight_decay,apply_decay_param_fun=lambda x: x in decay_params,grad_clip=nn.ClipGradByNorm(clip_norm=1.0))

5.2 模型评估

from data import build_index
# 索引的大小
hnsw_max_elements=1000000
# 控制时间和精度的平衡参数
hnsw_ef=100
hnsw_m=100
output_emb_size = 0
recall_num = 10def recall(rs, N=10):recall_flags = [np.sum(r[0:N]) for r in rs]return np.mean(recall_flags)@paddle.no_grad()
def evaluate(model, corpus_data_loader, query_data_loader, recall_result_file,text_list, id2corpus):# Load pretrained semantic modelinner_model = model._layersfinal_index = build_index(corpus_data_loader, inner_model,output_emb_size=output_emb_size,hnsw_max_elements=hnsw_max_elements,hnsw_ef=hnsw_ef,hnsw_m=hnsw_m)query_embedding = inner_model.get_semantic_embedding(query_data_loader)with open(recall_result_file, 'w', encoding='utf-8') as f:for batch_index, batch_query_embedding in enumerate(query_embedding):recalled_idx, cosine_sims = final_index.knn_query(batch_query_embedding.numpy(), recall_num)batch_size = len(cosine_sims)for row_index in range(batch_size):text_index = batch_size * batch_index + row_indexfor idx, doc_idx in enumerate(recalled_idx[row_index]):f.write("{}\t{}\t{}\n".format(text_list[text_index]["text"], id2corpus[doc_idx],1.0 - cosine_sims[row_index][idx]))text2similar = {}with open(similar_text_pair_file, 'r', encoding='utf-8') as f:for line in f:text_arr = line.rstrip().rsplit("\t")text, similar_text = text_arr[0], text_arr[1].replace('##', ',')text2similar[text] = similar_textrs = []with open(recall_result_file, 'r', encoding='utf-8') as f:relevance_labels = []for index, line in enumerate(f):if index % recall_num == 0 and index != 0:rs.append(relevance_labels)relevance_labels = []text_arr = line.rstrip().rsplit("\t")text, similar_text, cosine_sim = text_arrif text2similar[text] == similar_text:relevance_labels.append(1)else:relevance_labels.append(0)recall_N = []recall_num_list = [1, 5, 10, 20]for topN in recall_num_list:R = round(100 * recall(rs, N=topN), 3)recall_N.append(str(R))evaluate_result_file = os.path.join(recall_result_dir,evaluate_result)result = open(evaluate_result_file, 'a')res = []timestamp = time.strftime('%Y%m%d-%H%M%S', time.localtime())res.append(timestamp)for key, val in zip(recall_num_list, recall_N):print('recall@{}={}'.format(key, val))res.append(str(val))result.write('\t'.join(res) + '\n')return float(recall_N[0])

5.3 启动训练

save_root_dir='checkpoints'
global_step = 0
best_recall = 0.0
tic_train = time.time()
for epoch in range(1, epochs + 1):for step, batch in enumerate(train_data_loader, start=1):query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batchloss = model(query_input_ids=query_input_ids,title_input_ids=title_input_ids,query_token_type_ids=query_token_type_ids,title_token_type_ids=title_token_type_ids)global_step += 1if global_step % log_steps == 0:print("global step %d, epoch: %d, batch: %d, loss: %.5f, speed: %.2f step/s"% (global_step, epoch, step, loss, 10 /(time.time() - tic_train)))tic_train = time.time()loss.backward()optimizer.step()lr_scheduler.step()optimizer.clear_grad()print("evaluating")recall_5 = evaluate(model, corpus_data_loader, query_data_loader,recall_result_file, text_list, id2corpus)if recall_5 > best_recall:best_recall = recall_5save_dir = os.path.join(save_root_dir, "model_best")if not os.path.exists(save_dir):os.makedirs(save_dir)save_param_path = os.path.join(save_dir, 'model_state.pdparams')paddle.save(model.state_dict(), save_param_path)tokenizer.save_pretrained(save_dir)with open(os.path.join(save_dir, "train_result.txt"),'a',encoding='utf-8') as fp:fp.write('epoch=%d, global_step: %d, recall: %s\n' %(epoch, global_step, recall_5))

语义索引模型的收敛速度比较慢,建议设置较大的epoch,性能会提升比较大,但训练时间会比较长一点。

6、模型预测

from data import build_index
# 构建索引
final_index = build_index(corpus_data_loader, model._layers,output_emb_size=output_emb_size,hnsw_max_elements=hnsw_max_elements,hnsw_ef=hnsw_ef,hnsw_m=hnsw_m)
[2022-09-20 17:21:27,837] [    INFO] - start build index..........
[2022-09-20 17:21:28,296] [    INFO] - Total index number:316
query_embedding = model._layers.get_semantic_embedding(query_data_loader)
list_data=[]
for batch_index, batch_query_embedding in enumerate(query_embedding):recalled_idx, cosine_sims = final_index.knn_query(batch_query_embedding.numpy(), recall_num)batch_size = len(cosine_sims)for row_index in range(batch_size):text_index = batch_size * batch_index + row_indexfor idx, doc_idx in enumerate(recalled_idx[row_index]):list_data.append([text_list[text_index]["text"], id2corpus[doc_idx],1.0 - cosine_sims[row_index][idx]])
# 打印若干预测的值
print(list_data[:20])
[['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '教育/科学,理工学科,心理学', 0.6262896060943604], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼,恋爱', 0.5953264832496643], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼,情感情绪', 0.5355594754219055], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '健康,精神心理科', 0.5196096301078796], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼', 0.3823578953742981], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '娱乐', 0.328216552734375], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '健康,精神心理科,心理科', 0.26789504289627075], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '烦恼,交友技巧', 0.2628575563430786], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '电脑/网络,互联网', 0.226911723613739], ['请问深入骨髓地喜欢一个人怎么办我不能确定对方是不是喜欢我,我却想我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点', '娱乐,电视', 0.21047455072402954], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,诛仙', 0.7549923658370972], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,金山游戏,封神榜', 0.543825089931488], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,神鬼传奇', 0.530568540096283], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,久游游戏,仙剑ol', 0.516210675239563], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,网络游戏,天龙八部', 0.4695323705673218], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,网络游戏,破天一剑', 0.46719789505004883], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,武林外传', 0.42875951528549194], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,单机游戏,仙剑奇侠传', 0.4053085446357727], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,完美游戏,口袋西游', 0.30933690071105957], ['我登陆诛仙2时总说我账号密码错误,但是我打的是正确的,就算不对我?', '游戏,网页游戏', 0.3086514472961426]]

7、模型部署

模型部署分为3步,第一步是把模型转换成静态图,这样能够提升速度;第二步是搭建一个分类索引引擎,构建标签向量库,可以使用Milvus,FAISS,Elastic Search等;第三步是PaddleServing部署,把模型部署成Serving的形式,可拓展性更好。

7.1 动转静导出

output_path='output'
# 切换成eval模式,关闭dropout
model.eval()
# Convert to static graph with specific input description
model = paddle.jit.to_static(model._layers,input_spec=[paddle.static.InputSpec(shape=[None, None], dtype="int64"),  # input_idspaddle.static.InputSpec(shape=[None, None], dtype="int64")  # segment_ids])
# Save in static graph model.
save_path = os.path.join(output_path, "inference")
paddle.jit.save(model, save_path)

7.2 分类引擎

模型准备结束以后,开始搭建 Milvus 的向量检索引擎,用于文本语义向量的快速检索,本项目使用Milvus开源工具进行向量检索,Milvus 的搭建教程请参考官方教程 Milvus官方安装教程本案例使用的是 Milvus 的1.1.1 CPU版本,建议使用官方的 Docker 安装方式,简单快捷。

如果想使用最新版本的Milvus,可以参考Neural Search的实现,最新的Milvus更易用一点。

7.3 Paddle Serving 部署

dirname="output"
# 模型的路径
model_filename="inference.pdmodel"
# 参数的路径
params_filename="inference.pdiparams"
# server的保存地址
server_path="serving_server"
# client的保存地址
client_path="serving_client"
# 指定输出的别名
feed_alias_names=None
# 制定输入的别名
fetch_alias_names="output_embedding"
# 设置为True会显示日志
show_proto=False
serving_io.inference_model_to_serving(dirname=dirname,serving_server=server_path,serving_client=client_path,model_filename=model_filename,params_filename=params_filename,show_proto=show_proto,feed_alias_names=feed_alias_names,fetch_alias_names=fetch_alias_names)
(dict_keys(['query_input_ids', 'title_input_ids']),dict_keys(['mean_0.tmp_0']))

搭建结束以后,就可以启动server部署服务,使用client端访问server端就行了。具体细节参考代码:
PaddleNLP文本分类应用

# 删除产生的模型文件
!rm -rf output/
# rm -rf checkpoints/
!rm -rf checkpoints_recall
!rm -rf serving_server/
!rm -rf serving_client/

8、模型优化

对于基于索引的方法除了用更多的数据集,调参数以外,还可以通过丰富标签的语义信息得到进一步的优化,由于标签大多都是关键词,这对于句子级别语义建模的方法不是很好,可以考虑把这些关键字,增加一些描述信息,这样就提供了更多的信息,对文本和标签的向量相似度计算很有帮助。

  • 如果对您有帮助,欢迎star收藏一下,不易走丢哦~链接指路:https://github.com/PaddlePaddle/PaddleNLP

此文章为搬运
原项目链接

PaddleNLP创新思路:基于检索实现层次化文本分类相关推荐

  1. Datawhale NLP入门:Task5 基于深度学习的文本分类2

    Task5 基于深度学习的文本分类2 在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习 ...

  2. Task5 基于深度学习的文本分类2

    Task5 基于深度学习的文本分类2 在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习 ...

  3. Datawhale零基础入门NLP day5/Task5基于深度学习的文本分类2

    基于深度学习的文本分类 本章将继续学习基于深度学习的文本分类. 学习目标 学习Word2Vec的使用和基础原理 学习使用TextCNN.TextRNN进行文本表示 学习使用HAN网络结构完成文本分类 ...

  4. Datawhale零基础入门NLP赛事 - Task5 基于深度学习的文本分类2

    在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习基于深度学习的文本分类. 学习目标 学 ...

  5. 新闻文本分类--任务5 基于深度学习的文本分类2

    Task5 基于深度学习的文本分类2 在上一章节,我们通过FastText快速实现了基于深度学习的文本分类模型,但是这个模型并不是最优的.在本章我们将继续深入. 基于深度学习的文本分类 本章将继续学习 ...

  6. 综述:基于深度学习的文本分类 --《Deep Learning Based Text Classification: A Comprehensive Review》总结(一)

    文章目录 综述:基于深度学习的文本分类 <Deep Learning Based Text Classification: A Comprehensive Review>论文总结(一) 总 ...

  7. NLP以赛代练 Task5:基于深度学习的文本分类 2

    基于深度学习的文本分类 2 学习目标 文本表示方法 Part3 词向量 1. Skip-grams原理和网络结构 2. Skip-grams训练 2.1 Word pairs and "ph ...

  8. 文本基线怎样去掉_ICML 2020 | 基于类别描述的文本分类模型

    论文标题: Description Based Text Classification with Reinforcement Learning 论文作者: Duo Chai, Wei Wu, Qing ...

  9. 【项目实战课】NLP入门第1课,人人免费可学,基于TextCNN的新闻文本分类实战...

    欢迎大家来到我们的项目实战课,本期内容是<基于TextCNN的新闻文本分类实战>. 所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题,进行代码级的实战讲解,可以 ...

最新文章

  1. 如何利用大数据进行精准营销
  2. 3、leetcode35 搜索插入位置**
  3. SpringMVC中使用Interceptor拦截器
  4. 用批处理修复 win10 无法升级的问题
  5. 计算机显卡是指什么时候,电脑哪个是显卡
  6. ssh 登陆错误后禁止ip再次登陆_macOS破坏SSH默认规则,程序员无法登录Web服务器...
  7. -bash : ** : command not found的问题解决(图文详解)
  8. shell 封装mysql查询
  9. 如果希望点击父控件子控件也响应的话, 可以给子控件加如下属性:  android:duplicateParentState=true...
  10. Multi-attributed heterogeneous graph convolutional network for bot detection(SCI CCF B)
  11. java实现登录注册界面
  12. xp的ie显示无服务器,WinXP系统IE无法打开站点怎么办?
  13. Vue 设置图片不转为base64
  14. 用python模拟一个文本浏览器来抓取网页
  15. 通过CSS实现太极图案例
  16. 删除插件mysearchresult(chrome和firebox)
  17. 18025 小明的密码
  18. adobe acrobat dc 2021如何统一pdf页面大小?
  19. 微服务商城mall-swarm本地搭建
  20. LPIPS 图像相似性度量标准(感知损失)

热门文章

  1. GUAVA本地缓存01_概述、优缺点、创建方式、回收机制、监听器、统计、异步锁定
  2. 【对汇编语言又爱又恨?那是没找对方法或者合适的书】
  3. java脚本语言是什么_什么是脚本语言
  4. Lniux操作系统下的火狐浏览器英文修改成中文
  5. GoldenDict词典安装和使用
  6. java和c语言哪个用途更大更广泛?
  7. bp神经网络对数据的要求,bp神经网络适用条件
  8. 今天许多的家庭有计算机翻译成英语,新视野英语教程课后翻译答案(高职高专版)...
  9. Java Scanner类的详细介绍(Java键盘输入)
  10. 丢人丢大了!深圳一公司违反开源协议还耍无赖,科技博主上门教做人!