文本分类任务可以根据标签类型分为多分类(multi class)、多标签(multi label)、层次分类(hierarchical等三类任务,接下来我们将以下图的新闻文本分类为例介绍三种分类任务的区别。


PaddleNLP采用AutoModelForSequenceClassification, AutoTokenizer提供了方便易用的接口,可指定模型名或模型参数文件路径通过from_pretrained() 方法加载不同网络结构的预训练模型,并在输出层上叠加一层线性层,且相应预训练模型权重下载速度快、稳定。Transformer预训练模型汇总包含了如 ERNIE、BERT、RoBERTa等40多个主流预训练模型,500多个模型权重。下面以ERNIE 3.0 中文base模型为例,演示如何加载预训练模型和分词器:

from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
num_classes = 10
model_name = "ernie-3.0-base-zh"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_classes=num_classes)
tokenizer = AutoTokenizer.from_pretrained(model_name)

0.1 层次分类任务介绍



0.2 文本分类应用全流程介绍


  • 数据准备


  • 模型训练
    数据准备完成后,可以开始使用我们的数据集对预训练模型进行微调训练。我们可以根据任务需求,调整可配置参数,选择使用GPU或CPU进行模型训练,脚本默认保存在开发集最佳表现模型。中文任务默认使用"ernie-3.0-base-zh"模型,英文任务默认使用"ernie-2.0-base-en"模型,ERNIE 3.0还支持多个轻量级中文模型,详见ERNIE模型汇总,可以根据任务和设备需求进行选择。

首先我们需要根据场景选择不同的任务目录,具体可以见 多分类任务点击这里 多标签任务点击这里 层次分类任务点击这里


  • 模型预测




文本分类应用同时基于Paddle Serving的服务端部署方案。










├── train.py # 训练评估脚本
├── predict.py # 预测脚本
├── export_model.py # 动态图参数导出静态图参数脚本
├── utils.py # 工具函数脚本
├── metric.py # metric脚本
├── prune.py # 裁剪脚本
├── prune_trainer.py # 裁剪trainer脚本
├── prune_config.py # 裁剪训练参数配置
├── requirements.txt # 环境依赖
└── README.md # 使用说明

以层次分类公开数据集WOS(Web of Science)为示例,在训练集上进行模型微调,并在开发集上验证。WOS数据集是一个两层的层次文本分类数据集,包含7个父类和134子类,每个样本对应一个父类标签和子类标签,父类标签和子类标签间具有树状层次结构关系。

程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存开发集上最佳模型在指定的 save_dir 中,保存模型文件结构如下所示:

├── model_config.json
├── model_state.pdparams
├── tokenizer_config.json
└── vocab.txt


如需恢复模型训练,则可以设置 init_from_ckpt , 如 init_from_ckpt=checkpoint/model_state.pdparams 。
如需训练中文文本分类任务,只需更换预训练模型参数 model_name 。中文训练任务推荐使用"ernie-3.0-base-zh",更多可选模型可参考Transformer预训练模型。



!wget https://paddlenlp.bj.bcebos.com/datasets/wos_data.tar.gz
!tar -zxvf wos_data.tar.gz
!mv wos_data data


├── train.txt # 训练数据集文件
├── dev.txt # 开发数据集文件
├── test.txt # 可选,测试训练集文件
├── label.txt # 分类标签文件
└── data.txt # 可选,待预测数据文件

train.txt(训练数据集文件), dev.txt(开发数据集文件), test.txt(可选,测试训练集文件)中 n 表示标签层次结构中最大层数,<level i 标签> 代表数据的第i层标签。输入文本序列及不同层的标签数据用’\t’分隔开,每一层标签中多个标签之间用’,‘逗号分隔开。注意,对于第i层数据没有标签的,使用空字符’'来表示<level i 标签>。

train.txt/dev.txt/test.txt 文件格式:

<输入序列1>'\t'<level 1 标签1>','<level 1 标签2>'\t'<level 2 标签1>','<level 2 标签2>'\t'...'\t'<level n 标签1>','<level n 标签2>
<输入序列2>'\t'<level 1 标签>'\t'<level 2 标签>'\t'...'\t'<level n 标签>

