Flair是一个基于PyTorch构建的NLP开发包,它在解决命名实体识别(NER)、语句标注(POS)、文本分类等NLP问题时达到了当前的顶尖水准。本文将介绍如何使用Flair构建定制的文本分类器。

简介

文本分类是一种用来将语句或文档归入一个或多个分类的有监督机器学习方法,被广泛应用于垃圾邮件过滤、情感分析、新文章归类等众多业务领域。

当前绝大多数领先的文本分类方法都依赖于文本嵌入技术,它将文本转换为高维空间的数值表示,可以将文档、句子、单次或字符表示为这个高维空间的一个向量。

Flair基于Zalando Research的论文“用于串行标准的上下文相关字符串嵌入”,论文算法表现可以毙掉之前的最好方案,该算法在Flair中得到完整实现,可以用来构建文本分类器。

1. 准备

Flair安装需要Python 3.6,执行pip安装即可:

~$ pip install flair

上面的命令将安装运行Flair所需要的依赖包,当然也包括了PyTorch。

2. 使用训练好的预置分类模型

最新的Flair 0.4版本包含有两个预先训练好的模型。一个基于IMDB数据集训练的情感分析模型和一个攻击性语言探测模型(当前仅支持德语)。

只需一个命令就可以下载、存储并使用模型,这使得预置模型的使用过程异常简单。例如,下面的代码将使用情感分析模型:

from flair.models import TextClassifier
from flair.data import Sentenceclassifier = TextClassifier.load('en-sentiment')sentence = Sentence('Flair is pretty neat!')
classifier.predict(sentence)# print sentence with predicted labels
print('Sentence above is: ', sentence.labels)

当第一次运行上述代码时,Flari将下载情感分析模型,默认情况下会保存到本地用户主目录的.flair子目录,下载可能需要几分钟。

上面的代码首先载入必要的库,然后载入情感分析模型到内存中(必要时先下载),接下来就可以预测“Flair is pretty neat!”的情感分值了(0~1之间)。最后的命令输入结果为:

The sentence above is: [Positive (1.0)].

就是这么简单!现在你可以将上述代码整合为一个REST API,提供类似于google云端情感分析API的功能了!

3. 训练自定义文本分类器

要训练一个自定义的文本分类器,首先需要一个标注文本集。Flair的分类数据集格式基于Facebook的FastText格式,要求在每一行的开始使用**label**前缀定义一个或多个标签。格式如下:

__label__<class_1> <text>
__label__<class_2> <text>

在这篇文章中我们将使用Kaggle的SMS垃圾信息检测数据集来用Flair构建一个垃圾/非垃圾分类器。这个数据集很适合我们的学习任务,因为它很小,只有5572行数据,可以在单个CPU上只花几分钟就完成模型的训练。

3.1 预处理 - 构建数据集

首先下载Kaggle上的数据集,得到spam.csv;然后再数据集目录下,运行我们的处理脚本,得到训练集、开发集和测试集:

import pandas as pd
data = pd.read_csv("./spam.csv", encoding='latin-1').sample(frac=1).drop_duplicates()data = data[['v1', 'v2']].rename(columns={"v1":"label", "v2":"text"})data['label'] = '__label__' + data['label'].astype(str)data.iloc[0:int(len(data)*0.8)].to_csv('train.csv', sep='\t', index = False, header = False)
data.iloc[int(len(data)*0.8):int(len(data)*0.9)].to_csv('test.csv', sep='\t', index = False, header = False)
data.iloc[int(len(data)*0.9):].to_csv('dev.csv', sep='\t', index = False, header = False);

上面的脚本会进行剔重和随机乱序处理,并按照80/10/10的比例进行数据集的分割。脚本成功执行后,就会得到FastText格式的三个数据文件:train.csv、dev.csv和test.csv。

3.2 训练自定义文本分类模型

用下面的脚本训练模型:

