本文将会介绍如何使用keras-bert实现文本多标签分类任务,其中对BERT进行微调

项目结构

  本项目的项目结构如下:

其中依赖的Python第三方模块如下:

pandas==0.23.4
Keras==2.3.1
keras_bert==0.83.0
numpy==1.16.4

数据集介绍

  本文采用的数据集与文章NLP(二十八)多标签文本分类中的一致,以事件抽取比赛的数据集为参考,形成文本与事件类型的多标签数据集,一共为65种事件类型。样例数据(csv格式)如下:

label,content
司法行为-起诉|组织关系-裁员,最近,一位前便利蜂员工就因公司违规裁员,将便利蜂所在的公司虫极科技(北京)有限公司告上法庭。
组织关系-裁员,思科上海大规模裁员人均可获赔100万官方澄清事实
组织关系-裁员,日本巨头面临危机,已裁员1000多人,苹果也救不了它!
组织关系-裁员|组织关系-解散,在硅谷镀金失败的造车新势力们:蔚来裁员、奇点被偷窃、拜腾解散

在label中,每个事件类型用|隔开。
  在该数据集中,训练集一共11958个样本,测试集一共1498个样本。

模型训练

  模型训练的脚本model_train.py的完整代码如下:

# -*- coding: utf-8 -*-
import json
import codecs
import pandas as pd
import numpy as np
from keras_bert import load_trained_model_from_checkpoint, Tokenizer
from keras.layers import *
from keras.models import Model
from keras.optimizers import Adam# 建议长度<=510
maxlen = 256
BATCH_SIZE = 8
config_path = './chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = './chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = './chinese_L-12_H-768_A-12/vocab.txt'token_dict = {}
with codecs.open(dict_path, 'r', 'utf-8') as reader:for line in reader:token = line.strip()token_dict[token] = len(token_dict)class OurTokenizer(Tokenizer):def _tokenize(self, text):R = []for c in text:if c in self._token_dict:R.append(c)else:R.append('[UNK]')   # 剩余的字符是[UNK]return Rtokenizer = OurTokenizer(token_dict)def seq_padding(X, padding=0):L = [len(x) for x in X]ML = max(L)return np.array([np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X])class DataGenerator:def __init__(self, data, batch_size=BATCH_SIZE):self.data = dataself.batch_size = batch_sizeself.steps = len(self.data) // self.batch_sizeif len(self.data) % self.batch_size != 0:self.steps += 1def __len__(self):return self.stepsdef __iter__(self):while True:idxs = list(range(len(self.data)))np.random.shuffle(idxs)X1, X2, Y = [], [], []for i in idxs:d = self.data[i]text = d[0][:maxlen]x1, x2 = tokenizer.encode(first=text)y = d[1]X1.append(x1)X2.append(x2)Y.append(y)if len(X1) == self.batch_size or i == idxs[-1]:X1 = seq_padding(X1)X2 = seq_padding(X2)Y = seq_padding(Y)yield [X1, X2], Y[X1, X2, Y] = [], [], []# 构建模型
def create_cls_model(num_labels):bert_model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=None)for layer in bert_model.layers:layer.trainable = Truex1_in = Input(shape=(None,))x2_in = Input(shape=(None,))x = bert_model([x1_in, x2_in])cls_layer = Lambda(lambda x: x[:, 0])(x)    # 取出[CLS]对应的向量用来做分类p = Dense(num_labels, activation='sigmoid')(cls_layer)     # 多分类model = Model([x1_in, x2_in], p)model.compile(loss='binary_crossentropy',optimizer=Adam(1e-5), # 用足够小的学习率metrics=['accuracy'])model.summary()return modelif __name__ == '__main__':# 数据处理, 读取训练集和测试集print("begin data processing...")train_df = pd.read_csv("data/train.csv").fillna(value="")test_df = pd.read_csv("data/test.csv").fillna(value="")select_labels = train_df["label"].unique()labels = []for label in select_labels:if "|" not in label:if label not in labels:labels.append(label)else:for _ in label.split("|"):if _ not in labels:labels.append(_)with open("label.json", "w", encoding="utf-8") as f:f.write(json.dumps(dict(zip(range(len(labels)), labels)), ensure_ascii=False, indent=2))train_data = []test_data = []for i in range(train_df.shape[0]):label, content = train_df.iloc[i, :]label_id = [0] * len(labels)for j, _ in enumerate(labels):for separate_label in label.split("|"):if _ == separate_label:label_id[j] = 1train_data.append((content, label_id))for i in range(test_df.shape[0]):label, content = test_df.iloc[i, :]label_id = [0] * len(labels)for j, _ in enumerate(labels):for separate_label in label.split("|"):if _ == separate_label:label_id[j] = 1test_data.append((content, label_id))# print(train_data[:10])print("finish data processing!")# 模型训练model = create_cls_model(len(labels))train_D = DataGenerator(train_data)test_D = DataGenerator(test_data)print("begin model training...")model.fit_generator(train_D.__iter__(),steps_per_epoch=len(train_D),epochs=10,validation_data=test_D.__iter__(),validation_steps=len(test_D))print("finish model training!")# 模型保存model.save('multi-label-ee.h5')print("Model saved!")result = model.evaluate_generator(test_D.__iter__(), steps=len(test_D))print("模型评估结果:", result)

  模型结构如下:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, None)         0
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None)         0
__________________________________________________________________________________________________
model_2 (Model)                 (None, None, 768)    101677056   input_1[0][0]                    input_2[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 768)          0           model_2[1][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 65)           49985       lambda_1[0][0]
==================================================================================================
Total params: 101,727,041
Trainable params: 101,727,041
Non-trainable params: 0
__________________________________________________________________________________________________

  从中我们可以发现,该模型结构与文章NLP(三十五)使用keras-bert实现文本多分类任务中给出的文本多分类模型结构大体一致,修改之处在于BERT后接的网络结构,所接的依然是dense层,但激活函数采用sigmoid函数,同时损失函数为binary_crossentropy。就其本质而言,该模型结构是对输出的65个结果采用0-1分类,故而激活函数采用sigmoid,这当然是文本多分类模型转化为多标签标签的最便捷方式,但不足之处在于,该模型并未考虑标签之间的依赖关系。

