BERT模型也出来很久了,之前看了论文学习过它的大致模型(可以参考前些日子写的笔记NLP大杀器BERT模型解读),但是一直有杂七杂八的事拖着没有具体去实现过真实效果如何。今天就趁机来动手写一写实战,顺便复现一下之前的内容。这篇文章的内容还是以比较简单文本分类任务入手,数据集选取的是新浪新闻cnews,包括了[‘体育’, ‘财经’, ‘房产’, ‘家居’, ‘教育’, ‘科技’, ‘时尚’, ‘时政’, ‘游戏’, ‘娱乐’]总共十个主题的新闻数据。那么我们就开始吧!

Transformer模型

BERT模型就是以Transformer基础上训练出来的嘛,所以在开始之前我们首先复习一下目前NLP领域可以说是最高效的‘变形金刚’Transformer。由于网上Transformer介绍解读文章满天飞了都,这里就不浪费太多时间了。

本质上来说,Transformer就是一个只由attention机制形成的encoder-decoder结构。关于attention的具体介绍可以参考之前这篇理解Attention机制原理及模型。理解Transformer模型可以将其进行解剖,分成几个组成部分:

  1. Embedding (word + position)
  2. Attention mechanism (scaled dot-product + multi-head)
  3. Feed-Forward network
  4. ADD(类似于Resnet里的残差操作)
  5. Norm(加快收敛)
  6. Softmax
  7. Fine-tuning

前期准备

1.下载BERT

我们要使用BERT模型的话,首先要去github上下载相关源码:

git clone  https://github.com/google-research/bert.git

下载成功以后我们现在的文件大概就是这样的

2.下载bert预训练模型

Google提供了多种预训练好的bert模型,有针对不同语言的和不同模型大小的。Uncased参数指的是将数据全都转成小写的(大多数任务使用Uncased模型效果会比较好,当然对于一些大小写影响严重的任务比如NER等就可以选择Cased)

对于中文模型,我们使用Bert-Base, Chinese。下载后的文件包括五个文件:

bert_model.ckpt:有三个,包含预训练的参数
vocab.txt:词表
bert_config.json:保存模型超参数的文件

3. 数据集准备

前面有提到过数据使用的是新浪新闻分类数据集,每一行组成是 【标签+ TAB + 文本内容】

Start Working

BERT非常友好的一点就是对于NLP任务,我们只需要对最后一层进行微调便可以用于我们的项目需求。我们只需要将我们的数据输入处理成标准的结构进行输入就可以了。

DataProcessor基类

首先在run_classifier.py文件中有一个基类DataProcessor类:

class DataProcessor(object):"""Base class for data converters for sequence classification data sets."""def get_train_examples(self, data_dir):"""Gets a collection of `InputExample`s for the train set."""raise NotImplementedError()def get_dev_examples(self, data_dir):"""Gets a collection of `InputExample`s for the dev set."""raise NotImplementedError()def get_test_examples(self, data_dir):"""Gets a collection of `InputExample`s for prediction."""raise NotImplementedError()def get_labels(self):"""Gets the list of labels for this data set."""raise NotImplementedError()@classmethoddef _read_tsv(cls, input_file, quotechar=None):"""Reads a tab separated value file."""with tf.gfile.Open(input_file, "r") as f:reader = csv.reader(f, delimiter="\t", quotechar=quotechar)lines = []for line in reader:lines.append(line)return lines

在这个基类中定义了一个读取文件的静态方法_read_tsv,四个分别获取训练集,验证集,测试集和标签的方法。接下来我们要定义自己的数据处理的类,我们将我们的类命名为MyTaskProcessor

编写MyTaskProcessor

MyTaskProcessor继承DataProcessor,用于定义我们自己的任务

class MyTaskProcessor(DataProcessor):"""Processor for my task-news classification """def __init__(self):self.labels = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']def get_train_examples(self, data_dir):return self._create_examples(self._read_tsv(os.path.join(data_dir, 'cnews.train.txt')), 'train')def get_dev_examples(self, data_dir):return self._create_examples(self._read_tsv(os.path.join(data_dir, 'cnews.val.txt')), 'val')def get_test_examples(self, data_dir):return self._create_examples(self._read_tsv(os.path.join(data_dir, 'cnews.test.txt')), 'test')def get_labels(self):return self.labelsdef _create_examples(self, lines, set_type):"""create examples for the training and val sets"""examples = []for (i, line) in enumerate(lines):guid = '%s-%s' %(set_type, i)text_a = tokenization.convert_to_unicode(line[1])label = tokenization.convert_to_unicode(line[0])examples.append(InputExample(guid=guid, text_a=text_a, label=label))return examples

