一、简介

在开始使用之前,我们先简单介绍一下到底什么是BERT,大家也可以去BERT的github上进行详细的了解。在CV问题中,目前已经有了很多成熟的预训练模型供大家使用,我们只需要修改结尾的FC层或根据实际场景添加softmax层,也就是我们常说的迁移学习。那在NLP领域是否有类似的方法呢,答案是肯定的,BERT就是这样的预训练模型。对于NLP的正常流程来说,我们需要做一些预处理,例如分词、W2V等,BERT包含所有的预训练过程,只需要提供文本数据即可,接下来我们会基于NLP常用的文本分类问题来介绍如何使用BERT。

BERT 模型来源于论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding。BERT模型是谷歌提出的基于双向Transformer构建的语言模型。BERT模型和ELMo有大不同,在之前的预训练模型(包括word2vec,ELMo等)都会生成词向量,这种类别的预训练模型属于domain transfer。而近一两年提出的ULMFiT,GPT,BERT等都属于模型迁移。

BERT 模型是将预训练模型和下游任务模型结合在一起的,也就是说在做下游任务时仍然是用BERT模型,而且天然支持文本分类任务,在做文本分类任务时不需要对模型做修改。谷歌提供了下面七种预训练好的模型文件。

BERT模型在英文数据集上提供了两种大小的模型,Base和Large。Uncased是意味着输入的词都会转变成小写,cased是意味着输入的词会保存其大写(在命名实体识别等项目上需要)。Multilingual是支持多语言的,最后一个是中文预训练模型。

二、下载预训练模型和代码

首先,在github上clone谷歌的BERT项目,或者直接下载。项目地址

然后,下载中文预训练模型,地址

预训练模型主要包含三个内容:

  • TensorFlow 用来保存预训练模型的三个 checkpoint 文件(bert_model.ckpt.xxx)
  • 字典文件,用于做ID的映射 (vocab.txt)
  • 配置文件,该文件的参数是fine-tuning时模型用到的,可自行调整 (bert_config.json)

三、环境准备

tensorflow >= 1.11.0, BERT base 模型占用显存约为 9.5G。

四、数据准备

我们需要将文本数据分为三部分:

  • Train: train.tsv
  • Evaluate: dev.tsv
  • Test: test.tsv

每个文件的格式,非常简单,一列为需要做分类的文本数据,另一列则是对应的 Label。

五、编写代码,修改processor

模型准备好后就可以编写代码了,我们先把BERT的github代码clone下来,之后我们的代码编写会基于run_classifier.py文件

由于我们要做的是文本多分类任务,可以在 run_classifier.py 基础上面做调整。主要是添加我们的数据预处理类。

可以看到,在执行run_classifier.py时需要先输入这5个必填参数,这里我们对参数做一个简单的说明:

参数 说明
data_dir 训练数据的地址
task_name processor的名字
vocab_file 字典地址,用默认提供的就可以了,当然也可以自定义
bert_config_file 配置文件
output_dir 模型的输出地址

再补充下以下三个可选参数说明:

参数 说明
do_train 是否做fine-tuning,默认为false,如果为true必须重写获取训练集的方法
do_eval 是否运行验证集,默认为false,如果为true必须重写获取验证集的方法
dopredict 是否做预测,默认为false,如果为true必须重写获取测试集的方法

在run_classifier.py文件中有一个基类DataProcessor类,其代码如下:

class DataProcessor(object):"""Base class for data converters for sequence classification data sets."""def get_train_examples(self, data_dir):"""Gets a collection of `InputExample`s for the train set."""raise NotImplementedError()def get_dev_examples(self, data_dir):"""Gets a collection of `InputExample`s for the dev set."""raise NotImplementedError()def get_test_examples(self, data_dir):"""Gets a collection of `InputExample`s for prediction."""raise NotImplementedError()def get_labels(self):"""Gets the list of labels for this data set."""raise NotImplementedError()@classmethoddef _read_tsv(cls, input_file, quotechar=None):"""Reads a tab separated value file."""with tf.gfile.Open(input_file, "r") as f:reader = csv.reader(f, delimiter="\t", quotechar=quotechar)lines = []for line in reader:lines.append(line)return lines

在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。在run_classsifier.py文件中我们可以看到,google对于一些公开数据集已经写了一些processor,如XnliProcessor,MnliProcessor,MrpcProcessor和ColaProcessor。这给我们提供了一个很好的示例,指导我们如何针对自己的数据集来写processor。接下来我们要定义自己的数据处理的类,我们将我们的类命名为IMDBProcessor

