使用Bert预训练模型进行文本分类

bert做文本分类,简单来说就是将每句话的第一个位置加入了特殊分类嵌入[CLS]。而该[CLS]包含了整个句子的信息,它的最终隐藏状态(即,Transformer的输出)被用作分类任务的聚合序列表示。

下载bert预训练模型

Google提供了多种预训练好的bert模型,有针对不同语言的和不同模型大小的。我们的任务是针对临床试验筛选标准进行分类,是中文模型,所以我们下载的是Bert-Base, Chinese这个模型Bert模型下载链接。
该模型解压后的目录包含bert_config.json(模型的超参数),bert_model.ckpt.data-00000-of-00001,bert_model.ckpt.index,bert_model_ckpt.meta(保存预训练模型与权重的 ckpt 文件)和vocab.txt(词表)五个文件。

在自己的数据集微调过程

任务介绍

根据预先给定的44个类别和一系列中文临床试验筛选标准的描述句子,判断该中文临床医学描述句子属于的类别。

评价指标

本任务的评价指标包括宏观准确率(Macro Precision)、宏观召回率(Macro Recall)、Average F1值。最终排名以Average F1值为基准。假设我们有n个类别,C1, … …, Ci, … …, Cn。
宏观准确率Pi=正确预测为类别Ci的样本个数/预测为Ci类样本的个数
召回率Ri=正确预测为类别Ci的样本个数/真实的Ci类样本的个数
平均F1=(1/n)求和[(2Pi*Ri)/(Pi+Ri)]

前期数据分析

训练集数据22962条
验证机数据7682条
测试集数据7697条
统计每个句子的长度,看大部分的句子长度为多少,则将最长的句子设为多少
句子最长为341个字,最短为2个字

前期数据准备
  1. 我的数据格式如下:
    要先对数据进行一定的处理,初略观察数据集,数据集中包含无用的标点符号、数字,去除这些部分,同时对文本进行分词去除掉停用词。最后得到 label+句子的格式,中间用’\t’分隔

将数据集中的后两列提取出来,同时把句子中的停用词去除,得到的结果写入到train_data.csv中

#将训练集提取标签和句子
将训练集提取标签和句子,并过滤停用词
lines=open(train_path,encoding='utf-8').read().split('\n')
with open(new_train,'w',encoding='utf-8')as w:content=''for i in range(len(lines)-1):output=''item=lines[i].split('\t')content += str(item[1])content += '\t'seg=jieba.cut(str(item[2]),cut_all=False)for j in seg:if j not in stopwordset:output+=jcontent += str(output)content += '\n'# print(content)w.write(str(content))
w.close()

得到处理后的数据如下:

2. 我们需要在run_classifiler.py中定义自己任务的DataProcessor子类,根据我们数据集的格式重写获取训练集、开发集、测试集的样本数据的方法以及获取标签的方法

我们可以仿照Cola处理器来写我们自己文本分类的处理器

我自己文本分类的处理器,一共分为44个类别

