1 大纲概述

  文本分类这个系列将会有十篇左右,包括基于word2vec预训练的文本分类,与及基于最新的预训练模型(ELMo,BERT等)的文本分类。总共有以下系列:

  word2vec预训练词向量

  textCNN 模型

  charCNN 模型

  Bi-LSTM 模型

  Bi-LSTM + Attention 模型

  RCNN 模型

  Adversarial LSTM 模型

  Transformer 模型

  ELMo 预训练模型

  BERT 预训练模型

  所有代码均在textClassifier仓库中。

2 数据集

  数据集为IMDB 电影影评,总共有三个数据文件,在/data/rawData目录下,包括unlabeledTrainData.tsv,labeledTrainData.tsv,testData.tsv。在进行文本分类时需要有标签的数据(labeledTrainData),数据预处理如文本分类实战(一)—— word2vec预训练词向量中一样,预处理后的文件为/data/preprocess/labeledTrain.csv。

3 BERT预训练模型

  BERT 模型来源于论文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding。BERT模型是谷歌提出的基于双向Transformer构建的语言模型。BERT模型和ELMo有大不同,在之前的预训练模型(包括word2vec,ELMo等)都会生成词向量,这种类别的预训练模型属于domain transfer。而近一两年提出的ULMFiT,GPT,BERT等都属于模型迁移。

  BERT 模型是将预训练模型和下游任务模型结合在一起的,也就是说在做下游任务时仍然是用BERT模型,而且天然支持文本分类任务,在做文本分类任务时不需要对模型做修改。谷歌提供了下面七种预训练好的模型文件。

  

  BERT模型在英文数据集上提供了两种大小的模型,Base和Large。Uncased是意味着输入的词都会转变成小写,cased是意味着输入的词会保存其大写(在命名实体识别等项目上需要)。Multilingual是支持多语言的,最后一个是中文预训练模型。

  在这里我们选择BERT-Base,Uncased。下载下来之后是一个zip文件,解压后有ckpt文件,一个模型参数的json文件,一个词汇表txt文件。

  在应用BERT模型之前,我们需要去github上下载开源代码,我们可以直接clone下来,在这里有一个run_classifier.py文件,在做文本分类项目时,我们需要修改这个文件,主要是添加我们的数据预处理类。clone下来的项目结构如下:

    

  在run_classifier.py文件中有一个基类DataProcessor类,其代码如下:

class DataProcessor(object):"""Base class for data converters for sequence classification data sets."""def get_train_examples(self, data_dir):"""Gets a collection of `InputExample`s for the train set."""raise NotImplementedError()def get_dev_examples(self, data_dir):"""Gets a collection of `InputExample`s for the dev set."""raise NotImplementedError()def get_test_examples(self, data_dir):"""Gets a collection of `InputExample`s for prediction."""raise NotImplementedError()def get_labels(self):"""Gets the list of labels for this data set."""raise NotImplementedError()@classmethoddef _read_tsv(cls, input_file, quotechar=None):"""Reads a tab separated value file."""with tf.gfile.Open(input_file, "r") as f:reader = csv.reader(f, delimiter="\t", quotechar=quotechar)lines = []for line in reader:lines.append(line)return lines

  在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。接下来我们要定义自己的数据处理的类,我们将我们的类命名为IMDBProcessor

class IMDBProcessor(DataProcessor):"""IMDB data processor"""def _read_csv(self, data_dir, file_name):with tf.gfile.Open(data_dir + file_name, "r") as f:reader = csv.reader(f, delimiter=",", quotechar=None)lines = []for line in reader:lines.append(line)return linesdef get_train_examples(self, data_dir):lines = self._read_csv(data_dir, "trainData.csv")examples = []for (i, line) in enumerate(lines):if i == 0:continueguid = "train-%d" % (i)text_a = tokenization.convert_to_unicode(line[0])label = tokenization.convert_to_unicode(line[1])examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examplesdef get_dev_examples(self, data_dir):lines = self._read_csv(data_dir, "devData.csv")examples = []for (i, line) in enumerate(lines):if i == 0:continueguid = "dev-%d" % (i)text_a = tokenization.convert_to_unicode(line[0])label = tokenization.convert_to_unicode(line[1])examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examplesdef get_test_examples(self, data_dir):lines = self._read_csv(data_dir, "testData.csv")examples = []for (i, line) in enumerate(lines):if i == 0:continueguid = "test-%d" % (i)text_a = tokenization.convert_to_unicode(line[0])label = tokenization.convert_to_unicode(line[1])examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examplesdef get_labels(self):return ["0", "1"]

  在这里我们没有直接用基类中的静态方法_read_tsv,因为我们的csv文件是用逗号分隔的,因此就自己定义了一个_read_csv的方法,其余的方法就是读取训练集,验证集,测试集和标签。在这里标签就是一个列表,将我们的类别标签放入就行。训练集,验证集和测试集都是返回一个InputExample对象的列表。InputExample是run_classifier.py中定义的一个类,代码如下:

