如何使用BERT实现中文的文本分类(附代码)
如何使用BERT模型实现中文的文本分类
- 前言
- Pytorch
- readme
- 参数表
- 算法流程
- 1. 概述
- 2. 读取数据
- 3. 特征转换
- 4. 模型训练
- 5. 模型测试
- 6. 测试结果
- 7. 总结
前言
- Google官方BERT代码(Tensorflow)
- 本文章参考的BERT分类代码(Pytorch)
- 本文章改进的BERT中文文本分类代码(Pytorch)
- BERT模型介绍
Pytorch
readme
- 请先安装pytorch的BERT代码,代码源见前言(2)
pip install pytorch-pretrained-bert
参数表
data_dir | bert_model | task_name |
---|---|---|
输入数据目录 | 加载的bert模型,对于中文文本请输入’bert-base-chinese’ | 输入数据预处理模块,最好根据应用场景自定义 |
model_save_pth | max_seq_length* | train_batch_size |
模型参数保存地址 | 最大文本长度 | batch大小 |
learning_rate | num_train_epochs | |
Adam初始学习步长 | 最大epoch数 |
* max_seq_length = 所设定的文本长度 + 2 ,BERT会给每个输入文本开头和结尾分别加上[CLS]和[SEP]标识符,因此会占用2个字符空间,其作用会在后续进行详细说明。
算法流程
1. 概述
2. 读取数据
- 对应于参数表中的task_name,是用于数据读取的模块
- 可以根据自身需要自定义新的数据读取模块
- 以输入数据为json文件时为例,数据读取模块包含两个部分:
- 基类DataProcessor:
class DataProcessor(object): def get_train_examples(self, data_dir):raise NotImplementedError()def get_dev_examples(self, data_dir):raise NotImplementedError()def get_test_examples(self, data_dir):raise NotImplementedError()def get_labels(self):raise NotImplementedError()@classmethoddef _read_json(cls, input_file, quotechar=None):"""Reads a tab separated value file."""dicts = []with codecs.open(input_file, 'r', 'utf-8') as infs:for inf in infs:inf = inf.strip()dicts.append(json.loads(inf))return dicts
- 用于数据读取的模块MyPro:
class MyPro(DataProcessor):def get_train_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), 'train')def get_dev_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "val.json")), 'dev')def get_test_examples(self, data_dir):return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), 'test')def get_labels(self):return [0, 1]def _create_examples(self, dicts, set_type):examples = []for (i, infor) in enumerate(dicts):guid = "%s-%s" % (set_type, i)text_a = infor['question']label = infor['label']examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examples
- 基类DataProcessor:
- 需要注意的几点是:
- data_dir目录下应包含名为train、val、test的三个文件,根据文件格式不同需要对读取方式进行修改
- get_labels()返回的是所有可能的类别label_list,比如
['数学', '英语', '语文']
、[1, 2, 3]
… - 模块最终返回一个名为examples的列表,每个列表元素中包含序号、中文文本、类别三个元素
3. 特征转换
- convert_examples_to_features是用于将examples转换为特征,也即features的函数。
- features包含4个数据:
- input_ids:分词后每个词语在vocabulary中的id,补全符号对应的id为0,[CLS]和[SEP]的id分别为101和102。应注意的是,在中文BERT模型中,中文分词是基于字而非词的分词。
- input_mask:真实字符/补全字符标识符,真实文本的每个字对应1,补全符号对应0,[CLS]和[SEP]也为1。
- segment_ids:句子A和句子B分隔符,句子A对应的全为0,句子B对应的全为1。但是在多数文本分类情况下并不会用到句子B,所以基本不用管。
- label_id :将label_list中的元素利用字典转换为index标识,即
label_map = {} for (i, label) in enumerate(label_list):label_map[label] = i
- features中一个元素的例子是:
- 转换完成后的特征值就可以作为输入,用于模型的训练和测试
4. 模型训练
- 完成读取数据、特征转换之后,将特征送入模型进行训练
- 训练算法为BERT专用的Adam算法
- 训练集、测试集、验证集比例为6:2:2
- 每一个epoch后会在验证集上进行验证,并给出相应的f1值,如果f1值大于此前最高分则保存模型参数,否则flags加1。如果flags大于6,也即连续6个epoch模型的性能都没有继续优化,停止训练过程。
f1 = val(model, processor, args, label_list, tokenizer, device) if f1 > best_score:best_score = f1print('*f1 score = {}'.format(f1))flags = 0checkpoint = {'state_dict': model.state_dict()}torch.save(checkpoint, args.model_save_pth) else:print('f1 score = {}'.format(f1))flags += 1if flags >=6:break
- 如果epoch数超过先前设定的num_train_epochs,同样会停止迭代。
5. 模型测试
- 先加载模型
- 送数据,取得分,完事
- 暂时还没加打印测试结果到文件的功能,后续会加上
6. 测试结果
val_F1 | test_F1 | |
---|---|---|
Fast text | 0.7218 | 0.7094 |
Text rnn + bigru | 0.7383 | 0.7194 |
Text cnn | 0.7292 | 0.7088 |
bigru + attention | 0.7335 | 0.7146 |
RCNN | 0.7355 | 0.7213 |
BERT | 0.7938 | 0.787 |
- 基于真实数据做的文本分类,用过不少模型,BERT的性能可以说是独一档
- BERT确实牛逼,不过一部分原因也是模型量级就不一样
7. 总结
- 使用代码的时候按照参数表修改下参数,把数据按照命名规范放data_dir目录下一般就没啥问题了
- 最多还要修改下读取数据的代码(如果数据不是.json格式的),就可以跑通了
- 最后可以根据个人需要,对模型训练逻辑、epoch数、学习步长等地方做进一步修改
- 代码地址已经放在前言(3)里了
如何使用BERT实现中文的文本分类(附代码)相关推荐
- 【NLP保姆级教程】手把手带你RNN文本分类(附代码)
写在前面 这是NLP保姆级教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关于循环神经网络在多任务文本分类上的应用:Recur ...
- 【NLP傻瓜式教程】手把手带你RCNN文本分类(附代码)
继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...
- 【NLP傻瓜式教程】手把手带你HAN文本分类(附代码)
继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...
- 【NLP傻瓜式教程】手把手带你fastText文本分类(附代码)
写在前面 已经发布: [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) 继续NLP傻瓜式教程系列,今天的教程是基于FAIR的Bag of ...
- 毕业设计-基于 BERT 的中文长文本分类系统
目录 前言 课题背景和意义 实现技术思路 一.文本分类的相关技术 二.文本表示模型 三.文本分类模型 实现效果图样例 最后 前言
- 【NLP傻瓜式教程】手把手带你CNN文本分类(附代码)
文章来源于NewBeeNLP,作者kaiyuan 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classification[ ...
- 【NLP保姆级教程】手把手带你CNN文本分类(附代码)
分享一篇老文章,文本分类的原理和代码详解,非常适合NLP入门! 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classifi ...
- 【NLP】保姆级教程:手把手带你CNN文本分类(附代码)
分享一篇老文章,文本分类的原理和代码详解,非常适合NLP入门! 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classifi ...
- 【NLP傻瓜式教程】手把手带你RNN文本分类(附代码)
文章来源于NewBeeNLP,作者kaiyuan 写在前面 这是NLP傻瓜式教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关 ...
最新文章
- SpringBoot (六) :SpringBoot定时器实现(简单入门)
- python itertools模块位置_Python高效编程之itertools模块详解
- 阐述html语言的理解,大学语文课后思考题答案
- Spring boot配置项目访问路径server.context-path不起作用(改为server.servlet.context-path)
- centos 7 单独安装mysql和mysqli和pdo_mysql扩展
- 软件工程师证书有用吗_考证:BIM工程师证书有用吗?
- Java并发编程实战_《Java并发编程实战》PDF版本下载
- 测试低频噪音软件,设计制作并验证0.1Hz~10Hz超低频微弱噪音检测放大器STEP BY STEP...
- python计算日期是一年中的第几天,Python根据年月日,计算是一年的第几天
- 使命召唤手游服务器显示错误,使命召唤手游无法连接服务器是什么原因
- 华为认证HCIA-云服务工程师正式发布
- 高德打车宣布上线共享雨伞:或许是醉翁之意不在酒
- Python 基础 之 词云(词的频率统计大小成图)的简单实现(包括图片词云,词云颜色,词的过滤)
- 2048小游戏HTML网页版源码共享
- B站回应HR称“核心用户都是Loser”、求职者是“白嫖党”:已被劝退
- SX1278性能评估
- Shiro第十二章-与Spring集成、配置文件初解
- 爬楼梯当中的递归简化计算
- 老司机带你快速实现Python下载与安装
- 电气工程及其自动化-课程体系介绍
热门文章
- JavaFX 控件 ImageView
- fjut第四次周赛A(欧拉路径)D题
- 医疗器械许可证怎么办理
- c++笔记(1):C++中命令行参数argc,argc[ ]究竟是什么
- 手机音频口反向列表 类Square手机刷卡器注意事项
- jQuery面试笔试题汇总整理
- vivo Z1青春版全面评测:Z系列的继承者,展示全新千元态度
- Vue3+Vite+TS后台项目 ~ 4. axios请求封装
- 【火牛STM32F103VC】RT-Thread 蜂鸣器BEEP功能验证
- VScode快速下载方法和快捷键