作者:哈工大SCIR 狄东林 刘元兴 朱庆福 胡景雯

引言

随着人工智能的发展,越来越多深度学习框架如雨后春笋般涌现,例如PyTorch、TensorFlow、Keras、MXNet、Theano和PaddlePaddle等。这些基础框架提供了构建一个模型需要的基本通用工具包。但是对于NLP相关的任务,我们往往需要自己编写大量比较繁琐的代码,包括数据预处理和训练过程中的工具等。因此,大家通常基于NLP相关的深度学习框架编写自己的模型,如OpenNMT、ParlAI和AllenNLP等。借助这些框架,三两下就可以实现一个NLP相关基础任务的训练和预测。但是当我们需要对基础任务进行改动时,又被代码封装束缚,举步维艰。因此,本文主要针对于如何使用框架实现自定义模型,帮助大家快速了解框架的使用方法。

我们首先介绍广泛用于NLP/CV领域的TensorFlow框架——Tensor2Tensor,该框架提供了NLP/CV领域中常用的基本模型。然后介绍NLP领域的AllenNLP框架,该框架基于PyTorch平台开发,为NLP模型提供了统一的开发架构。接着在介绍NLP领域中重要的两个子领域,神经机器翻译和对话系统常用的框架,OpenNMT和ParlAI。通过这四个框架的介绍,希望能帮助大家了解不同开发平台,不同领域下的NLP框架的使用方式。

框架名称 应用领域 开发平台
Tensor2Tensor NLP/CV TensorFlow
AllenNLP NLP PyTorch
OpenNMT NLP-机器翻译 PyTorch/TensorFlow
ParlAI NLP-对话 PyTorch

一、Tensor2Tensor

Tensor2Tensor[1]是一个基于TensorFlow的较为综合性的库,既包括一些CV 和 NLP的基本模型,如LSTM,CNN等,也提供一些稍微高级一点的模型,如各式各样的GAN和Transformer。对NLP的各项任务支持得都比较全面,很方便容易上手。

由于该资源库仍处于不断开发过程中,截止目前为止,已经有3897次commit,66个release 版本,178 contributors。在2018年《Attention is all you need》这个全网热文中,该仓库是官方提供的Transformer模型版本,后面陆陆续续其余平台架构才逐渐补充完成。

Tensor2Tensor(Transformer)使用方法

注意:有可能随着版本迭代更新的过程中会有局部改动

安装环境

1. 安装CUDA 9.0 (一定是9.0,不能是9.2)

2. 安装TensorFlow (现在是1.12)

3. 安装Tensor2Tensor (参考官网安装)

开始使用

1. 数据预处理

这一步骤是根据自己任务自己编写一些预处理的代码,比如字符串格式化,生成特征向量等操作。

2. 编写自定义problem:

  • 编写自定义的problem代码,一定需要在自定义类名前加装饰器(@registry.registry_problem)。

  • 自定义problem的类名一定是驼峰式命名,py文件名一定是下划线式命名,且与类名对应。

  • 一定需要继承父类problem,t2t已经提供用于生成数据的problem,需要自行将自己的问题人脑分类找到对应的父类,主要定义的父类problem有:(运行 t2t-datagen 可以查看到problem list)。

  • 一定需要在__init__.py文件里导入自定义problem文件。

3. 使用t2t-datagen 将自己预处理后的数据转为t2t的格式化数据集【注意路径】

  • 运行 t2t-datagen --help 或 t2t-datagen --helpfull。例如:

1cd scripts && t2t-datagen --t2t_usr_dir=./ --data_dir=../train_data --tmp_dir=../tmp_data --problem=my_problem
  • 如果自定义problem代码的输出格式不正确,则此命令会报错

4. 使用t2t-trainer使用格式化的数据集进行训练

  • 运行t2t-trainer --help 或 t2t-trainer --helpfull。例如:

1cd scripts && t2t-trainer --t2t_usr_dir=./ --problem=my_problem --data_dir=../train_data --model=transformer --hparams_set=transformer_base --output_dir=../output --train_steps=20 --eval_steps=100