from flair.data_fetcher import NLPTaskDataFetcher
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentLSTMEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Pathcorpus = NLPTaskDataFetcher.load_classification_corpus(Path('./'), test_file='train.csv', dev_file='dev.csv', train_file='test.csv')word_embeddings = [WordEmbeddings('glove'), FlairEmbeddings('news-forward-fast'), FlairEmbeddings('news-backward-fast')]document_embeddings = DocumentLSTMEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256)classifier = TextClassifier(document_embeddings, label_dictionary=corpus.make_label_dictionary(), multi_label=False)trainer = ModelTrainer(classifier, corpus)trainer.train('./', max_epochs=20)

第一次运行上面这个脚本时,Flair会自动下载所需要的嵌入模型,这可能需要几分钟,然后接下来的整个训练过程还需要大约5分钟。

脚本首先载入需要的库和数据集,得到一个corpus对象。

接下来,我们创建一个嵌入列表,包含两个Flair上下文字符串嵌入和一个GloVe单词嵌入,这个列表接下来将作为我们文档嵌入对象的输入。堆叠和文本嵌入是Flair中最有趣的感念之一,它们提供了将不同的嵌入整合在一起的手段,你可以同时使用传统的单词嵌入(例如GloVe、word2vector、ELMo)和Flair的上下文字符串嵌入。在上面的示例中我们使用一个基于LSTM的方法来生成文档嵌入,关于该方法的详细描述可以参考这里。

最后,上面的代码训练模型并生成两个模型文件:final-model.pt和best-model.pt。

3.3 用训练好的模型进行预测

现在我们可以使用导出的模型进行预测了。脚本如下:

from flair.models import TextClassifier
from flair.data import Sentenceclassifier = TextClassifier.load_from_file('./best-model.pt')sentence = Sentence('Hi. Yes mum, I will...')classifier.predict(sentence)print(sentence.labels)

上面的代码输出如下:

[ham (1.0)]

这意味着模型100%的确信我们输入的示例消息不是垃圾信息。

Flair是如何超越其他框架的?

与Facebook的FastText或者Google的AutoML平台不同,用Flair进行文本分类还是相对底层的任务。我们可以完全控制文本如何嵌入,也可以设置训练的参数例如学习速率、批大小、损失函数、优化器选择策略等,这些超参数是要实现最优性能所必须进行调整的。Flair提供了著名的超参数调整库Hyperopt的一个封装。

在这篇文章中,出于简化考虑我们使用了默认的超参数,得到的Flair模型的f1-score在20个epoch之后达到了0.973。

为了对比,我们使用FastText和AutoML训练了一个文本分类器。我们首先使用默认参数运行FastText,得到的f1-score为0.883,这意味着我们的Flair模型远远优于FastText模型,不过FastText的训练很快,只需要几秒钟。

然后我们也与AutoML自然语言平台上得到的结果进行了对比。平台首先需要20分钟来解析数据集,然后我们启动训练过程,这大约花了3个小时才完成,但是f1-score达到了99.211,要稍好于我们自己训练的Flair模型。

汇智网翻译整理,转载请标明出处。链接:用Flair进行文本分类