class classification(DataProcessor):def get_train_examples(self, data_dir):return self._create_examples(self._read_tsv(os.path.join(data_dir, "train_data.csv")), "train")def get_dev_examples(self, data_dir):return self._create_examples(self._read_tsv(os.path.join(data_dir, "val_data.csv")), "dev")def get_test_examples(self, data_dir):return self._create_examples(self._read_tsv(os.path.join(data_dir, "test_data.csv")), "test")def get_labels(self):#一共有44个标签return ['Disease', 'Symptom','Sign','Pregnancy-related Activity','Neoplasm Status','Non-Neoplasm Disease Stage','Allergy Intolerance','Organ or Tissue Status','Life Expectancy', 'Oral related','Pharmaceutical Substance or Drug','Therapy or Surgery','Device','Nursing','Diagnostic', 'Laboratory Examinations','Risk Assessment','Receptor Status','Age','Special Patient Characteristic','Literacy','Gender','Education', 'Address','Ethnicity','Consent','Enrollment in other studies','Researcher Decision','Capacity','Ethical Audit','Compliance with Protocol','Addictive Behavior', 'Bedtime','Exercise','Diet','Alcohol Consumer','Sexual related', 'Smoking Status','Blood Donation','Encounter','Disabilities','Healthy','Data Accessible','Multiple']def _create_examples(self, lines, set_type):"""Creates examples for the training and dev sets."""examples = []for (i, line) in enumerate(lines):guid = "%s-%s" % (set_type, i)if set_type=="test":text_a = tokenization.convert_to_unicode(line[0])label = "Disease"#当是测试集的时候,标签默认为diseaseelse:text_a = tokenization.convert_to_unicode(line[1])label = tokenization.convert_to_unicode(line[0])examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))return examples
  1. 在run_classifiler.py中修改main函数,将自己写的文本分类处理器加入到processors中
  processors = {"cola": ColaProcessor,"mnli": MnliProcessor,"mrpc": MrpcProcessor,"xnli": XnliProcessor,"task":classification,}
  1. 执行run_classifiler.py

    需要必填参数data_dir,task_name,vocab_file,bert_config_file,output_dir。参数do_train,do_eval和do_predict分别控制了是否进行训练,评估和预测,可以按需将其设置为True或者False
  2. 参数的设置
python run_classifier.py --data_dir=data --task_name=classification--vocab_file=chinese_L-12_H-768_A-12R/vocab.txt --bert_config_file=chinese_L-12_H-768_A-12/bert_config.json --output_dir=output--do_train=true --do_eval=false--init_checkpoint=chinese_L-12_H-768_A-12/bert_model.ckpt --max_seq_length=200 (句子的最大长度,可以根据数据集大部分数据的长度来设计)--train_batch_size=16 (如果内存太小可以适当缩小,每轮喂入的多少条数据)--learning_rate=5e-5--num_train_epochs=2.0
  1. 可以增加验证集的指标在分类时,我们可能还需要看auc,recall,precision的值。
   def metric_fn(per_example_loss, label_ids, logits, is_real_example):predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)#求loggits[batch size,num_labels]的行最大值的下标accuracy = tf.metrics.accuracy(labels=label_ids, predictions=predictions, weights=is_real_example)auc=tf.metrics.auc(labels=label_ids,predictions=predictions,weights=is_real_example)precision=tf.metrics.precision(labels=label_ids,predictions=predictions,weights=is_real_example)recall=tf.metrics.recall(labels=label_ids,predictions=predictions,weights=is_real_example)f1_score=tf.metrics.mean((2*precision*recall)/(precision+recall))loss = tf.metrics.mean(values=per_example_loss, weights=is_real_example) #求加权平均lossreturn {"eval_accuracy": accuracy,"eval_auc":auc,"eval_precision":precision,"eval_recall":recall,"eval_f1":f1_score,"eval_loss": loss,}

成功运行

python run_classifier.py  -—bert_config_file=chinese_L-12_H-768_A-12/bert_config.json --vocab_file=chinese_L-12_H-768_A-12/vocab.txt --init_checkpoint=chinese_L-12_H-768_A-12/bert_model.ckpt


将输入转为特征

从init_checkpoint读取参数

因为自己电脑cpu跑的原因,所以选择了部分数据,训练的轮次也设得很小,效果不是很好。

参考
教程:使用Bert预训练模型文本分类
BERT文本分类使用指南
文本分类实战(十)—— BERT 预训练模型