5. 使用t2t-decoder对测试集进行预测【注意路径】

  • 如果想使用某一个checkpoint时的结果时,需要将checkpoint文件中的第一行: model_checkpoint_path: “model.ckpt-xxxx” 的最后的序号修改即可。例如:

1cd scripts && t2t-decoder --t2t_usr_dir=./ --problem=my_problem --data_dir=../train_data --model=transformer --hparams_set=transformer_base --output_dir=../output --decode_hparams=”beam_size=5,alpha=0.6” --decode_from_file=../decode_in/test_in.txt --decode_to_file=../decode_out/test_out.txt

6. 使用t2t-exporter导出训练模型

7. 分析结果

附: (整体代码)

 1# coding=utf-82from tensor2tensor.utils import registry3from tensor2tensor.data_generators import problem, text_problems45@registry.register_problem6class AttentionGruFeature(text_problems.Text2ClassProblem):78    ROOT_DATA_PATH = '../data_manager/'9    PROBLEM_NAME = 'attention_gru_feature'
10
11    @property
12    def is_generate_per_split(self):
13        return True
14
15    @property
16    def dataset_splits(self):
17        return [{
18            "split": problem.DatasetSplit.TRAIN,
19            "shards": 5,
20        }, {
21            "split": problem.DatasetSplit.EVAL,
22            "shards": 1,
23        }]
24
25    @property
26    def approx_vocab_size(self):
27        return 2 ** 10  # 8k vocab suffices for this small dataset.
28
29    @property
30    def num_classes(self):
31        return 2
32
33    @property
34    def vocab_filename(self):
35        return self.PROBLEM_NAME + ".vocab.%d" % self.approx_vocab_size
36
37    def generate_samples(self, data_dir, tmp_dir, dataset_split):
38        del data_dir
39        del tmp_dir
40        del dataset_split
41
42        # with open('{}self_antecedent_generate_sentences.pkl'.format(self.ROOT_DATA_PATH), 'rb') as f:
43        #     # get all the sentences for antecedent identification
44        #     _sentences = pickle.load(f)
45        #
46        # for _sent in _sentences:
47        #     # # sum pooling, FloatTensor, Size: 400
48        #     # _sent.input_vec_sum
49        #     # # sum pooling with feature, FloatTensor, Size: 468
50        #     # _sent.input_vec_sum_feature
51        #     # # GRU, FloatTensor, Size: 6100
52        #     # _sent.input_vec_hidden
53        #     # # GRU with feature, FloatTensor, Size: 6168
54        #     # _sent.input_vec_hidden_feature
55        #     # # AttentionGRU, FloatTensor, Size: 1600
56        #     # _sent.input_vec_attention
57        #     # # AttentionGRU with feature, FloatTensor, Size: 1668
58        #     # _sent.input_vec_attention_feature
59        #     # # tag(1 for positive case, and 0 for negative case), Int, Size: 1
60        #     # _sent.antecedent_label
61        #     # # tag(1 for positive case, and 0 for negative case), Int, Size: 1
62        #     # _sent.trigger_label
63        #     # # trigger word for the error analysis, Str
64        #     # _sent.trigger
65        #     # # trigger word auxiliary type for the experiment, Str
66        #     # _sent.aux_type
67        #     # # the original sentence for the error analysis, Str
68        #     # _sent.sen
69        #
70        #     yield {
71        #         "inputs": _sent.input_vec_attention_feature,
72        #         "label": _sent.antecedent_label
73        #     }
74
75        with open('../prep_ante_data/antecedent_label.txt') as antecedent_label, open(
76                '../prep_ante_data/input_vec_attention_gru_feature.txt') as input_vec:
77            for labal in antecedent_label:
78                yield {
79                    "inputs": input_vec.readline().strip()[1:-2],
80                    "label": int(labal.strip())
81                }
82
83        antecedent_label.close()
84        input_vec.close()
85
86
87# PROBLEM_NAME='attention_gru_feature'
88# DATA_DIR='../train_data_atte_feature'
89# OUTPUT_DIR='../output_atte_feature'
90# t2t-datagen --t2t_usr_dir=. --data_dir=$DATA_DIR --tmp_dir=../tmp_data --problem=$PROBLEM_NAME
91# t2t-trainer --t2t_usr_dir=. --data_dir=$DATA_DIR --problem=$PROBLEM_NAME --model=transformer --hparams_set=transformer_base --output_dir=$OUTPUT_DIR