用Flair(PyTorch构建的NLP开发包)进行文本分类相关推荐

  1. 零基础入门天池NLP赛事之——新闻文本分类(5)

    基于深度学习的文本分类 一.学习目标: 学习Word2Vec的使用和基础原理 学习使用TextCNN.TextRNN进行文本表示 学习使用HAN网络结构完成文本分类 二.文本表示方法 Part3: 词 ...

  2. 【NLP】深度学习文本分类|模型代码技巧

    文本分类是NLP的必备入门任务,在搜索.推荐.对话等场景中随处可见,并有情感分析.新闻分类.标签分类等成熟的研究分支和数据集. 本文主要介绍深度学习文本分类的常用模型原理.优缺点以及技巧,是「NLP入 ...

  3. 【NLP基础理论】03 文本分类

    注: Unimelb Comp90042 NLP笔记 相关tutorial代码链接 Text Classification(文本分类) 目录 Text Classification(文本分类) 1 分 ...

  4. NLP入门之新闻文本分类竞赛——文本分类模型

    一.Word2Vec word2vec模型背后的基本思想是对出现在上下文环境里的词进行预测.对于每一条输入文本,我们选取一个上下文窗口和一个中心词,并基于这个中心词去预测窗口里其他词出现的概率.因此, ...

  5. pytorch创建模型并训练(初探文本分类问题)

        本博客对pytorch在深度学习上的使用进行了介绍,本博客并不会对怎么训练一个好的模型进行介绍(其实我也不会),我觉得训练一个好的模型首先得选对一个模型(关键的问题在于模型如何设计),然后再经 ...

  6. NLP实战-中文新闻文本分类

    目录 1.思路 2.基于paddle的ERINE模型进行迁移学习训练 3.分步实现 3.1 获取数据 (1)数据解压 (2)将文本转成变量,这里为了好计算,我只选了新闻标题做文本分类 3.2 中文分词 ...

  7. NLP学习笔记-FastText文本分类(四)

    分类的目的和分类的方法 1. 文本分类的目的 回顾之前的流程,我们可以发现文本分类的目的就是为了进行意图识别 在当前我们的项目的下,我们只有两种意图需要被识别出来,所以对应的是2分类的问题 可以想象, ...

  8. NLP实战之HAN文本分类

    HAN(层叠注意力)神经网络文本分类 原理讲解 HAN出处:论文Hierarchical Attention Networks for Document Classification 可以参见讲解文献 ...

  9. 【NLP】知乎文本分类比赛第一名笔记

    转载自:https://zhuanlan.zhihu.com/p/25928551 知乎"看山杯" 夺冠记 陈云 研究僧 537 人赞了该文章 知乎看山杯夺冠记 比赛源码(PyTo ...

最新文章

  1. ssh mysql环境搭建_搭建一个MySQL高可用架构集群环境
  2. 完美解决tomcat/springboot启动速度相当慢 快死的状态了
  3. java camel swagger,java – CAMEL_CASE_TO_LOWER_CASE_WITH_UNDERSCORES没有反映在swagger.json中
  4. 看看大货车到底有多少盲区,肯定用得到!救命的!
  5. vue中使用高德地图 amap--基础使用方法
  6. c语言编程新思路知道答案,C语言编程新思路_知道答案公众号免费
  7. elipse调试linux内核,debug eclipse cdt + qemu虚拟机调试linux内核
  8. drawboard pdf拆分文件_掌握在线PDF拆分技巧,从此打开文件不再处于“加载中”...
  9. 用XSLT和XML改进Struts
  10. 【原创】rabbitmq-echopid用户手册(翻译)
  11. 【sklearn第十九讲】高斯混合模型
  12. linux 软键盘输入密码,Linux系统中使用屏幕键盘的方法
  13. s7200cpu224xp手册_河南西门子CPU224XP模块使用手册
  14. 软件设计师教程第5版.PDF.高清
  15. 新手如何学习FPGA技术
  16. np问题 量子计算机,P vs NP与经典与量子计算机可解决的问题相同吗?
  17. 客户端负载均衡Ribbon
  18. oracle nlv 全称,oracle: OCA-047-题解与实验(9)--SQL语句中COUNT和NLV的用法
  19. 常用矩阵向量求导公式
  20. Mac终端命令和连接服务器

热门文章

  1. JavaScript — DOM
  2. python语言Camelot库: 人类的 PDF 表提取
  3. Oauth2.0 安全性(以微信授权登陆为例)
  4. php图片居中在div,在div中水平居中图像
  5. Opencv之图像滤波:6.双边滤波(cv2.bilateralFilter)
  6. oracle 存储二进制流,oracle存二进制流
  7. 2022年湖北省科技进步奖详细解答,该奖项申报条件以及奖励补贴具体情况解析
  8. 什么是“分布式应用系统”
  9. 计算机自动设置开机,电脑可以设置系统自动开机吗
  10. 【精华贴】数字音频接口详解-I2S接口PCM/TDM接口PDM接口