注意这里有一个self._read_tsv()方法,规定读取的数据是使用TAB分割的,如果你的数据集不是这种形式组织的,需要重写一个读取数据的方法,更改“_create_examples()”的实现。

编写main以及训练

至此我们就完成了对我们的数据加工成BERT所需要的格式,就可以进行模型训练了。

def main(_):tf.logging.set_verbosity(tf.logging.INFO)processors = {"cola": ColaProcessor,"mnli": MnliProcessor,"mrpc": MrpcProcessor,"xnli": XnliProcessor,"mytask": MyTaskProcessor,}
python run_classifier.py \--task_name=mytask \--do_train=true \--do_eval=true \--data_dir=$DATA_DIR/ \--vocab_file=$BERT_BASE_DIR/vocab.txt \--bert_config_file=$BERT_BASE_DIR/bert_config.json \--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \--max_seq_length=128 \--train_batch_size=32 \--learning_rate=2e-5 \--num_train_epochs=3.0 \--output_dir=mytask_output

其中DATA_DIR是你的要训练的文本的数据所在的文件夹,BERT_BASE_DIR是你的bert预训练模型存放的地址。task_name要求和你的DataProcessor类中的名称一致。下面的几个参数,do_train代表是否进行fine tune,do_eval代表是否进行evaluation,还有未出现的参数do_predict代表是否进行预测。如果不需要进行fine tune,或者显卡配置太低的话,可以将do_trian去掉。max_seq_length代表了句子的最长长度,当显存不足时,可以适当降低max_seq_length。

BERT prediction

上面一节主要就是介绍了怎么去根据我们实际的任务(多文本分类)去fine-tune bert模型,那么训练好适用于我们特定的任务的模型后,接下来就是使用这个模型去做相应地预测任务。预测阶段唯一需要做的就是修改 – do_predict=true。你需要将测试样本命名为test.csv,输出会保存在输出文件夹的test_result.csv,其中每一行代表一个测试样本对应的预测输出,每一列代表对应于不同类别的概率。

export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
export GLUE_DIR=/path/to/glue
export TRAINED_CLASSIFIER=/path/to/fine/tuned/classifierpython run_classifier.py \--task_name=MRPC \--do_predict=true \--data_dir=$GLUE_DIR/MRPC \--vocab_file=$BERT_BASE_DIR/vocab.txt \--bert_config_file=$BERT_BASE_DIR/bert_config.json \--init_checkpoint=$TRAINED_CLASSIFIER \--max_seq_length=128 \--output_dir=/tmp/mrpc_output/

有趣的优化

指定训练时输出loss

bert自带代码中是这样的,在run_classifier.py文件中,训练模型,验证模型都是用的tensorflow中的estimator接口,因此我们无法实现在训练迭代100步就用验证集验证一次,在run_classifier.py文件中提供的方法是先运行完所有的epochs之后,再加载模型进行验证。训练模型时的代码:

train_input_fn = file_based_input_fn_builder(input_file=train_file,seq_length=FLAGS.max_seq_length,is_training=True,drop_remainder=True)estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

想要实现在训练过程中输出loss日志,我们可以使用hooks参数:

train_input_fn = file_based_input_fn_builder(input_file=train_file,seq_length=FLAGS.max_seq_length,is_training=True,drop_remainder=True)tensors_to_log = {'train loss': 'loss/Mean:0'}logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)estimator.train(input_fn=train_input_fn, hooks=[logging_hook], max_steps=num_train_steps)
增加验证集输出的指标值

原生BERT代码中验证集的输出指标值只有loss和accuracy,

def metric_fn(per_example_loss, label_ids, logits, is_real_example):predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example)loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)return {"eval_accuracy": accuracy,"eval_loss": loss,}

但是在分类时,我们可能还需要分析auc,recall,precision等的值。

def metric_fn(per_example_loss, label_ids, logits, is_real_example):predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example)loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example)auc = tf.metrics.auc(labels=label_ids, predictions=predictions, weights=is_real_example)precision = tf.metrics.precision(labels=label_ids, predictions=predictions, weights=is_real_example)recall = tf.metrics.recall(labels=label_ids, predictions=predictions, weights=is_real_example)return {"eval_accuracy": accuracy,"eval_loss": loss,'eval_auc': auc,'eval_precision': precision,'eval_recall': recall,}


以上~
2019.03.21

