2019-12-31 10:50:59

选自GitHub

作者:Jay Alammar

参与:王子嘉、Geek AI

如果你是一名自然语言处理从业者,那你一定听说过最近大火的 BERT 模型。本文是一份使用简化版的 BERT 模型——DisTillBERT 完成句子情感分类任务的详细教程,是一份不可多得的 BERT 快速入门指南。

在过去的几年中,用于处理语言的机器学习模型取得了突飞猛进的进展。这些进展已经走出了实验室,开始为一些先进的数字产品赋能。最近成为 Google 搜索背后主要力量的 BERT 就是一个很好的例子。Google 认为这一进步(或者说在搜索中应用自然语言理解技术的进步)代表着「过去五年中最大的飞跃,也是搜索历史上最大的飞跃之一」。

本文是一篇简单的教程,介绍如何使用不同的 BERT 对句子进行分类。本文中的例子深入浅出,也足以展示 BERT 使用过程中所涉及的关键概念。
除了这篇博文,我还准备了一份对应的 notebook 代码,链接如下:
https://github.com/jalammar/jalammar.github.io/blob/master/notebooks/bert/A_Visual_Notebook_to_Using_BERT_for_the_First_Time.ipynb。
你也可以在 colab 中运行这份代码:
https://colab.research.google.com/github/jalammar/jalammar.github.io/blob/master/notebooks/bert/A_Visual_Notebook_to_Using_BERT_for_the_First_Time.ipynb

数据集:SST2

本文示例中使用的数据集为「SST2」,该数据集收集了一些影评中的句子,每个句子都有标注,好评被标注为 1,差评标注为 0。

sentencelabela stirring , funny and finally transporting re imagining of beauty and the beast and 1930s horror films1apparently reassembled from the cutting room floor of any given daytime soap0they presume their audience won't sit still for a sociology lesson0this is a visually stunning rumination on love , memory , history and the war between art and commerce1jonathan parker 's bartleby should have been the be all end all of the modern office anomie films1

模型:句子情感分类

我们的目标是创建一个分类器,它的输入是一句话(即类似于数据集中的句子),并输出一个 1(表示这句话体现出了积极的情感)或是 0(表示这句话体现出了消极的情感)。整体的框架如下图所示:

实际上,整个模型由两个子模型组成:

  • DistilBERT 先对句子进行处理,并将它提取到的信息传给下个模型。DistilBERT 是 BERT 的缩小版,它是由 HuggingFace 的一个团队开发并开源的。它是一种更轻量级并且运行速度更快的模型,同时基本达到了 BERT 原有的性能。
  • 另外一个模型,是 scikit learn 中的 Logistic 回归模型,它会接收 DistilBERT 处理后的结果,并将句子分类为积极或消极(0 或 1)。

我们在两个模型之间传递的数据是 768 维的向量。我们可以把这个向量看做能够用于分类任务的句子嵌入。