Tensor2Tensor使用总结

T2T 是Google 非官方提供的仓库,是社区广大爱好者共同努力建设的简单入门型框架,底层封装TF,能满足大部分CV 和 NLP的任务,很多主流成熟的模型也已经都有实现。直接继承或实现一些框架内预设的接口,就可以完成很多任务。入门起来非常友好,并且文档更新也较为及时。认真阅读文档(或阅读报错信息)就可以了解并使用该框架,方便许多非大幅创新模型的复现。

二、AllenNLP

AllenNLP是一个基于PyTorch的NLP研究库,可为开发者提供语言任务中的各种业内最佳训练模型。官网提供了一个很好的入门教程[2],能够让初学者在30分钟内就了解AllenNLP的使用方法。

AllenNLP使用方法

由于AllenNLP已经帮我们实现很多麻烦琐碎的预处理和训练框架,我们实际需要编写的只有:

1. DatasetReader

DatasetReader的示例代码如下所示。

 1from typing import Dict, Iterator23from allennlp.data import Instance4from allennlp.data.fields import TextField5from allennlp.data.dataset_readers import DatasetReader6from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer7from allennlp.data.tokenizers import WordTokenizer, Tokenizer89@DatasetReader.register('custom')
10class CustomReader(DatasetReader):
11
12    def __init__(self, tokenizer: Tokenizer = None, token_indexers: Dict[str, TokenIndexer] = None) -> None:
13        super().__init__(lazy=False)
14        self.tokenizer = tokenizer or WordTokenizer()
15        self.word_indexers = token_indexers or {"word": SingleIdTokenIndexer('word')}
16
17    def text_to_instance(self, _input: str) -> Instance:
18        fields = {}
19        tokenized_input = self.tokenizer.tokenize(_input)
20        fields['input'] = TextField(tokenized_input, self.word_indexers)
21        return Instance(fields)
22
23    def _read(self, file_path: str) -> Iterator[Instance]:
24        with open(file_path) as f:
25            for line in f:
26                yield self.text_to_instance(line)

首先需要自定义_read函数,写好读取数据集的方式,通过yield方式返回构建一个instance需要的文本。然后通过text_to_instance函数将文本转化为instance。在text_to_instance函数中,需要对输入的文本进行切分,然后构建fileld

self.tokenizer是用来切分文本成Token的。有Word级别的也有Char级别的。self.word_indexers是用来索引Token并转换为Tensor。同样TokenIndexer也有很多种,在实现自己的模型之前可以看看官方文档有没有比较符合自己需要的类型。如果你需要构建多个Vocabulary,比如源语言的vocab 和目标语言的vocab, 就需要在这里多定义一个self.word_indexers。不同indexers在vocab中,是通过SingleIdTokenIndexer函数初始化的namespace来区分的,也就是15行代码中最后一个的'word'

2. Model

与PyTorch实现model的方式一样,但需要注意的是:

@Model.register('') 注册之后可以使用JsonNet进行模型选择(如果你有多个模型,可以直接修改Json值来切换,不需要手动修改代码)。

由于AllenNLP封装了Trainer,所以我们需要在model内实现或者选择已有的评价指标,这样在训练过程中就会自动计算评价指标。具体方法是,在__init__方法中定义评价函数,可以从在官方文档[3]上看看有没有,如果没有的话就需要自己写。

1self.acc = CategoricalAccuracy()

然后在forward方法中调用评价函数计算指标

1self.acc(output, labels)

最后在model的get_metrics返回对应指标的dict结果就行了。

1def get_metrics(self, reset: bool = False) -> Dict[str, float]:
2    return {"acc": self.acc.get_metric(reset)}

3. Trainer

一般来说直接调用AllenNLP的Trainer方法就可以自动开始训练了。但是如果你有一些特殊的训练步骤,比如GAN[4],你就不能单纯地使用AllenNLP的Trainer,得把Trainer打开进行每步的迭代,可以参考[4]中trainer的写法。