class IMDBProcessor(DataProcessor):"""IMDB data processor"""def _read_csv(self, data_dir, file_name):with tf.gfile.Open(data_dir + file_name, "r") as f:reader = csv.reader(f, delimiter=",", quotechar=None)lines = []for line in reader:lines.append(line)return linesdef get_train_examples(self, data_dir):lines = self._read_csv(data_dir, "trainData.csv")examples = []for (i, line) in enumerate(lines):if i == 0:continueguid = "train-%d" % (i)text_a = tokenization.convert_to_unicode(line[0])label = tokenization.convert_to_unicode(line[1])examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examplesdef get_dev_examples(self, data_dir):lines = self._read_csv(data_dir, "devData.csv")examples = []for (i, line) in enumerate(lines):if i == 0:continueguid = "dev-%d" % (i)text_a = tokenization.convert_to_unicode(line[0])label = tokenization.convert_to_unicode(line[1])examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examplesdef get_test_examples(self, data_dir):lines = self._read_csv(data_dir, "testData.csv")examples = []for (i, line) in enumerate(lines):if i == 0:continueguid = "test-%d" % (i)text_a = tokenization.convert_to_unicode(line[0])label = tokenization.convert_to_unicode(line[1])examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examplesdef get_labels(self):return ["0", "1"]

可以用基类中的静态方法_read_tsv读取数据集是用\t分割的,你的数据集可以按照\t分割,或者自己定义一个_read_csv的方法读取数据集,其余的方法就是读取训练集,验证集,测试集和标签。在这里标签就是一个列表,将我们的类别标签放入就行。训练集,验证集和测试集都是返回一个InputExample对象的列表。

读取的数据需要封装成一个InputExample的对象并添加到list中,InputExample是run_classifier.py中定义的一个类,源码为:

class InputExample(object):"""A single training/test example for simple sequence classification."""def __init__(self, guid, text_a, text_b=None, label=None):"""Constructs a InputExample.Args:guid: Unique id for the example.text_a: string. The untokenized text of the first sequence. For singlesequence tasks, only this sequence must be specified.text_b: (Optional) string. The untokenized text of the second sequence.Only must be specified for sequence pair tasks.label: (Optional) string. The label of the example. This should bespecified for train and dev examples, but not for test examples."""self.guid = guidself.text_a = text_aself.text_b = text_bself.label = label

注意这里有一个guid的参数,这个参数是必填的,是用来区分每一条数据的。是否进行训练集、验证集、测试集的计算,在执行代码时会有参数控制,我们下文会讲,所以这里的抽象方法也并不是需要全部都重写,但是为了体验一个完整的流程, 建议大家还是简单写一下。

在这里定义了text_a和text_b,说明是支持句子对的输入的,不过我们这里做文本分类只有一个句子的输入,因此text_b可以不传参。

另外从上面我们自定义的数据处理类中可以看出,训练集和验证集是保存在不同文件中的,因此我们需要将我们之前预处理好的数据提前分割成训练集和验证集,并存放在同一个文件夹下面,文件的名称要和类中方法里的名称相同。

get_labels方法返回的是一个数组,因为相似度问题可以理解为分类问题,所以返回的标签只有0和1,注意,这里我返回的是参数是字符串,所以在重写获取数据的方法时InputExample中的label也要传字符串的数据,可以看到上图中我对label做了一个str()的处理。

到这里之后我们已经准备好了我们的数据集,并定义好了数据处理类,此时我们需要将我们的数据处理类加入到run_classifier.py文件中的main函数下面的processors字典中,给Processor加一个名字,在运行时告诉代码我们要执行哪一个Processor。结果如下:

def main(_):tf.logging.set_verbosity(tf.logging.INFO)processors = {"cola": ColaProcessor,"mnli": MnliProcessor,"mrpc": MrpcProcessor,"xnli": XnliProcessor,"imdb": IMDBProcessor}tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint)

六、Fine-tune训练模型

假设已有标注好的数据集,接下来我们就开始用自己的数据集训练模型了

在运行时需要制定一些参数,这里给出一个较为完整的运行参数命令,如下所示:

#设置全局变量BERT_BASE_DIR:下载的bert预训练模型地址
export BERT_BASE_DIR=/xxx/chinese_L-12_H-768_A-12
#设置全局变量MY_DATASET:数据集地址
export MY_DATASET=/xxx/your_data_name python run_classifier.py \--data_dir=$MY_DATASET \  #训练数据路径--task_name=imdb \  # #自己添加processor在processors字典里的key名--vocab_file=$BERT_BASE_DIR/vocab.txt \--bert_config_file=$BERT_BASE_DIR/bert_config.json \--output_dir=$MY_DATASET/out \  # 模型输出路径--do_train=true \--do_eval=true \--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \--max_seq_length=128 \  #语句长度--train_batch_size=32 \--learning_rate=5e-5\--num_train_epochs=2.0

或者不设置全局变量,直接在传入参数的时候指定,例如可以按照下面的方式执行:

python run_classifier.py \--data_dir=$MY_DATASET \--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \--output_dir=$MY_DATASET/out \--task_name=imdb \--vocab_file=$BERT_BASE_DIR/vocab.txt \--bert_config_file=$BERT_BASE_DIR/bert_config.json \--do_train=true \--do_eval=true \--max_seq_length=128 \--train_batch_size=32 \--learning_rate=5e-5\--num_train_epochs=2.0

