自然语言推理:微调BERT

Natural Language Inference: Fine-Tuning BERT

SNLI数据集上的自然语言推理任务设计了一个基于注意力的体系结构。现在通过微调BERT来重新讨论这个任务。自然语言推理是一个序列级文本对分类问题,而微调BERT只需要额外的基于MLP的架构,如图1所示。

Fig. 1. This section feeds pretrained BERT to an MLP-based architecture for natural language inference.

下载一个经过预训练的小版本BERT,然后对其进行微调,以便在SNLI数据集上进行自然语言推理。

from d2l import mxnet as d2l

import json

import multiprocessing

from mxnet import autograd, gluon, init, np, npx

from mxnet.gluon import nn

import os

npx.set_np()

  1. Loading Pretrained BERT

解释了如何在WikiText-2数据集上预训练BERT(注意,原始的BERT模型是在更大的语料库上预训练的)。最初的BERT模型有上亿个参数。提供两个版本的预训练BERT:“bert.base “大约和原始的BERT基模型一样大,需要大量的计算资源进行微调,而“bert.small”是一个小版本,便于演示。

d2l.DATA_HUB[‘bert.base’] = (d2l.DATA_URL + ‘bert.base.zip’,

                         '7b3820b35da691042e5d34c0971ac3edbd80d3f4')

d2l.DATA_HUB[‘bert.small’] = (d2l.DATA_URL + ‘bert.small.zip’,

                          'a4e718a47137ccd1809c9107ab4f5edd317bae2c')

任何一个预训练的BERT模型都包含一个“vocab.json”定义词汇集和“pretrained.params”预训练参数的文件。实现了如下加载预训练模型函数来load_pretrained_model加载预训练的BERT参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,

                      num_heads, num_layers, dropout, max_len, ctx):data_dir = d2l.download_extract(pretrained_model)# Define an empty vocabulary to load the predefined vocabularyvocab = d2l.Vocab([])vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))vocab.token_to_idx = {token: idx for idx, token in enumerate(vocab.idx_to_token)}bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,num_layers, dropout, max_len)# Load pretrained BERT parametersbert.load_parameters(os.path.join(data_dir, 'pretrained.params'), ctx=ctx)return bert, vocab

为了便于在大多数机器上演示,将加载并微调小版本(“bert.small”)的名称。在练习中,将演示如何微调更大的“bert.base”以显著提高测试精度。

ctx = d2l.try_all_gpus()

bert, vocab = load_pretrained_model(

'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,

num_layers=2, dropout=0.1, max_len=512, ctx=ctx)

Downloading …/data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip…

  1. The Dataset for Fine-Tuning BERT

对于SNLI数据集上的下游任务自然语言推理,定义了一个自定义的数据集类SNLIBERTDataset。在每个例子中,前提和假设形成一对文本序列,并被打包成一个BERT输入序列,如图2所示。段IDs用于区分BERT输入序列中的前提和假设。使用预定义的BERT输入序列的最大长度(max_len),输入文本对中较长的最后一个标记会一直被删除,直到满足max_len。为了加速生成用于微调BERT的SNLI数据集,使用4个worker进程并行地生成训练或测试示例。

class SNLIBERTDataset(gluon.data.Dataset):

def __init__(self, dataset, max_len, vocab=None):all_premise_hypothesis_tokens = [[p_tokens, h_tokens] for p_tokens, h_tokens in zip(*[d2l.tokenize([s.lower() for s in sentences])for sentences in dataset[:2]])]self.labels = np.array(dataset[2])self.vocab = vocabself.max_len = max_len(self.all_token_ids, self.all_segments,self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)print('read ' + str(len(self.all_token_ids)) + ' examples')def _preprocess(self, all_premise_hypothesis_tokens):pool = multiprocessing.Pool(4)  # Use 4 worker processesout = pool.map(self._mp_worker, all_premise_hypothesis_tokens)all_token_ids = [token_ids for token_ids, segments, valid_len in out]all_segments = [segments for token_ids, segments, valid_len in out]valid_lens = [valid_len for token_ids, segments, valid_len in out]return (np.array(all_token_ids, dtype='int32'),np.array(all_segments, dtype='int32'),np.array(valid_lens))def _mp_worker(self, premise_hypothesis_tokens):p_tokens, h_tokens = premise_hypothesis_tokensself._truncate_pair_of_tokens(p_tokens, h_tokens)tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \* (self.max_len - len(tokens))segments = segments + [0] * (self.max_len - len(segments))valid_len = len(tokens)return token_ids, segments, valid_lendef _truncate_pair_of_tokens(self, p_tokens, h_tokens):# Reserve slots for '<CLS>', '<SEP>', and '<SEP>' tokens