模型评估

  模型评估脚本model_evaluate.py的完整代码如下:

# -*- coding: utf-8 -*-
# @Time : 2020/12/23 15:28
# @Author : Jclian91
# @File : model_evaluate.py
# @Place : Yangpu, Shanghai
# 模型评估脚本,利用hamming_loss作为多标签分类的评估指标,该值越小模型效果越好
import json
import numpy as np
import pandas as pd
from keras.models import load_model
from keras_bert import get_custom_objects
from sklearn.metrics import hamming_loss, classification_reportfrom model_train import token_dict, OurTokenizermaxlen = 256# 加载训练好的模型
model = load_model("multi-label-ee.h5", custom_objects=get_custom_objects())
tokenizer = OurTokenizer(token_dict)
with open("label.json", "r", encoding="utf-8") as f:label_dict = json.loads(f.read())# 对单句话进行预测
def predict_single_text(text):# 利用BERT进行tokenizetext = text[:maxlen]x1, x2 = tokenizer.encode(first=text)X1 = x1 + [0] * (maxlen - len(x1)) if len(x1) < maxlen else x1X2 = x2 + [0] * (maxlen - len(x2)) if len(x2) < maxlen else x2# 模型预测并输出预测结果prediction = model.predict([[X1], [X2]])one_hot = np.where(prediction > 0.5, 1, 0)[0]return one_hot, "|".join([label_dict[str(i)] for i in range(len(one_hot)) if one_hot[i]])# 模型评估
def evaluate():test_df = pd.read_csv("data/test.csv").fillna(value="")true_y_list, pred_y_list = [], []true_label_list, pred_label_list = [], []common_cnt = 0for i in range(test_df.shape[0]):print("predict %d samples" % (i+1))true_label, content = test_df.iloc[i, :]true_y = [0] * len(label_dict.keys())for key, value in label_dict.items():if value in true_label:true_y[int(key)] = 1pred_y, pred_label = predict_single_text(content)if true_label == pred_label:common_cnt += 1true_y_list.append(true_y)pred_y_list.append(pred_y)true_label_list.append(true_label)pred_label_list.append(pred_label)# F1值print(classification_report(true_y_list, pred_y_list, digits=4))return true_label_list, pred_label_list, hamming_loss(true_y_list, pred_y_list), common_cnt/len(true_y_list)true_labels, pred_labels, h_loss, accuracy = evaluate()
df = pd.DataFrame({"y_true": true_labels, "y_pred": pred_labels})
df.to_csv("pred_result.csv")print("accuracy: ", accuracy)
print("hamming loss: ", h_loss)

Hamming Loss为多标签分类所特有的评估方式,其值越小代表多标签分类模型的效果越好。运行上述模型评估代码,输出结果如下:

   micro avg     0.9341    0.9578    0.9458      1657macro avg     0.9336    0.9462    0.9370      1657
weighted avg     0.9367    0.9578    0.9456      1657samples avg     0.9520    0.9672    0.9531      1657accuracy:  0.8985313751668892
hamming loss:  0.001869158878504673

  在这里,笔者希望与之前的文章NLP(二十八)多标签文本分类中的模型对比一下。当时采用的模型为用ALBERT提取特征向量,再用Bi-GRU+Attention+FCN进行分类,模型结构如下:

  对该模型同样采用上述评估办法,输出的结果如下:

   micro avg     0.9424    0.8292    0.8822      1657macro avg     0.8983    0.7218    0.7791      1657
weighted avg     0.9308    0.8292    0.8669      1657samples avg     0.8675    0.8496    0.8517      1657
accuracy:  0.7983978638184246
hamming loss:  0.0037691280681934887

可以发现,采用BERT微调的模型,在accuracy方面高出了约10%,各种F1值高出约5%-10%,Hamming Loss也小了很多。因此,BERT微调的模型比之前的模型效果好很多。

