作者 | Thilina Rajapakse
译者 | Raku
编辑 | 夕颜
出品 | AI科技大本营(ID: rgznai100)
【导读】本文将介绍一个简单易操作的Transformers库——Simple Transformers库。它是AI创业公司Hugging Face在Transformers库的基础上构建的。Hugging Face Transformers是供研究与其他需要全面控制操作方式的人员使用的库,简单易操作。
简介
Simple Transformers专为需要简单快速完成某项工作而设计。不必拘泥于源代码,也不用费时费力地去弄清楚各种设置,文本分类应该非常普遍且简单——Simple Transformers就是这么想的,并且专为此实现。
一行代码建立模型,另一行代码训练模型,第三行代码用来预测,老实说,还能比这更简单吗?
所有源代码都可以在Github Repo上找到,如果你有任何问题或疑问,请在这上面自行寻求答案。
GitHub repo:
https://github.com/ThilinaRajapakse/simpletransformers
安装
1、从这里(https://www.anaconda.com/distribution/)安装Anaconda或Miniconda Package Manager。
2、创建一个新的虚拟环境并安装所需的包。
conda create -n transformers python pandas tqdm
conda activate transformers
如果是cuda:
conda install pytorch cpuonly -c pytorch
conda install -c anaconda scipy
conda install -c anaconda scikit-learn
pip install transformers
pip install tensorboardx
3、安装simpletransformers。
用法
让我们看看如何对AGNews数据集执行多类分类。
对于用Simple Transformers简单二分类,参考这里。
下载并提取数据
1、从Fast.ai(https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz)下载数据集。
2、提取train.csv和test.csv并将它们放在目录data/ 中。
为训练准备数据
Simple Transformers要求数据必须包含在至少两列的Pandas DataFrames中。你只需为列的文本和标签命名,SimpleTransformers就会处理数据。或者你也可以遵循以下约定:
•  第一列包含文本,类型为str。
•  第二列包含标签,类型为int。
对于多类分类,标签应该是从0开始的整数。如果数据具有其他标签,则可以使用python dict保留从原始标签到整数标签的映射。
模型
from simpletransformers.model import TransformerModel # Create a TransformerModel model = TransformerModel('roberta', 'roberta-base', num_labels=4)
这将创建一个TransformerModel,用于训练,评估和预测。第一个参数是model_type,第二个参数是model_name,第三个参数是数据中的标签数:
•  model_type可以是['bert','xlnet','xlm','roberta','distilbert']之一。
•  有关可用于model_name的预训练模型的完整列表,请参阅“当前预训练模型”(https://github.com/ThilinaRajapakse/simpletransformers#current-pretrained-models)。
要加载以前保存的模型而不是默认模型的模型,可以将model_name更改为包含已保存模型的目录的路径。
TransformerModel具有dict参数,其中包含许多属性,这些属性提供对超参数的控制。有关每个属性的详细说明,请参阅repo。默认值如下所示:
self.args = { 'output_dir': 'outputs/',   'cache_dir': 'cache_dir',   'fp16': True, 'fp16_opt_level': 'O1', 'max_seq_length': 128,    'train_batch_size': 8,    'gradient_accumulation_steps': 1, 'eval_batch_size': 8, 'num_train_epochs': 1,    'weight_decay': 0,    'learning_rate': 4e-5,    'adam_epsilon': 1e-8, 'warmup_ratio': 0.06, 'warmup_steps': 0,    'max_grad_norm': 1.0, 'logging_steps': 50,  'save_steps': 2000,   'overwrite_output_dir': False,    'reprocess_input_data': False,    'process_count': cpu_count() - 2 if cpu_count() > 2 else 1,    }
在创建TransformerModel或调用其train_model方法时,只要简单地传递包含要更新的键值对的字典,就可以修改这些属性中的任何一个。下面给出一个例子:
# Create a TransformerModel with modified attributes
model = TransformerModel('roberta', 'roberta-base', num_labels=4,
args={'learning_rate':1e-5, 'num_train_epochs': 2,
'reprocess_input_data': True, 'overwrite_output_dir': True})
训练
# Train the model
model.train_model(train_df)
这就是训练模型所需要做的全部。你还可以通过将包含相关属性的字典传递给train_model方法来更改超参数。请注意,即使完成训练,这些修改也将保留。
train_model方法将在第n个步骤(其中n为self.args ['save_steps'])的第n个步骤创建模型的检查点(保存)。训练完成后,最终模型将保存到self.args ['output_dir']。
评估
要评估模型,只需调用eval_model。此方法具有三个返回值:
•  result:dict形式的评估结果。默认情况下,仅对多类分类计算马修斯相关系数(MCC)。
•  model_outputs:评估数据集中每个项目的模型输出list。用softmax函数来计算预测值,输出   每个类别的概率而不是单个预测。
•  wrong_predictions:每个错误预测的InputFeature list。可以从InputFeature.text_a属性获取文本。(可以在存储库 https://github.com/ThilinaRajapakse/simpletransformers 的utils.py文件中找到InputFeature类)
你还可以包括在评估中要使用的其他指标。只需将指标函数作为关键字参数传递给eval_model方法。指标功能应包含两个参数,第一个是真实标签,第二个是预测,这遵循sklearn标准。
对于任何需要附加参数的度量标准函数(在sklearn中为f1_score),你可以在添加了附加参数的情况下将其包装在自己的函数中,然后将函数传递给eval_model。
from sklearn.metrics import f1_score, accuracy_score def f1_multiclass(labels, preds):   return f1_score(labels, preds, average='micro')  result, model_outputs, wrong_predictions = model.eval_model(eval_df, f1=f1_multiclass, acc=accuracy_score
作为参考,我使用这些超参数获得的结果如下:
考虑到我实际上并没有进行任何超参数调整,效果还不错。感谢RoBERTa!
预测/测试
在实际应用中,我们常常不知道什么是真正的标签。要对任意示例执行预测,可以使用predict方法。此方法与eval_model方法非常相似,不同之处在于,该方法采用简单的文本列表并返回预测列表和模型输出列表。
结论
在许多实际应用中,多分类是常见的NLP任务,Simple Transformers是将Transformers的功能应用于现实世界任务的一种简单方法,你无需获得博士学位才能使用它。
关于项目
我计划在不久的将来将“问答”添加到Simple Transformers 库中。敬请关注!
Simple Transformers 库:https://github.com/ThilinaRajapakse/simpletransformers
原文链接:
https://medium.com/swlh/simple-transformers-multi-class-text-classification-with-bert-roberta-xlnet-xlm-and-8b585000ce3a

(*本文为 AI科技大本营翻译文章,转载请微信联系 1092722531

精彩推荐

2019 中国大数据技术大会(BDTC)再度来袭!豪华主席阵容及百位技术专家齐聚,15 场精选专题技术和行业论坛,超强干货+技术剖析+行业实践立体解读,深入解析热门技术在行业中的实践落地。5 折票倒计时 5 天!

推荐阅读

你点的每个“在看”,我都认真当成了AI

Simple Transformer:用BERT、RoBERTa、XLNet、XLM和DistilBERT进行多类文本分类相关推荐

  1. 【Bert、T5、GPT】fine tune transformers 文本分类/情感分析

    [Bert.T5.GPT]fine tune transformers 文本分类/情感分析 0.前言 text classification emotions 数据集 data visualizati ...

  2. bert使用做文本分类_使用BERT进行深度学习的多类文本分类

    bert使用做文本分类 Most of the researchers submit their research papers to academic conference because its ...

  3. Transformer, BERT, ALBERT, XLNet全面解析(ALBERT第一作者亲自讲解)

    现在是国家的非常时期,由于疫情各地陆续延迟复工,以及各大院校延期开学.作为一家AI教育领域的创业公司,我们希望在这个非常时期做点有价值的事情,并携手共渡难关.在疫情期间,我们决定联合国内外顶尖AI专家 ...

  4. ALBERT第一作者亲自讲解:Transformer、BERT、ALBERT、XLNet全面解析

    现在是国家的非常时期,由于疫情各地陆续延迟复工,以及各大院校延期开学.作为一家AI教育领域的创业公司,我们希望在这个非常时期做点有价值的事情,并携手共渡难关.在疫情期间,我们决定联合国内外顶尖AI专家 ...

  5. NLP专题直播 | Transformer, BERT, ALBERT, XLNet全面解析(ALBERT第一作者亲自讲解)

    现在是国家的非常时期,由于疫情各地陆续延迟复工,以及各大院校延期开学.作为一家AI教育领域的创业公司,我们希望在这个非常时期做点有价值的事情,并携手共渡难关.在疫情期间,我们决定联合国内外顶尖AI专家 ...

  6. NLP专题直播 | 详谈Transformer, BERT, ALBERT, XLNet(ALBERT第一作者亲自讲解)

    提到 - "预训练模型".从简单的 Word2Vec,ELMo,GPT,BERT,XLNet到ALBERT,  这几乎是NLP过去10年最为颠覆性的成果.作为一名AI从业者,或者未 ...

  7. 词向量, BERT, ALBERT, XLNet全面解析(ALBERT第一作者亲自讲解)

    Datawhale Datawhale编辑 现在是国家的非常时期,由于疫情各地陆续延迟复工,以及各大院校延期开学.作为一家 AI 教育领域的创业公司,贪心学院筹划了5期NLP专题直播课程,希望在这个非 ...

  8. 【面试必备】奉上最通俗易懂的XGBoost、LightGBM、BERT、XLNet原理解析

    一只小狐狸带你解锁 炼丹术&NLP 秘籍 在非深度学习的机器学习模型中,基于GBDT算法的XGBoost.LightGBM等有着非常优秀的性能,校招算法岗面试中"出镜率"非 ...

  9. NLP专题直播 | 详谈词向量技术:从Word2Vec, BERT到XLNet

    现在是国家的非常时期,由于疫情各地陆续延迟复工,以及各大院校延期开学.作为一家AI教育领域的创业公司,我们希望在这个非常时期做点有价值的事情,并携手共渡难关.在疫情期间,我们决定联合国内外顶尖AI专家 ...

最新文章

  1. 构建用户界面 Android 应用中一些常用的小部件
  2. python调用git生成log文件_python解析git log后生成页面显示git更新日志信息
  3. 高并发C/S的TCP版本golang实现
  4. 【SpringBoot零基础案例08】【IEDA 2021.1】SpringBoot获取核心配置文件application.properties中的自定义配置
  5. 福禄克OFP光纤测试仪5个强大的功能
  6. Csharp实例:武汉智能安检闸机数据接收和解析
  7. CTF常用python库PwnTools的使用学习
  8. 前端学习(548):node的自定义模块
  9. void类型及void指针
  10. 20 个免费的 jQuery 的工具提示插件:
  11. NLP简报(Issue#9)
  12. 服务器内容推送技术(转)
  13. (转)MapReduce中的两表join几种方案简介
  14. 比较List和ArrayList的性能及ArrayList和LinkedList优缺点
  15. 力软敏捷开发框架源码7.0.6解析
  16. 21天学通C语言-学习笔记(13)
  17. 小程序毕设作品之微信二手交易小程序毕业设计成品(6)开题答辩PPT
  18. linux的磁盘busy,Linux umount 报 device is busy 的处理方法
  19. 介绍中国传统节日的网页html,介绍中国传统节日的作文4篇
  20. CountDownLatch 用法和源码解析

热门文章

  1. 想知道垃圾回收暂停的过程中发生了什么吗?查查垃圾回收日志就知道了!
  2. 开源 免费 java CMS - FreeCMS1.9 全文检索
  3. mysql数据库备份、恢复文档
  4. windows mobile做一个摄象头预览程序
  5. DOS命令大全(经典收藏)
  6. android id 重名_Android App 自定义权限重名不能安装解决办法
  7. 洛谷P2763 试题库问题
  8. linu逻辑分区动态调整大小
  9. Oracle Mutex 机制 说明
  10. Datawhale组队学习周报(第047周)