for the BERT

    # inputwhile len(p_tokens) + len(h_tokens) > self.max_len - 3:if len(p_tokens) > len(h_tokens):p_tokens.pop()else:h_tokens.pop()def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx]), self.labels[idx]def __len__(self):return len(self.all_token_ids)

在下载SNLI数据集之后,通过实例化SNLIBERTDataset类来生成训练和测试示例。这些例子将在自然语言推理的训练和测试中分批阅读。

# Reduce batch_size if there is an out of memory error. In the original
BERT

# model, max_len = 512

batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()

data_dir = d2l.download_extract(‘SNLI’)

train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)

test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)

train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,

                               num_workers=num_workers)

test_iter = gluon.data.DataLoader(test_set, batch_size,

                              num_workers=num_workers)

read 549367 examples

read 9824 examples

  1. Fine-Tuning BERT

如图2所示,用于自然语言推理的微调BERT只需要由两个完全连接的层组成的额外MLP(参见自隐藏以及自输出在下面的BERTClassifier类中)。这种MLP将编码前提和假设信息的特殊“”标记的BERT表示转化为自然语言推理的三种输出:蕴涵、矛盾和中性。

class BERTClassifier(nn.Block):

def __init__(self, bert):super(BERTClassifier, self).__init__()self.encoder = bert.encoderself.hidden = bert.hiddenself.output = nn.Dense(3)def forward(self, inputs):tokens_X, segments_X, valid_lens_x = inputsencoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)return self.output(self.hidden(encoded_X[:, 0, :]))

接下来,将预训练的BERT模型BERT输入BERT分类器实例网络,供下游应用程序使用。在一般的BERT微调实现中,只有输出层的参数附加MLP(net.output)从零开始学习。预训练BERT编码器的所有参数(net.encoder)以及附加MLP的隐藏层(net.hidden)将进行微调。

net = BERTClassifier(bert)

net.output.initialize(ctx=ctx)

MaskLM类和NextSentencePred类在使用的mlp中都有参数。这些参数是预训练BERT模型BERT的一部分,因此也是网络中的一部分。然而,这些参数仅用于计算预训练过程中的隐含语言建模损失和下一句预测损失。这两个损失函数与下游应用的微调无关,因此当对BERT进行微调时,MaskLM和NextSentencePred中使用的MLPs的参数不会更新(过期)。
为了允许参数具有过时渐变,在d2l.train_batch_ch13的步进函数中设置标志ignore_stale_grad=True。利用SNLI的训练集(train_iter)和测试集(test_iter)来训练和评估模型网络。由于计算资源有限,训练和测试的准确性还有待进一步提高:将其讨论留在练习中。

lr, num_epochs = 1e-4, 5

trainer = gluon.Trainer(net.collect_params(), ‘adam’, {‘learning_rate’: lr})

loss = gluon.loss.SoftmaxCrossEntropyLoss()

d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, ctx,

           d2l.split_batch_multi_inputs)

loss 0.597, train acc 0.741, test acc 0.713

8563.9 examples/sec on [gpu(0), gpu(1)]

4. Summary

· We can fine-tune the pretrained BERT model for downstream applications, such as natural language inference on the SNLI dataset.

· During fine-tuning, the BERT model becomes part of the model for the downstream application. Parameters that are only related to pretraining loss will not be updated during fine-tuning.