AllenNLP使用总结

关于AllenNLP的学习代码,可以参考[5]。由于AllenNLP是基于PyTorch的,代码风格和PyTorch的风格基本一致,因此如果你会用PyTorch,那上手AllenNLP基本没有什么障碍。代码注释方面也比较全,模块封装方面比较灵活。AllenNLP的代码非常容易改动,就像用纯的PyTorch一样灵活。当然灵活也就意味着很多复杂的实现,AllenNLP目前还没有,大部分可能都需要自己写。AllenNLP依赖了很多Python库,近期也在更新。

三、OpenNMT

OpenNMT[6]是一个开源的神经机器翻译(neural machine translation)项目,采用目前普遍使用的编码器-解码器(encoder-decoder)结构,因此,也可以用来完成文本摘要、回复生成等其他文本生成任务。目前,该项目已经开发出PyTorch、TensorFlow两个版本,用户可以按需选取。本文以PyTorch版本[7]为例进行介绍。

OpenNMT使用方法

1. 数据处理

作为一个典型的机器翻译框架,OpenNMT的数据主要包含source和target两部分,对应于机器翻译中的源语言输入和目标语言翻译。OpenNMT采用TorchText中的Field数据结构来表示每个部分。用户自定义过程中,如需添加source和target外的其他数据,可以参照source field或target field的构建方法,如构建一个自定义的user_data数据:

1fields["user_data"] = torchtext.data.Field(
2    init_token=BOS_WORD, eos_token=EOS_WORD,
3    pad_token=PAD_WORD,
4    include_lengths=True)

其中init_token、eos_token和pad_token分别为用户自定义的开始字符、结束字符和padding字符。Include_lengths为真时,会同时返回处理后数据和数据的长度。

2. 模型

OpenNMT实现了注意力机制的编码器-解码器模型。框架定义了编码器和解码器的接口,在该接口下,进一步实现了多种不同结构的编码器解码器,可供用户按需组合,如CNN、 RNN编码器等。如用户需自定义特定结构的模块,也可以遵循该接口进行设计,以保证得到的模块可以和OpenNMT的其他模块进行组合。其中,编码器解码器接口如下:

1class EncoderBase(nn.Module):
2    def forward(self, input, lengths=None, hidden=None):
3        raise NotImplementedError
4
5class RNNDecoderBase(nn.Module):
6    def forward(self, input, context, state, context_lengths=None):
7             raise NotImplementedError

3. 训练

OpenNMT的训练由Trainer.py中Trainer类控制,该类的可定制化程度并不高,只实现了最基本的序列到序列的训练过程。对于多任务、对抗训练等复杂的训练过程,需要对该类进行较大的改动。

OpenNMT使用总结

OpenNMT提供了基于PyTorch和TensorFlow这两大主流框架的不同实现,能够满足绝大多数用户的需求。对于基础框架的封装使得其丧失了一定的灵活性,但是对于编码器-解码器结构下文本生成的任务来说,可以省去数据格式、接口定义等细节处理,将精力更多集中在其自定义模块上,快速搭建出需要的模型。

四、ParlAI

ParlAI是Facebook公司开发出的一个专注于对话领域在很多对话任务上分享,训练和评估对话模型的平台[8]。这个平台可以用于训练和测试对话模型,在很多数据集上进行多任务训练,并且集成了Amazon Mechanical Turk,以便数据收集和人工评估。

ParlAI 中的基本概念:

  • world定义了代理彼此交互的环境。世界必须实施一种parley方法。每次对parley的调用都会进行一次交互,通常每个代理包含一个动作。

  • agent可以是一个人,一个简单的机器人,可以重复它听到的任何内容,完美调整的神经网络,读出的数据集,或者可能发送消息或与其环境交互的任何其他内容。代理有两个他们需要定义的主要方法:

1def observe(self, observation): #用观察更新内部状态
2def act(self): #根据内部状态生成动作
  • observations是我们称之为代理的act函数返回的对象,并且因为它们被输入到其他代理的observe函数而被命名。这是ParlAI中代理与环境之间传递消息的主要方式。观察通常采用包含不同类型信息的python词典的形式。

  • teacher是特殊类型的代理人。他们像所有代理一样实施act和observe功能,但他们也会跟踪他们通过报告功能返回的指标,例如他们提出的问题数量或者正确回答这些问题的次数。