train.txt/dev.txt/test.txt 文件样例:

unintended pregnancy continues to be a substantial public health problem. emergency contraception (ec) provides a last chance at pregnancy prevention. several safe and effective options for emergency contraception are currently available. the yuzpe method, a combined hormonal regimen, was essentially replaced by other oral medications including levonorgestrel and the antiprogestin ulipristal. the antiprogestin mifepristone has been studied for use as emergency contraception. the most effective postcoital method of contraception is the copper intrauterine device (iud). obesity and the simultaneous initiation of progestin-containing contraception may decrease the effectiveness of some emergency contraception.    Medical    Emergency Contraception
the objective of this paper is to present an example in which matrix functions are used to solve a modern control exercise. specifically, the solution for the equation of state, which is a matrix differential equation is calculated. to resolve this, two different methods are presented, first using the properties of the matrix functions and by other side, using the classical method of laplace transform.    ECE    Control engineering


label.txt 文件格式:

<level 1: 标签>
<level 1: 标签>'##'<level 2: 标签>
<level 1: 标签>'##'<level 2: 标签>'##'<level 3: 标签>

label.txt 文件样例:

CS##Computer vision
CS##Machine learning
ECE##Lorentz force law


data.txt 文件格式:


data.txt 文件样例:

previous research exploring cognitive biases in bulimia nervosa suggests that attentional biases occur for both food-related and body-related cues. individuals with bulimia were compared to non-bulimic controls on an emotional-stroop task which contained both food-related and body-related cues. results indicated that bulimics (but not controls) demonstrated a cognitive bias for both food-related and body related cues. however, a discrepancy between the two cue-types was observed with body-related cognitive biases showing the most robust effects and food-related cognitive biases being the most strongly associated with the severity of the disorder. the results may have implications for clinical practice as bulimics with an increased cognitive bias for food-related cues indicated increased bulimic disorder severity. (c) 2016 elsevier ltd. all rights reserved.
posterior reversible encephalopathy syndrome (pres) is a reversible clinical and neuroradiological syndrome which may appear at any age and characterized by headache, altered consciousness, seizures, and cortical blindness. the exact incidence is still unknown. the most commonly identified causes include hypertensive encephalopathy, eclampsia, and some cytotoxic drugs. vasogenic edema related subcortical white matter lesions, hyperintense on t2a and flair sequences, in a relatively symmetrical pattern especially in the occipital and parietal lobes can be detected on cranial mr imaging. these findings tend to resolve partially or completely with early diagnosis and appropriate treatment. here in, we present a rare case of unilateral pres developed following the treatment with pazopanib, a testicular tumor vascular endothelial growth factor (vegf) inhibitory agent.


!python train.py --early_stop --epochs 5  --warmup --save_dir "./checkpoint" --batch_size 32 --dataset_dir "data/wos_data"