Bert实战--文本分类(一)相关推荐

  1. 【NLP】Kaggle从零到实践:Bert中文文本分类

    Bert是非常强化的NLP模型,在文本分类的精度非常高.本文将介绍Bert中文文本分类的基础步骤,文末有代码获取方法. 步骤1:读取数据 本文选取了头条新闻分类数据集来完成分类任务,此数据集是根据头条 ...

  2. 二分类问题:基于BERT的文本分类实践!附完整代码

    Datawhale 作者:高宝丽,Datawhale优秀学习者 寄语:Bert天生适合做分类任务.文本分类有fasttext.textcnn等多种方法,但在Bert面前,就是小巫见大巫了. 推荐评论展 ...

  3. tensorflow 加载bert_用NodeJS/TensorFlowJS调用BERT实现文本分类

    题图 "JavaScript Logo"byb0neskullis licensed underCC BY-NC-SA 2.0 几个提前知识 TensorFlowJS可以简单认为有 ...

  4. 实战文本分类对抗攻击

    文章写得比较长,先列出大纲,以便读者直取重点. "文本分类对抗攻击"是清华大学和阿里安全2020年2月举办的一场AI比赛,从开榜到比赛结束20天左右,内容是主办方在线提供1000条 ...

  5. Flair实战文本分类

    2019独角兽企业重金招聘Python工程师标准>>> Flair是一个基于PyTorch构建的NLP开发包,它在解决命名实体识别(NER).语句标注(POS).文本分类等NLP问题 ...

  6. Pytorch Bert+BiLstm文本分类

    文章目录 前言 一.运行环境 二.数据 三.模型结构 四.训练 五.测试及预测 前言 昨天按照该文章(自然语言处理(NLP)Bert与Lstm结合)跑bert+bilstm分类的时候,没成功跑起来,于 ...

  7. 使用Bert进行文本分类(立场检测)

    主要流程:载入数据-->进行分词(单个词)-->根据词典转换为编码-->加载BERT预训练模型-->微调出结果.

  8. 《基于Tensorflow的知识图谱实战》 --- 实战文本分类与命名实体识别,快速构建知识图谱(王晓华 著)

    ⚽开发平台:jupyter lab

  9. bert 文本分类实战

    前言: 由于课题需要,学习自然语言处理(NLP),于是在网上找了找文章和代码进行学习,在此记录,课题代码就不展示了,使用网上的代码和大家分享.思想和代码大部分参考苏神,在此感谢. 任务目标: 希望be ...

最新文章

  1. 三调 图斑地类面积_国土三调攻坚冲刺,大疆无人机为调查举证提供加速度
  2. 将来以静态网页形式展示漏洞影响产品信息
  3. spring 源码 找不到 taskprovider_Spring 源码阅读环境的搭建
  4. 简单比较python语言和c语言的异同-Python快速入门之与C语言异同
  5. IOS基础之Foundation框架常用类NSFileManager,DSDate,CGPoint,CGSize,copy,单例
  6. ABB机器人 系统参数配置
  7. c++读取excel_Java 嵌入 SPL 轻松实现 Excel 文件合并
  8. 省一级计算机选择题题库及答案,计算机一级考试选择题题库之excel题及答案(最新版).doc...
  9. Day4 dict和set
  10. ”天空之城”的主题曲《君をのせて》
  11. vue.js引入外部CSS样式和外部JS文件的方法
  12. Android Glide加载本地gif动态图
  13. 数据驱动型企业的海外服务器管理实践
  14. php那好,php那好【货币问答】- php那好所有答案 - 联合货币
  15. mt4量化交易接口:分享日常量化选股方法
  16. python读取word文档并做简单的批量文档筛选
  17. 基于51单片机智能手机锂电池充电器设计
  18. 电脑派位系统(新生入学摇号) v2016
  19. 网上报修系统java源码_网上报修系统管理软件
  20. BUUCTF [GYCTF2020] Blacklist

热门文章

  1. otf和ctf的意义_北京邮电大学出版社
  2. idea条件断点和异常断点
  3. MATLAB中写TXT文件换行的实现
  4. 计时+事件倒计时网页源码分享
  5. 5G的主要业务场景:eMBB、URLLC、mMTC
  6. 【算法和数据结构】模拟和暴力
  7. 格式: echo -e \033[字背景颜色;字体颜色m字符串\033[0m
  8. 人工智能就业方向及前景,前景如何?好就业吗?
  9. tekton TriggerTemplate资源
  10. 基于PaddleSpeech搭建个人语音听写服务