ParlAI 的代码包含如下几个主要的文件夹[9]:

  • core包含框架的主要代码;

  • agents包含可以和不同任务交互的代理;

  • examples包含不同循环的一些基本示例;

  • tasks包含不同任务的代码;

  • mturk包含设置 Mechanical Turk 的代码及 MTurk 任务样例。

ParlAI使用方法

ParlAI内部封装了很多对话任务(如ConvAI2)和评测(如F1值和hits@1等等)。使用ParlAI现有的数据,代码以及模型进行训练和评测,可以快速实现对话模型中的很多baseline模型。但由于代码封装性太强,不建议使用它从头搭建自己的模型。想在基础上搭建自己的模型可以详细参考官网中的教程[10]。

这里简单介绍直接利用内部的数据,代码以及模型进行训练和评测的一个简单例子(Train a Transformer on Twitter):

1. 打印一些数据集中的例子

1python examples/display_data.py -t twitter
2*# display first examples from twitter dataset*

2. 训练模型

1python examples/train_model.py -t twitter -mf /tmp/tr_twitter -m transformer/ranker -bs 10 -vtim 3600 -cands batch -ecands batch --data-parallel True
2# train transformer ranker

3. 评测之前训练出的模型

1python examples/eval_model.py -t twitter -m legacy:seq2seq:0 -mf models:twitter/seq2seq/twitter_seq2seq_model
2# Evaluate seq2seq model trained on twitter from our model zoo

4. 输出模型的一些预测

1python examples/display_model.py -t twitter -mf /tmp/tr_twitter -ecands batch
2# display predictions for model saved at specific file on twitter

ParlAI使用总结

ParlAI有自己的一套模式,例如world、agent和teacher等等。代码封装性特别好,代码量巨大,如果想查找一个中间结果,需要一层一层查看调用的函数,不容易进行修改。ParlAI中间封装了很多现有的baseline模型,对于对话研究者,可以快速实现baseline模型。目前ParlAI还在更新,不同版本之间的代码可能结构略有不同,但是ParlAI的核心使用方法大致相同。

五、总结

本文介绍了四种常见框架构建自定义模型的方法。Tensor2Tensor涵盖比较全面,但是只支持TensorFlow。AllenNLP最大的优点在于简化了数据预处理、训练和预测的过程。代码改起来也很灵活,但是一些工具目前官方还没有实现,需要自己写。如果是比较传统的编码器-解码器结构下文本生成任务,使用OpenNMT能节省很多时间。但是如果是结构比较新颖的模型,使用OpenNMT搭建模型依旧是一个不小的挑战。ParlAI内部封装了很多对话任务,方便使用者快速复现相关的baseline模型。但由于代码封装性太强和其特殊的模式,使用ParlAI从头搭建自己的模型具有一定的挑战性。每个框架都有各自的优点和弊端,大家需结合自身情况和使用方式进行选择。但是不建议每个框架都试一遍,毕竟掌握每个框架还是需要一定时间成本的。

参考资料

[1] https://github.com/tensorflow/tensor2tensor

[2] https://allennlp.org/tutorials

[3] https://allenai.github.io/allennlp-docs/api/allennlp.training.metrics.html

[4] http://www.realworldnlpbook.com/blog/training-a-shakespeare-reciting-monkey-using-rl-and-seqgan.html

[5] https://github.com/mhagiwara/realworldnlp

[6] http://opennmt.net/

[7] https://github.com/OpenNMT/OpenNMT-py

[8] http://parl.ai.s3-website.us-east-2.amazonaws.com/docs/tutorial_quick.html

[9] https://www.infoq.cn/article/2017/05/ParlAI-Facebook-AI

[10] http://parl.ai.s3-website.us-east-2.amazonaws.com/docs/tutorial_basic.html

本期责任编辑:崔一鸣

本期编辑:刘元兴