[2022-07-27 17:54:18,773] [    INFO] - global step 1870, epoch: 2, batch: 930, loss: 0.04018, micro f1 score: 0.56644, macro f1 score: 0.04182, speed: 1.79 step/s
[2022-07-27 17:54:24,434] [    INFO] - global step 1875, epoch: 2, batch: 935, loss: 0.03838, micro f1 score: 0.56670, macro f1 score: 0.04185, speed: 1.79 step/s
[2022-07-27 17:54:29,539] [    INFO] - global step 1880, epoch: 2, batch: 940, loss: 0.03892, micro f1 score: 0.56682, macro f1 score: 0.04187, speed: 1.98 step/s
[2022-07-27 17:55:27,020] [    INFO] - eval loss: 0.03925, micro f1 score: 0.59396, macro f1 score: 0.04428
[2022-07-27 17:55:27,021] [    INFO] - Current best macro f1 score: 0.04428
[2022-07-27 17:55:28,033] [    INFO] - tokenizer config file saved in ./checkpoint/tokenizer_config.json
[2022-07-27 17:55:28,034] [    INFO] - Special tokens file saved in ./checkpoint/special_tokens_map.json
[2022-07-27 17:55:30,385] [    INFO] - global step 1885, epoch: 3, batch: 5, loss: 0.03854, micro f1 score: 0.64000, macro f1 score: 0.04778, speed: 0.16 step/s
[2022-07-27 17:55:31,980] [    INFO] - global step 1890, epoch: 3, batch: 10, loss: 0.03603, micro f1 score: 0.63455, macro f1 score: 0.04747, speed: 6.57 step/s
[2022-07-27 17:55:33,539] [    INFO] - global step 1895, epoch: 3, batch: 15, loss: 0.03707, micro f1 score: 0.62945, macro f1 score: 0.04679, speed: 6.73 step/s
[2022-07-27 17:55:35,138] [    INFO] - global step 1900, epoch: 3, batch: 20, loss: 0.03549, micro f1 score: 0.62788, macro f1 score: 0.04674, speed: 6.56 step/s
[2022-07-27 17:55:36,823] [    INFO] - global step 1905, epoch: 3, batch: 25, loss: 0.03838, micro f1 score: 0.62448, macro f1 score: 0.04646, speed: 6.20 step/s
[2022-07-27 17:55:38,457] [    INFO] - global step 1910, epoch: 3, batch: 30, loss: 0.03717, micro f1 score: 0.62339, macro f1 score: 0.04635, speed: 6.42 step/s
[2022-07-27 17:55:40,075] [    INFO] - global step 1915, epoch: 3, batch: 35, loss: 0.04115, micro f1 score: 0.62302, macro f1 score: 0.04632, speed: 6.48 step/s
[2022-07-27 17:55:41,742] [    INFO] - global step 1920, epoch: 3, batch: 40, loss: 0.03842, micro f1 score: 0.61973, macro f1 score: 0.04607, speed: 6.29 step/s
[2022-07-27 17:55:43,423] [    INFO] - global step 1925, epoch: 3, batch: 45, loss: 0.03772, micro f1 score: 0.61950, macro f1 score: 0.04606, speed: 6.22 step/s
[2022-07-27 17:55:45,118] [    INFO] - global step 1930, epoch: 3, batch: 50, loss: 0.04074, micro f1 score: 0.61848, macro f1 score: 0.04602, speed: 6.17 step/s







max_seq_length:ERNIE模型使用的最大序列长度,最大不能超过512, 若出现显存不足,请适当调低这一参数;默认为512。


device: 选用什么设备进行训练,可选cpu、gpu、xpu、npu。如使用gpu训练,择使用参数gpus指定GPU卡号。






epochs: 训练轮次,默认为1000。


warmup_steps:学习率warmup策略的steps数,如果设为2000,则学习率会在前2000 steps数从0慢慢增长到learning_rate, 而后再缓慢衰减;默认为2000。

logging_steps: 日志打印的间隔steps数,默认5。



2.2.1 评价指标定义


    criterion = paddle.nn.BCEWithLogitsLoss()metric = MetricReport() #得到F1 值  如果需要修改参考多分类文章micro_f1_score, macro_f1_score = evaluate(model, criterion, metric,dev_data_loader)



from sklearn.metrics import f1_score, classification_report


from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score