class InputExample(object):"""A single training/test example for simple sequence classification."""def __init__(self, guid, text_a, text_b=None, label=None):"""Constructs a InputExample.Args:guid: Unique id for the example.text_a: string. The untokenized text of the first sequence. For singlesequence tasks, only this sequence must be specified.text_b: (Optional) string. The untokenized text of the second sequence.Only must be specified for sequence pair tasks.label: (Optional) string. The label of the example. This should bespecified for train and dev examples, but not for test examples."""self.guid = guidself.text_a = text_aself.text_b = text_bself.label = label

  在这里定义了text_a和text_b,说明是支持句子对的输入的,不过我们这里做文本分类只有一个句子的输入,因此text_b可以不传参。

  另外从上面我们自定义的数据处理类中可以看出,训练集和验证集是保存在不同文件中的,因此我们需要将我们之前预处理好的数据提前分割成训练集和验证集,并存放在同一个文件夹下面,文件的名称要和类中方法里的名称相同。

  到这里之后我们已经准备好了我们的数据集,并定义好了数据处理类,此时我们需要将我们的数据处理类加入到run_classifier.py文件中的main函数下面的processors字典中,结果如下:

  

  之后就可以直接执行run_classifier.py文件,执行脚本如下:

export BERT_BASE_DIR=../modelParams/uncased_L-12_H-768_A-12export DATASET=../data/python run_classifier.py \--data_dir=$MY_DATASET \--task_name=imdb \--vocab_file=$BERT_BASE_DIR/vocab.txt \--bert_config_file=$BERT_BASE_DIR/bert_config.json \--output_dir=../output/ \--do_train=true \--do_eval=true \--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \--max_seq_length=200 \--train_batch_size=16 \--learning_rate=5e-5\--num_train_epochs=2.0 

  在这里的task_name就是我们定义的数据处理类的键,BERT模型较大,加载时需要较大的内存,如果出现内存溢出的问题,可以适当的降低batch_size的值。

  目前迭代完之后的输出比较少,而且只有等迭代结束后才会有结果输出,不利于观察损失的变化,后续将修改输出。目前的输出结果:

  

  测试集上的准确率达到了90.7% ,这个结果比Bi-LSTM + Attention(87.7%)的结果要好。

4 增加验证集输出的指标值

  目前验证集上的输出指标值只有loss和accuracy,如上图所示,然而在分类时,我们可能还需要看auc,recall,precision的值。增加几行代码就可以搞定:

  

  在我的代码中743行这里有个metric_fn函数,之前这个函数下只有loss和accuracy的计算,我们在这里加上auc,recall,precision的计算,然后加入到return的这个字典中就可以了。现在的输出结果:

  

 

5 关于BERT的问题

  在run_classifier.py文件中,训练模型,验证模型都是用的tensorflow中的estimator接口,因此我们无法实现在训练迭代100步就用验证集验证一次,在run_classifier.py文件中提供的方法是先运行完所有的epochs之后,再加载模型进行验证。训练模型时的代码:

  

  在我的代码中948行这里,在这里我们加入了几行代码,可以实现训练时输出loss,就是上面的:

tensors_to_log = {"train loss": "loss/Mean:0"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)

  这是我们添加进去的,加入了一个hooks的参数,让训练的时候没迭代100步就输出一次loss。然而这样的意义并不是很大。

  下面的日志可以看到验证时是加载训练完的模型来进行验证的,见下图第一行:Restoring xxx

  

  这种无法在训练时输出验证集上的结果,会导致我们很难直观的看到损失函数的变化。就无法很方便的确定模型是否收敛,这也是tensorflow中这些高级API的问题,高级封装虽然让书写代码更容易,但也让代码更死板。

  

  在https://github.com/jiangxinyang227/bert-for-task中提供了bert,albert在各种任务中的应用,代码已经标准化,可以快速的训练,预测,线上部署。

分类: 文本分类, 自然语言处理

好文要顶 关注我 收藏该文  

微笑sun
关注 - 18
粉丝 - 444

+加关注

16

0

« 上一篇: 文本分类实战(九)—— ELMO 预训练模型 
» 下一篇: BERT模型在多类别文本分类时的precision, recall, f1值的计算