BERT模型实战之多文本分类(附源码)相关推荐

  1. 何使用BERT模型实现中文的文本分类

    原文网址:https://blog.csdn.net/Real_Brilliant/article/details/84880528 如何使用BERT模型实现中文的文本分类 前言 Pytorch re ...

  2. Android富文本编辑器附源码

    Android富文本编辑器附源码 1,源码分析 本软件是Android端创建富文本数据,向服务器发送;安卓端创建数据可以是文字,图片,语音,文件,还可以录音,可以拍照.录音完毕保存的时候应该提交给服务 ...

  3. Python实战例子(32个附源码)

    Python是一种高级编程语言,具有简洁.清晰的语法,易于理解和使用,因此受到广泛的欢迎.尤其在数据科学.人工智能.机器学习.自然语言处理等领域,Python已成为最受欢迎的编程语言之一.Python ...

  4. 100个Python实战练手项目(附源码+素材),学习必备

    前言: 不管学习哪门语言都希望能做出实际的东西来,这个实际的东西当然就是项目啦,不用多说大家都知道学编程语言一定要做项目才行. 这里整理了最新32个Python实战项目列表,都有完整且详细的视频教程和 ...

  5. 超简单的pyTorch训练-onnx模型-C++ OpenCV DNN推理(附源码地址)

    学更好的别人, 做更好的自己. --<微卡智享> 本文长度为1974字,预计阅读5分钟 前言 很早就想学习深度学习了,因为平时都是自学,业余时间也有限,看过几个pyTorch的入门,都是一 ...

  6. 基于卷积神经网络的句子分类模型【经典卷积分类附源码链接】

    https://www.toutiao.com/a6680124799831769603/ 基于卷积神经网络的句子分类模型 题目: Convolutional Neural Networks for ...

  7. python k-means聚类算法 物流分配预测实战(超详细,附源码)

    数据集和地图可以点赞关注收藏后评论区留下QQ邮箱或者私信博主要 聚类是一类机器学习基础算法的总称. 聚类的核心计算过程是将数据对象集合按相似程度划分成多个类,划分得到的每个类称为聚类的簇 聚类不等于分 ...

  8. 《Tensorflow 实战》(完整版,附源码)

    向AI转型的程序员都关注了这个号

  9. Android App开发实战项目之购物车(附源码 超详细必看)

    需要源码请点赞关注收藏后评论区留言~~~ 一.需求描述 电商App的购物车可谓是司空见惯了,可以知道购物车除了底部有一个结算行,其余部分主要是已加入购物车的商品列表,然后每个商品左边是商品小图,右边是 ...

最新文章

  1. 【网络】通讯名词解释:带宽、速率、波特率、奈奎斯特定律、香农定理
  2. 7 papers | 对抗样本前,BERT也不行;AutoML的商业实践综述
  3. 手写识别python_Python徒手实现识别手写数字—图像识别算法(K最近邻)
  4. 解决Extjs中textarea不支持keyup事件的问题
  5. 11.4 上限分析-机器学习笔记-斯坦福吴恩达教授
  6. 【.net 深呼吸】自定义应用程序配置节
  7. 基2FFT算法matlab程序编写,频率抽取(DIF)基2FFT算法的MATLAB实现
  8. 求连通域面积matlab
  9. Python自定义分页组件
  10. linux:C++的socket编程
  11. TomCat7安装与配置
  12. element table 组件内容换行
  13. mysql数据库维护(mysql学习笔记)
  14. 将ajax的值传给控制器,ASP.Net C#MCV - 将值从Ajax Jquery传递给Controller(示例代码)
  15. 深圳高中女生街头版someone like you
  16. 看了《我的白大褂》才明白,原来平安是福
  17. yolov5方框的颜色及粗细更改
  18. 求生之路2正版rpg服务器,求生之路2怎么屏蔽rpg服务器 求生之路2屏蔽rpg服务器方法-游侠网...
  19. html钟表代码运行原理,·钟表指针运行方向的基本原理
  20. 新概念二册 Lesson 21 Mad or not?是不是疯了? ( 被动语态)

热门文章

  1. SAP License:作业费用分割均分常见原因
  2. SAP License:你是工程师还是顾问
  3. SAP License:SAP应用随想
  4. web端业务数据管理平台+Axure运营数据管理平台+月度数据统计分析+年度排行榜数据统计页面分析+运营大数据统计管理后台+用户信息管理+Axure通用web端高保真交互业务数据管理平台
  5. Android手机刷机失败的自救方法
  6. Java SE 6 中实现 Cookie 功能
  7. 如何修改博客园里个人首页背景(form:cot 大犇)
  8. 基于设备树的TQ2440 DMA学习(3)—— DMA控制器驱动
  9. Understanding ASP.NET Validation Techniques
  10. 徐中约与《中国近代史》 (zz)