自然语言推理:微调BERT相关推荐

  1. 97. BERT微调、自然语言推理数据集以及代码实现

    1. 微调BERT 2. 句子分类 3. 命名实体识别 4. 问题回答 5. 总结 即使下游任务各有不同,使用BERT微调时只需要增加输出层 但根据任务的不同,输入的表示,和使用的BERT特征也会不一 ...

  2. 自然语言处理NLP星空智能对话机器人系列:深入理解Transformer自然语言处理 基于BERT模型微调实现句子分类

    自然语言处理NLP星空智能对话机器人系列:深入理解Transformer自然语言处理 基于BERT模型微调实现句子分类 目录 基于BERT模型微调实现句子分类案例实战 Installing the H ...

  3. 微调BERT:序列级和令牌级应用程序

    微调BERT:序列级和令牌级应用程序 Fine-Tuning BERT for Sequence-Level and Token-Level Applications 为自然语言处理应用程序设计了不同 ...

  4. WSDM Cup 2019自然语言推理任务获奖解题思路

    WSDM(Web Search and Data Mining,读音为Wisdom)是业界公认的高质量学术会议,注重前沿技术在工业界的落地应用,与SIGIR一起被称为信息检索领域的Top2. 刚刚在墨 ...

  5. bert 多义词_自然语言处理:Bert及其他

    以下内容主要参考了文末列出的参考文献,在此表示感谢! 2018年被认为是NLP技术的new era的开始.在这一年,提出了多种有创新性的技术,而且最后的集大成者Bert在NLP的多项任务中屠榜,造成的 ...

  6. [NLP自然语言处理]谷歌BERT模型深度解析

    BERT模型代码已经发布,可以在我的github: NLP-BERT--Python3.6-pytorch 中下载,请记得start哦 目录 一.前言 二.如何理解BERT模型 三.BERT模型解析 ...

  7. 自然语言处理——谷歌BERT模型深度解析

    BERT模型代码已经发布,可以在我的github: NLP-BERT--Python3.6-pytorch 中下载,请记得start哦 目录 一.前言 二.如何理解BERT模型 三.BERT模型解析 ...

  8. 自然语言处理之BERT

    1.简介 嘘!BERT来了,就是那个同时刷新了11个NLP任务记录的模型.从本质上来bert属于一个预训练模型,模型可以理解上下文和单词之间的关系,也可以理解句子和句子之间的关系.针对不同的任务,可以 ...

  9. 【自然语言处理】BERT GPT

    BERT & GPT 近年来,随着大规模预训练语言模型的发展,自然语言处理领域发生了巨大变革.BERT 和 GPT 是其中最流行且最有影响力的两种模型.在本篇博客中,我们将讨论 BERT 和 ...

最新文章

  1. Javascript教程:AngularJS的五个超酷特性
  2. 用JavaScript实现简单的excel列转sql字符串
  3. HDOJ HDU 2058 The sum problem ACM 2058 IN HDU
  4. 能在微软的网站找到IeWebcontrols的安装文件吗
  5. Microsoft Power BI Desktop概念学习系列之Microsoft Power BI Desktop的官网自带示例数据(图文详解)...
  6. 安装cuda 非root_linux非root用户下安装软件,搭建生产环境
  7. ML.NET Cookbook:(12)我想看看模型的系数
  8. [css] CSS content属性特殊字符有哪些?
  9. UIButton 上的标题添加下划线效果
  10. 礼橙专车、青菜拼车今日起改名啦!
  11. 升级系统服务器出错,win10更新失败80070002错误怎么办
  12. PreparedStatement对象
  13. IBM Rational总经理夏然谈程序员生涯
  14. linux初学者-磁盘配额篇
  15. [Unity 3D] Unity 3D 性能优化 (一)
  16. 最优比例生成树(0/1分数规划)
  17. Java项目部署目录结构与部署方法 打包方法attilax总结 目录 1.1. Java web项目部署目录结构 1 2. Springboot项目的部署结构 2 3. Java项目的开发模式下目录
  18. Python版本的查看
  19. 技嘉主板更新版BIOS
  20. SQL 排序,筛选,过滤,聚合函数

热门文章

  1. Docker入门六部曲——Stack
  2. liunx上mysql源码安装mysql,搞定linux上MySQL编程(一):linux上源码安装MySQL
  3. OpenCV 笔记(07)— Mat 对象输出格式设置(Python 格式、CSV 格式、NumPy 格式、C 语言格式)
  4. 2022-2028年中国模胚行业市场研究及前瞻分析报告
  5. c++一些常见的知识点
  6. python中__dict__与dir()区别
  7. 解决:UnicodeEncodeError: 'ascii' codec can't encode character u'\xa0' in position错误
  8. LeetCode简单题之区域和检索 - 数组不可变
  9. MinkowskiEngine多GPU训练
  10. 【CV】吴恩达机器学习课程笔记第10章