在这里的task_name就是我们定义的数据处理类的键。

BERT模型较大,加载时需要较大的内存,如果出现内存溢出(OOM)的问题,如下图:

可以适当的降低batch_size的值。

batch size具体的大小可参考下图进行设置

目前迭代完之后的输出比较少,而且只有等迭代结束后才会有结果输出,不利于观察损失的变化,后续将修改输出。目前的输出结果:

七、预测

如果需要使用模型预测,可以执行以下命令:

python run_classifier.py \--task_name=imdb \--do_predict=true \--data_dir=$MY_DATASET \--vocab_file=$BERT_BASE_DIR/vocab.txt \--bert_config_file=$BERT_BASE_DIR/bert_config.json \--init_checkpoint=/data/bert_model_sim \--max_seq_length=128 \--output_dir=$MY_DATASET/out/

测试完成后会在output_dir路径下生成一个test_results.tsv文件,该文件包含了测试用例和相似度probabilities。

注:执行脚本中的$BERT_BASE_DIR和$MY_DATASET结合自己的目录更改。

BERT文本分类实战相关推荐

  1. bert 文本分类实战

    前言: 由于课题需要,学习自然语言处理(NLP),于是在网上找了找文章和代码进行学习,在此记录,课题代码就不展示了,使用网上的代码和大家分享.思想和代码大部分参考苏神,在此感谢. 任务目标: 希望be ...

  2. 『NLP学习笔记』BERT文本分类实战

    BERT技术详细介绍! 文章目录 一. 数据集介绍 二. 数据读取 三. 训练集和验证集划分 四. 数据分词tokenizer 五. 定义数据读取(继承Dataset类) 六. 定义模型以及优化方法 ...

  3. 【BERT-多标签文本分类实战】之五——BERT模型库的挑选与Transformers

    ·请参考本系列目录:[BERT-多标签文本分类实战]之一--实战项目总览 ·下载本实战项目资源:>=点击此处=< [1] BERT模型库   从BERT模型一经Google出世,到tens ...

  4. 【BERT-多标签文本分类实战】之二——BERT的地位与名词术语解释

    ·请参考本系列目录:[BERT-多标签文本分类实战]之一--实战项目总览 ·下载本实战项目资源:>=点击此处=< [注]本篇将从宏观上介绍bert的产生和在众多模型中的地位,以及与bert ...

  5. 文本分类实战(十)—— BERT 预训练模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

  6. 【英文文本分类实战】之三——数据清洗

    ·请参考本系列目录:[英文文本分类实战]之一--实战项目总览 ·下载本实战项目资源:神经网络实现英文文本分类.zip(pytorch) [1] 为什么要清洗文本   这里涉及到文本分类任务中:词典.词 ...

  7. 文本分类实战(三)—— charCNN模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

  8. 文本分类实战(七)—— Adversarial LSTM模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

  9. 文本分类实战—— Bi-LSTM模型

    1 大纲概述 文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类.总共有以下系列: word2vec预训练词向量 te ...

最新文章

  1. JSP--JavaBean
  2. 技术图文:如何利用C# 实现 Prim 最小生成树算法?
  3. 自学python可以找到好的工作吗-学好python能找到好工作吗?
  4. mysql数据库周考_周考三
  5. 一文详解神经网络模型
  6. PHP保留小数三种方法
  7. Typeface 字体样式
  8. idea中新建分支并且切换到新建的分支上
  9. 泛型算法(lambda表达式、function类模板、bind函数适配器、迭代器类别、链表数据结构独有的算法)
  10. opencv编译问题
  11. 山东到底有没有互联网?
  12. Windows引导及安装
  13. Main线程与main()方法的关系
  14. JavaCore/HeapDump文件分析工具
  15. 最新修复版电影网站源码_2021版米酷影视v7.2.1源码 修复幻灯片 分类网址错误
  16. python获取浏览器cookie_python3实现读取chrome浏览器cookie
  17. 网上下载图片去水印的方法
  18. fliebeat+kafka ELK日志分析平台实战
  19. php解决缓慢http请求,php CURL 服务器响应慢的问题
  20. 传统行业如何搭建大数据团队?

热门文章

  1. 小达同学软件测试第五讲-测试技术与应用(完结)
  2. E710芯片系列模块的特性
  3. 继电保护整定值计算软件_继电保护整定电流
  4. 西农大 C plus
  5. Moglue:无需编程的交互式电子书制作软件(视频演示)
  6. 简明教程-linux 之 7z文件解压缩
  7. 如何播放无限长度的音乐
  8. 离职,你想清楚了吗?
  9. 无线网手动添加服务器,无线网络手动设置的问题解决v
  10. queued_在Linux上,诸如“ UnrecovData 10B8B BadCRC”和“失败的命令:READ FPDMA QUEUED”之类的消息有什么问题?...