文本分类实战(十)—— BERT 预训练模型相关推荐

  1. 【文本分类】基于BERT预训练模型的灾害推文分类方法、基于BERT和RNN的新闻文本分类对比

    ·阅读摘要: 两篇论文,第一篇发表于<图学学报>,<图学学报>是核心期刊:第二篇发表于<北京印刷学院学报>,<北京印刷学院学报>没有任何标签. ·参考文 ...

  2. 天池零基础入门NLP竞赛实战:Task4-基于深度学习的文本分类3-基于Bert预训练和微调进行文本分类

    Task4-基于深度学习的文本分类3-基于Bert预训练和微调进行文本分类 因为天池这个比赛的数据集是脱敏的,无法利用其它已经预训练好的模型,所以需要针对这个数据集自己从头预训练一个模型. 我们利用H ...

  3. 6个你应该用用看的用于文本分类的最新开源预训练模型 忆臻

    作者:PURVA HUILGOL 编译:ronghuaiyang (AI公园) 原文链接: 6个你应该用用看的用于文本分类的最新开源预训练模型​mp.weixin.qq.com 导读 文本分类是NLP ...

  4. 6个用于文本分类的最新开源预训练模型(NLP必备)

    作者:PURVA HUILGOL  编译:ronghuaiyang 导读 文本分类是NLP的基础任务之一,今天给大家介绍6个最新的预训练模型,做NLP的同学一定要用用看. 介绍 我们正站在语言和机器的 ...

  5. 如何学习使用Bert预训练模型

    目录 (一)bert预训练模型下载 (二)使用bert做中文文本分类 (一)bert预训练模型下载 在bert官网下载自己需要的预训练模型. 下图是进入官网的图片. 点击想要选择的模型,选择Files ...

  6. Pytorch——BERT 预训练模型及文本分类(情感分类)

    BERT 预训练模型及文本分类 介绍 如果你关注自然语言处理技术的发展,那你一定听说过 BERT,它的诞生对自然语言处理领域具有着里程碑式的意义.本次试验将介绍 BERT 的模型结构,以及将其应用于文 ...

  7. 手把手教 | 使用Bert预训练模型文本分类(内附源码)

    作者:GjZero 标签:Bert, 中文分类, 句子向量 本文约1500字,建议阅读8分钟. 本文从实践入手,带领大家进行Bert的中文文本分类和作为句子向量进行使用的教程. Bert介绍 Bert ...

  8. 使用Bert预训练模型进行中文文本分类(基于pytorch)

    前言 最近在做一个关于图书系统的项目,需要先对图书进行分类,想到Bert模型是有中文文本分类功能的,于是打算使用Bert模型进行预训练和实现下游文本分类任务 数据预处理 2.1 输入介绍 在选择数据集 ...

  9. 【BERT-多标签文本分类实战】之五——BERT模型库的挑选与Transformers

    ·请参考本系列目录:[BERT-多标签文本分类实战]之一--实战项目总览 ·下载本实战项目资源:>=点击此处=< [1] BERT模型库   从BERT模型一经Google出世,到tens ...

最新文章

  1. TCP拥塞控制算法内核实现剖析(二)
  2. 后台开发经典书籍--unix环境高级编程
  3. 计算机世界的虚拟机,容器和医学界的人工硬脑膜
  4. 93号涨0.86元售6.2元/升 20日油价正式上调
  5. CF802C-Heidi and Library(hard)【费用流】
  6. leetcode练习:292. Nim Game
  7. 今日恐慌与贪婪指数为38 恐慌程度明显上升
  8. metacube 链接 mysql_2019 年 5月 随笔档案 - rgqancy - 博客园
  9. MATLAB-图像分割
  10. 微讲师录课软件下载、录屏软件下载
  11. 网络安全和CTF相关内容
  12. 再谈微软复兴,纳德拉与库克、马斯克、皮查伊在管理上有什么不同
  13. php源码比赛,TSRC挑战赛: PHP防御绕过挑战实录
  14. 微信小程序入门开发教程(详解)
  15. 浏览器flash过期无法使用完美解决
  16. 多分区装linux系统,Linux安装之多系统分区
  17. 【PM必知】项目管理的“六大核心”内容详解
  18. sentencepiece原理与实践
  19. 用ul制作html表单,要利用 display属性把段落P、标题h1、表单form、列表ul和li都可以定义成行内块元素,其属性值为...
  20. Git 回滚Rollback

热门文章

  1. python练习题:跟电脑玩剪刀石头布,一直循环玩,可手动退出,推出后可以计算玩家胜率
  2. 面向安全需求的VANET信道拥塞联合控制框架
  3. 目标检测FCOS的初步理解
  4. Springboot2.1.1版本升级到2.3.10版本报错合集及解决办法
  5. Yum,搭建软件仓库
  6. 【Linux】一步一步学Linux——readelf命令(253)
  7. SSD 闪存盘技术详解
  8. SRC任意账号密码重置的6种方法
  9. android多渠道打包签名配置,Gradle For Android(二) 多渠道打包与签名配置
  10. pta判断回文字符串