如果你读过我以前写的关于 Iluustrated BERT 的文章(https://jalammar.github.io/illustrated-bert/),这个向量其实就是以词 [CLS] 为输入的第一个位置上的结果。

模型训练

尽管整个模型包含了两个子模型,但是我们只需要训练 logistic 回归模型。至于 DistillBERT,我们会直接使用已经在英文上预训练好的模型。然而,这个模型不需要被训练,也不需要微调,就可以进行句子分类。BERT 训练时的一般目标让我们已经有一定的句子分类能力了,尤其是 BERT 对于第一个位置的输出(也就是与 [CLS] 对应的输出)。我觉得这主要得益于 BERT 的第二个训练目标——次句预测(Next sentence classification)。该目标似乎使得它第一个位置的输出封装了句子级的信息。Transformers 库给我们提供了 DistilBERT 的一种实现以及预训练好的模型。

教程概览

这篇教程的计划如下:我们首先用预训练好的 distilBERT 来生成 2,000 个句子的嵌入。

后面我们就不会再涉及 distilBERT 了。剩下的都是 Scikit Learn 的工作了,我们接下来要做的就是将数据集分成训练集和测试集。

将 distilBert(模型 #1)的输出划分为训练集和测试集后,就得到了我们训练和评估 logistic 回归(模型 #2)的数据集了。请注意,在现实中,sklearn 在将数据划分为训练集和测试集之前,会将样本打乱,而不是直接按照数据原来的顺序取前 75%。
接下来,我们要在训练集上训练逻辑 logistic 回归模型了:

每个预测值是怎么计算出来的?

在我们深入研究训练模型的代码前,让我们先看一下训练好的模型是如何计算出其预测值的。假设我们正在对句子「a visually stunning rumination on love」进行分类。第一步要做的就是用 BERT 分词器(tokenizer)将这些单词(word)划分成词(token)。然后,我们再加入句子分类所需的特殊词(在句子开始加入 [CLS],末端加入 [SEP])。

分词器要做的第三步就是查表,将这些词(token)替换为嵌入表中对应的编号,我们可以从预训练模型中得到这张嵌入表。如果对词嵌入不太了解,请参阅「The Illustrated Word2vec」:
https://jalammar.github.io/illustrated-word2vec/。

注意,只需要一行代码就可以完成分词器的所有工作:

tokenizer.encode("avisuallystunningruminationonlove",add_special_tokens=True)

现在,我们输入的句子是可以传递给 DistilBERT 的形式。
如果你读过「Illustrated BERT」(https://jalammar.github.io/illustrated-bert/),这一步也可以被可视化为下图:

DistilBERT 的数据流

将输入向量传递给 DistilBert 之后的工作方式就跟在 BERT 中一样,每个输入的词都会得到一个由 768 个浮点数组成的输出向量。

由于这是个句子分类任务,我们只关心第一个向量(与 [CLS] 对应的向量)。
该向量就是我们输入给 logistic 回归模型的向量。

从这一步开始,logistic 回归模型的任务就是:根据它在训练阶段学到的经验,对这个向量进行分类。我们可以把整个预测过程想象成这样:

我们将在下一节中讨论整个训练过程及其相关代码。

代码

在本节中,我们重点介绍训练这个句子分类模型的代码。读者可以从 colab 和 github 上获取完成本任务所需的所有 notebook 代码(链接见本文第二段)。
首先,我们需要导入相关的程序包。

importnumpyasnpimportpandasaspdimporttorchimporttransformersasppb#pytorchtransformersfromsklearn.linear_modelimportLogisticRegressionfromsklearn.model_selectionimportcross_val_scorefromsklearn.model_selectionimporttrain_test_split

数据集的链接如下:https://github.com/clairett/pytorch-sentiment-classification/。我们可以直接将其导入为一个 pandas 数据帧。

df=pd.read_csv('https://github.com/clairett/pytorch-sentiment-classification/raw/master/data/SST2/train.tsv',delimiter='\t',header=None)

我们可以通过「df.head()」查看前五行数据帧的内容:

df.head()

输出如下:

导入预训练好的 DistilBERT 模型与分词器

model_class,tokenizer_class,pretrained_weights=(ppb.DistilBertModel,ppb.DistilBertTokenizer,'distilbert-base-uncased')##WantBERTinsteadofdistilBERT?Uncommentthefollowingline:#model_class,tokenizer_class,pretrained_weights=(ppb.BertModel,ppb.BertTokenizer,'bert-base-uncased')#Loadpretrainedmodel/tokenizertokenizer=tokenizer_class.from_pretrained(pretrained_weights)model=model_class.from_pretrained(pretrained_weights)

现在,我们可以对数据集进行分词操作了。请注意,我们这里要做的事情跟上文中的示例不太一样。上文中的例子只对一个句子进行了分词和处理。在这里,我们要将所有句子作为同一批进行分词和处理。在 notebook 中,考虑到算力的问题,我们只用了 2000 个句子。

分词

tokenized=df[0].apply((*lambda*x:tokenizer.encode(x,add_special_tokens=True)))T

上面的代码将所有句子转化为了编号的列表。

现在,数据集变成了一个包含很多列表结构(或是 Pandas 的数组和数据帧)的列表。在 DistilBERT 可以将它们作为输入数据处理之前,我们需要把这些向量整理成相同的维度(在较短的句子后面填充上编号「0」)。你可以在 notebook 中看到「填充」操作对应的代码,它们就是 python 字符串和数组的简单操作。
完成「填充」操作后,我们就得到了 BERT 可以接收的矩阵/张量了。

DistilBERT 的处理

现在,我们根据填充后的词(token)矩阵创建一个输入张量,并将其传递给 DistilBERT

input_ids=torch.tensor(np.array(padded))withtorch.no_grad():last_hidden_states=model(input_ids)

这一步完成后,会将 DistilBERT 的输出赋给「last_hidden_states」。这是一个维度为(句子数,序列中的最大词数,DistilBERT 模型中的隐藏层数)的元组。在本例中,这个维度就是(2000,66,768),因为我们有 2000 个句子,2000 个句子中最长的序列长度为 66,DistilBERT 模型中有 768 个隐藏层。

展开 BERT 的输出张量

接下来我们要把这个三维的输出张量展开。我们可以先分析一下它的维度:

回顾这些句子的「一生」

每一行代表数据集中的一个句子。第一个句子的处理过程如下图所示:

切片得到重要的部分

对于句子分类任务来说,我们只对 BERT 得到的 [CLS] 对应的输出感兴趣,所以我们只保留「立方体」中 [CLS] 对应的切片,删掉了其它内容。

这就是从三维张量中获取我们需要的二维张量的方法:

#Slicetheoutputforthefirstpositionforallthesequences,takeallhiddenunitoutputsfeatures=last_hidden_states[0][:,0,:].numpy()

现在,「features」是一个二维的 numpy 数组,它包含了我们数据集中所有句子的嵌入。

我们从 BERT 输出中切片得到的张量。

Logistic 回归中的数据集

现在我们已经得到了 BERT 的输出,并且把训练 logistic 回归模型的数据集组装好了。这 768 列是特征,我们还给出了原始数据集中的标签。

我们用于训练 logistic 回归的有标签数据集。这些特征就是 BERT 中对应 [CLS](位置 #0)的部分,也就是我们前一张图中切片得到的部分。每一行对应着数据集中的一个句子,每一列对应着 BERT/DistilBERT 模型中顶层 transformer 模块中前馈神经网络的隐藏单元。
在完成了机器学习传统的训练集/测试集划分操作后,我们就可以创建我们自己的 logistic 回归模型,并用我们的数据集进行训练了。

labels=df[1]train_features,test_features,train_labels,test_labels=train_test_split(features,labels)

上面的代码把数据集分成了训练集和测试集:

接下来,我们要在训练集上训练 logistic 回归模型。

lr_clf=LogisticRegression()lr_clf.fit(train_features,train_labels)

现在模型已经训练好了,我们可以用测试集对其进行评估。

lr_clf.score(test_features,test_labels)

上面一段代码的输出告诉我们这个模型获得了 81% 的准确率。

评分标准

作为参照,在这个数据集上所能达到的最高的准确率是 96.8。在这个任务中,我们同样可以训练 DistilBERT 来提高其得分,也就是通常所说的调优过程,该过程可以更新 BERT 的权重参数,并让其在句子分类任务中(通常被称为「下游任务」)获得更好的性能。经过调优后的 DistilBERT 可以达到 90.7 的准确率,而完整的 BERT 可以达到 94.9。

Notebook 代码
本文开头给出了相关 notebook 代码的链接,自己去探索一下吧!
这就是本文的全部内容了,这些内容对于第一次接触 BERT 的人来说应该是非常适用的。而下一步就是查看文档并试着进行调优了。你也可以回过头来把代码中的 DistilBERT 转换成 BERT,看看它又是如何工作的。

BERT模型超酷炫,上手又太难?请查收这份BERT快速入门指南相关推荐

  1. BERT模型超酷炫,上手又太难?请查收这份BERT快速入门指南!

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 来自 | GitHub    作者 | Jay Alammar 转自 | 机器之心 如 ...

  2. pr cpu100%_6款超酷炫又小众的PR插件 据说都用过就是大神!

    PR作为目前最主流的视频编辑软件 拥有大量的配套插件模板等等 我们也推送过PR大量常用的插件 如转场.调色.特效等等 但PR还许多超酷炫又小众的插件 本期菌菌就为大家精选了6款插件 帮助小伙伴更好的进 ...

  3. Lively Wallpaper ---- 超酷炫的桌面壁纸

    最近在github 上找到一个超酷炫的桌面壁纸程序,支持各种格式的素材作为壁纸 下载链接: https://github.com/rocksdanister/lively/releases/tag/v ...

  4. 有什么做GIF的软件?这3款APP超酷炫

    有什么做GIF的软件?喜欢聊微信的你们都有这种感觉吧:聊天的时候,大部分情况下都是动态图片比静态图片更能表达情绪.经常会看到一张有意思的动态图片后,像中了魔一样反复观看,说不定还会嘻嘻嘻地笑出声来吓到 ...

  5. HTML5七夕情人节表白网页制作【飘落蒲公英动画超酷炫的HTML5页面】HTML+CSS+JavaScript

    这是程序员表白系列中的100款网站表白之一,旨在让任何人都能使用并创建自己的表白网站给心爱的人看. 此波共有100个表白网站,可以任意修改和使用,很多人会希望向心爱的男孩女孩告白,生性腼腆的人即使那个 ...

  6. HTML5七夕情人节表白网页_飘落蒲公英动画超酷炫的HTML5页面_ HTML+CSS+JS 求婚 html生日快乐祝福代码网页 520情人节告白代码 程序员表白源码 抖音3D旋转相册

    HTML5七夕情人节表白网页❤飘落蒲公英动画超酷炫的HTML5页面❤ HTML+CSS+JS 求婚 html生日快乐祝福代码网页 520情人节告白代码 程序员表白源码 抖音3D旋转相册 js烟花代码 ...

  7. 推荐系统[一]:超详细知识介绍,一份完整的入门指南,解答推荐系统相关算法流程、衡量指标和应用,以及如何使用jieba分词库进行相似推荐,业界广告推荐技术最新进展

    搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排精排重排混排).系统架构.常见问题.算法项目实战总结.技术细节以及项目实战(含码源) 专栏详细介绍:搜索推荐系统专栏简介:搜索推荐全流程讲解(召回粗排 ...

  8. micro-app-vue2 vue3 超详细快速入门指南 学习记录

    micro-app-vue 快速入门指南 简介 micro-app是京东零售推出的一款微前端框架,它基于类WebComponent进行渲染,从组件化的思维实现微前端,旨在降低上手难度.提升工作效率.它 ...

  9. 计算机怎么检测扫描机,再见扫描仪!微信这个功能太赞了,100份合同快速录入电脑...

    原标题:再见扫描仪!微信这个功能太赞了,100份合同快速录入电脑 最近领导丢了好几份合同给我,说是要把这些合同给统一录入到电脑里然后发给他,眼看着最后的期限就要到了,还有N份文件躺在桌上嗷嗷待录~ 与 ...

最新文章

  1. android程序排序算法实现
  2. 用户sa 登陆失败 SQLServer 错误18456----解决方法
  3. uniapp滑动切换tab标签_Web前端,Tab切换,缓存,页面处理的几种方式
  4. Django环境配置
  5. 两边双虚线是什么意思_行星减速机生产厂家解析行星减速机双支撑与单支撑
  6. 【项目合作】最高50万元!道路缝隙检测、目标跟踪优化、机器人平台开发
  7. 全局替换资源_BitLocker+VHD替换TrueCrypt及其后继VeraCrypt
  8. 面向资源的权限体系设计随想
  9. 构建smaba服务器
  10. e.style.opacity 通过javascript调用元素的样式属性
  11. Palantir:野心贼大,想做世界的创新引擎(附纪要)| 国君计算机李沐华
  12. 君康人寿2019年排名_2019中国保险公司竞争力报告出炉 君康人寿盈利能力排名第二...
  13. 发动机压缩比怎么计算公式_加几号油它说的算 解析发动机压缩比奥秘
  14. 域名已注册好,如何做网站?
  15. raw的服务器镜像是什么系统,如何将Ceph Raw格式镜像转换成Qcow2格式并上传云平台创建云主机...
  16. 华为教父任正非的别样视野(转)
  17. Dw如何创建网页链接?
  18. Outlook打不开的问题
  19. 嵌入式硬件(四)常用模拟集成电路
  20. maven 打包报错:The following files had format violations

热门文章

  1. sequence.pad_sequences 的用法举例
  2. 运行NER/formal_bert_lstm_crf.py“ 报错ModuleNotFoundError: No module named ‘keras_contrib‘
  3. HMM和贝叶斯网络的关系
  4. 自然语言处理好的 实体分词 及BERT
  5. 《强化学习周刊》第21期:EMNLP 2020-2021强化学习的最新研究与应用
  6. Percy Liang、李飞飞等百余位学者联名发布:「基础模型」的机遇与挑战
  7. 线上直播丨KDD 2021预训练Workshop,谷歌MSRA等5位顶尖研究者参与研讨
  8. 领域适配前沿研究——场景、方法与模型选择
  9. 区块链论文:OmniLedger,一种区块链分片技术
  10. Python图像处理:形态学操作