import numpy as np
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_scorefrom paddle.metric import Metricclass MultiLabelReport(Metric):"""AUC and F1 Score for multi-label text classification task."""def __init__(self, name='MultiLabelReport', average='micro'):super(MultiLabelReport, self).__init__()self.average = averageself._name = nameself.reset()def f1_score(self, y_prob):'''Returns the f1 score by searching the best threshhold'''best_score = 0for threshold in [i * 0.01 for i in range(100)]:self.y_pred = y_prob > thresholdscore = f1_score(y_pred=self.y_pred, y_true=self.y_true, average=self.average)if score > best_score:best_score = scoreprecison = precision_score(y_pred=self.y_pred, y_true=self.y_true, average=self.average)recall = recall_score(y_pred=self.y_pred, y_true=self.y_true, average=self.average)return best_score, precison, recalldef reset(self):"""Resets all of the metric state."""self.y_prob = Noneself.y_true = Nonedef update(self, probs, labels):if self.y_prob is not None:self.y_prob = np.append(self.y_prob, probs.numpy(), axis=0)else:self.y_prob = probs.numpy()if self.y_true is not None:self.y_true = np.append(self.y_true, labels.numpy(), axis=0)else:self.y_true = labels.numpy()def accumulate(self):auc = roc_auc_score(y_score=self.y_prob, y_true=self.y_true, average=self.average)f1_score, precison, recall = self.f1_score(y_prob=self.y_prob)return auc, f1_score, precison, recalldef name(self):"""Returns metric name"""return self._name


#!python -m paddle.distributed.launch --gpus "0" train.py --early_stop --dataset_dir data
#使用多卡训练可以指定多个GPU卡号,例如 --gpus "0,1"

2.3 模型预测



python predict.py --params_path ./checkpoint/


!python predict.py --params_path ./checkpoint/ --dataset_dir data/wos_data


input data: a high degree of uncertainty associated with the emission inventory for china tends to degrade the performance of chemical transport models in predicting pm2.5 concentrations especially on a daily basis. in this study a novel machine learning algorithm, geographically -weighted gradient boosting machine (gw-gbm), was developed by improving gbm through building spatial smoothing kernels to weigh the loss function. this modification addressed the spatial nonstationarity of the relationships between pm2.5 concentrations and predictor variables such as aerosol optical depth (aod) and meteorological conditions. gw-gbm also overcame the estimation bias of pm2.5 concentrations due to missing aod retrievals, and thus potentially improved subsequent exposure analyses. gw-gbm showed good performance in predicting daily pm2.5 concentrations (r-2 = 0.76, rmse = 23.0 g/m(3)) even with partially missing aod data, which was better than the original gbm model (r-2 = 0.71, rmse = 25.3 g/m(3)). on the basis of the continuous spatiotemporal prediction of pm2.5 concentrations, it was predicted that 95% of the population lived in areas where the estimated annual mean pm2.5 concentration was higher than 35 g/m(3), and 45% of the population was exposed to pm2.5 >75 g/m(3) for over 100 days in 2014. gw-gbm accurately predicted continuous daily pm2.5 concentrations in china for assessing acute human health effects. (c) 2017 elsevier ltd. all rights reserved.
predicted result:
level 1: CS
level 2:
input data: previous research exploring cognitive biases in bulimia nervosa suggests that attentional biases occur for both food-related and body-related cues. individuals with bulimia were compared to non-bulimic controls on an emotional-stroop task which contained both food-related and body-related cues. results indicated that bulimics (but not controls) demonstrated a cognitive bias for both food-related and body related cues. however, a discrepancy between the two cue-types was observed with body-related cognitive biases showing the most robust effects and food-related cognitive biases being the most strongly associated with the severity of the disorder. the results may have implications for clinical practice as bulimics with an increased cognitive bias for food-related cues indicated increased bulimic disorder severity. (c) 2016 elsevier ltd. all rights reserved.
predicted result:
level 1: Psychology
level 2:
input data: posterior reversible encephalopathy syndrome (pres) is a reversible clinical and neuroradiological syndrome which may appear at any age and characterized by headache, altered consciousness, seizures, and cortical blindness. the exact incidence is still unknown. the most commonly identified causes include hypertensive encephalopathy, eclampsia, and some cytotoxic drugs. vasogenic edema related subcortical white matter lesions, hyperintense on t2a and flair sequences, in a relatively symmetrical pattern especially in the occipital and parietal lobes can be detected on cranial mr imaging. these findings tend to resolve partially or completely with early diagnosis and appropriate treatment. here in, we present a rare case of unilateral pres developed following the treatment with pazopanib, a testicular tumor vascular endothelial growth factor (vegf) inhibitory agent.
predicted result:
level 1: Medical
level 2: 