总结

  本项目已经开源,Github地址为:https://github.com/percent4/keras_bert_multi_label_cls 。
  2020年12月27日于上海浦东

参考文章

  1. NLP(二十八)多标签文本分类:https://blog.csdn.net/jclian91/article/details/105386190
  2. NLP(三十五)使用keras-bert实现文本多分类任务:https://blog.csdn.net/jclian91/article/details/111742576

NLP(三十六)使用keras-bert实现文本多标签分类任务相关推荐

  1. [Python从零到壹] 三十六.图像处理基础篇之图像算术与逻辑运算详解

    欢迎大家来到"Python从零到壹",在这里我将分享约200篇Python系列文章,带大家一起去学习和玩耍,看看Python这个有趣的世界.所有文章都将结合案例.代码和作者的经验讲 ...

  2. Python编程基础:第三十六节 模块Modules

    第三十六节 模块Modules 前言 实践 前言 我们目前所有的代码都写在一个文档里面.如果你的项目比较大,那么把所有功能写在一个文件里就非常不便于后期维护.为了提高我们代码的可读性,降低后期维护的成 ...

  3. OpenCV学习笔记(三十六)——Kalman滤波做运动目标跟踪 OpenCV学习笔记(三十七)——实用函数、系统函数、宏core OpenCV学习笔记(三十八)——显示当前FPS OpenC

    OpenCV学习笔记(三十六)--Kalman滤波做运动目标跟踪 kalman滤波大家都很熟悉,其基本思想就是先不考虑输入信号和观测噪声的影响,得到状态变量和输出信号的估计值,再用输出信号的估计误差加 ...

  4. NeHe OpenGL教程 第三十六课:从渲染到纹理

    转自[翻译]NeHe OpenGL 教程 前言 声明,此 NeHe OpenGL教程系列文章由51博客yarin翻译(2010-08-19),本博客为转载并稍加整理与修改.对NeHe的OpenGL管线 ...

  5. 三十六、Java集合中的HashMap

    @Author:Runsen @Date:2020/6/3 作者介绍:Runsen目前大三下学期,专业化学工程与工艺,大学沉迷日语,Python, Java和一系列数据分析软件.导致翘课严重,专业排名 ...

  6. 三十六、rsync通过服务同步、Linux系统日志、screen工具

    三十六.rsync通过服务同步.Linux系统日志.screen工具 一.rsync通过服务同步 该方式可以理解为:在远程主机上建立一个rsync的服务器,在服务器上配置好各种应用,然后本机将作为客户 ...

  7. 嵌入式实时操作系统ucos-ii_「正点原子NANO STM32开发板资料连载」第三十六章 UCOSII 实验 1任务调度...

    1)实验平台:alientek NANO STM32F411 V1开发板2)摘自<正点原子STM32F4 开发指南(HAL 库版>关注官方微信号公众号,获取更多资料:正点原子 第三十六章 ...

  8. 第三百三十六章 斗宗强者间的大战!

    第三百三十六章 斗宗强者间的大战! <script language="javascript" src="/js/style2.js"></s ...

  9. 三十六進制之間隨便轉換

    去年在網上給一家公司投簡歷的時候,對方要求寫一個任意進制轉換的函數,當時沒有回過神來,也不知道JAVA中有這樣的函數,呵呵.于是就自己操刀,寫了這個三十六進制之音隨便轉的函數.不過,權當練習吧,如果你 ...

最新文章

  1. eclipse rcp 多线程
  2. net 快速打印日志
  3. mysql 20小时内,mysql中关于date(Y-m-d H:i:s) 入库慢8小时的解决
  4. windows下eclipse远程连接hadoop集群开发mapreduce
  5. 初一模拟赛总结(2019.3.9)
  6. 【计算机网络】链路与连通
  7. Why Open vSwitch?
  8. mybatis 学习一 建立maven项目
  9. HT for Web中3D流动效果的实现与应用
  10. P2835 刻录光盘
  11. bzoj 5281: [Usaco2018 Open]Talent Show【dp】
  12. mysql sjis 校对乱码_MySQL乱码问题深层分析
  13. 上周热点回顾(5.8-5.14)
  14. java剪刀石头布编程_Java如何编写石头剪子布游戏程序
  15. Vue地图导航调用百度地图
  16. 了解5G技术与未来5G面临的问题
  17. dva 的一些特殊的写法
  18. 因数据迁移导致跨库连接失效的解决办法
  19. vue 一个动态链接url转成二维码
  20. let和const的区别

热门文章

  1. matlab序列补零dft,补零位置的不同对频谱的影响
  2. Ubuntu20.04密码忘记了怎么办?
  3. + smarty 模板
  4. 普通电脑安装华为电脑管家操作流程
  5. 最小编辑距离算法 Edit Distance(经典DP)
  6. 道路分割 matlab,MATLAB图像的道路分割技术研究
  7. c语言课程设计-计算器
  8. python怎么打开ipynb文件_怎么在Jupyter里打开ipynb文件
  9. 分享160多种ChatGPT 高频中文prompt 提示词指令合集——秒变AI训练师
  10. 常用好玩的Git命令