四种常见NLP框架使用总结相关推荐

  1. 最新开源LiDAR数据集LSOOD:四种常见的室外物体分类

    点云PCL免费知识星球,点云论文速读. 标题:最新开源LiDAR数据集LSOOD:四种常见的室外物体分类 作者:Y Tian 来源:https://github.com/Tian-Yifei/LSOO ...

  2. (转载)四种常见的 POST 提交数据方式

    转载地址:https://imququ.com/post/four-ways-to-post-data-in-http.html 四种常见的 POST 提交数据方式 HTTP/1.1 协议规定的 HT ...

  3. application/json 四种常见的 POST 提交数据方式

    四种常见的 POST 提交数据方式   HTTP/1.1 协议规定的 HTTP 请求方法有 OPTIONS.GET.HEAD.POST.PUT.DELETE.TRACE.CONNECT 这几种.其中 ...

  4. 四种常见的 POST 提交数据方式对应的content-type取值

    做前后端分离一般都有第3中 , 第一种 基本上jquery那年代用的了 第2种在需要传文件时用的 https://www.cnblogs.com/wushifeng/p/6707248.html 四种 ...

  5. 四种常见的 POST 提交数据方式 专题

    原文地址为: 四种常见的 POST 提交数据方式 专题 定义和用法 enctype 属性规定在发送到服务器之前应该如何对表单数据进行编码. 默认地,表单数据会编码为 "application ...

  6. 采购订单管理的四种常见类型

    采购管理对于任何成功的企业都至关重要.如果你的企业没有统一的采购管理流程,那么你可能无法对你的采购进行解释,并可能犯下代价高昂的采购错误.采购订单或许是有效采购管理流程中最重要的部分.让我们来看看采购 ...

  7. 四种常见的 POST 提交数据方式--good

    http://www.cnblogs.com/softidea/p/5745369.html 四种常见的 POST 提交数据方式--good HTTP/1.1 协议规定的 HTTP 请求方法有 OPT ...

  8. 【转载】四种常见系统架构介绍

    转自于 四种常见系统架构介绍 - 宇大..大 - 博客园软件架构(software architecture)就是软件的基本结构. 合适的架构是软件成功的最重要因素之一.大型软件公司通常有专门的架构师 ...

  9. JavaScript内存管理机制以及四种常见的内存泄漏解析

    转自:http://geek.csdn.net/news/detail/238898 原文:How JavaScript works: memory management + how to handl ...

最新文章

  1. salt-api timeout 执行超时问题解决
  2. Spring Web Flow实例教程
  3. 综合布线中所需要的的带宽和数据速率
  4. Java Annotation认知(包括框架图、详细介绍、示例说明)
  5. 必须使用301重定向的运用场景
  6. 09 进程池的异步方法
  7. log4j2+ELK
  8. laravel使用artisan报错SQLSTATE[42S02]: Base table or view not found: 1146
  9. 南邮-2022年6月电子商务练习自整理 - 选择篇
  10. 路由器 接 交换机 接 路由器
  11. 各代DDR内存的速度表
  12. 2018校招笔试真题汇总
  13. 交友结婚的原则[转贴]
  14. iphone4s更换电池_更换iPhone电池有多困难?
  15. 遗传算法求解一元函数最大值
  16. python对真假的判断方式
  17. 速轩三维 - 白光/蓝光/拍照式三维扫描仪
  18. (二)Pgcluu监控
  19. Lind.DDD.Paging分页模块介绍
  20. Games101----Transformation(2)

热门文章

  1. c语言编程所得票数,C语言编程求1X2X3····Xn所得的数末尾有多少个零
  2. java串口通信DataRecive_串口通信之DataReceive事件触发时机
  3. 题目:任意给定一个浮点数,计算这个浮点数的立方根。(基于二分法和牛顿迭代法)(基于Java实现)
  4. min_sample_split 和min_sample_leaf区别
  5. 隐马尔可夫模型维特比算法与前向算法区别
  6. eclipse调试详解
  7. 基于三维数据的深度学习综述
  8. CSS Tricks网站创始人作序推荐,这本书助你成为Web开发高手
  9. Spring Boot并不重复“造轮子”
  10. 想找到女朋友,你得掌握这些算法