作者:Aravind Pai

翻译:王威力

校对:张一豪

本文约3400字,建议阅读10+分钟

本文介绍了利用Pytorch框架实现文本分类的关键知识点,包括使用如何处理Out of Vocabulary words,如何解决变长序列的训练问题,举出了具体实例。

标签:自然语言处理

总览

  • 学习如何使用PyTorch实现文本分类

  • 理解文本分类中的关键点

  • 学习使用压缩填充方法

介绍

在我的编程历程中,我总是求助于最先进的架构。现在得益于深度学习框架,比如说PyTorch,Keras和 TensorFlow,实现先进的架构已经变得更简单了。这些深度学习框架提供了一种实现复杂模型架构和算法的简单方式,不需要你掌握大量的专业知识和编程技能。总结来说,这是数据科学的金矿。

在本文中,我们将使用PyTorch框架,它以其快速的计算能力而闻名。因此,在本文中,我们将介绍解决文本分类问题的关键点。然后我们将在PyTorch框架实现我们的第一个文本分类器!

提示:在继续浏览前,我推荐先阅读这篇文章:

A Beginner-Friendly Guide to PyTorch and How it Works from Scratch:

https://www.analyticsvidhya.com/blog/2019/09/introduction-to-pytorch-from-scratch/?utm_source=blog&utm_medium=building-image-classification-models-cnn-pytorch

大纲

一、为什么用PyTorch来解决文本分类问题

1. 解决Out of Vocabulary words

2. 解决变长序列

3. 包装器和预训练模型

二、了解问题的场景

三、实现文本分类

一、为什么用PyTorch来解决文本分类问题?

在我们深入专业概念前,我们先快速熟悉一下PyTorch这个框架。PyTorch的基本数据单元是Tensor,类似于python中的numpy数列。使用PyTorch的最重要的两个优点是:

  • 动态的网络——训练过程中网络结构可以变化

  • 多GPU分布式训练

我肯定你在想-为什么我们要用PyTorch处理文本数据?接下来我们讨论一下PyTorch的一些令人难以置信的特性,这些特性使它不同于其他框架,特别是在处理文本数据时。

1. 处理Out of Vocabulary words问题

文本分类模型是在固定数据大小的文本数据集上进行训练的。但是对于推理问题,可能会遇到有些词并不涵盖在词汇集内,这些词叫做Out of Vocabulary words。忽略Out of Vocabulary words可能会导致信息丢失,因此这是一个重要的问题。

为了解决这个问题,PyTorch支持把训练数据集中的稀有词替换为unknown token,帮助我们处理Out of Vocabulary words问题。

另外,PyTorch还提供了处理变长序列的方法。

2. 处理变长文本序列

是否听过,循环神经网络用来解决变长序列的问题,有没有疑惑它是怎么实现的?PyTorch带来了很有用的'Packed Padding sequence',来实现动态循环神经网络。

Padding(填充)是在句子的开头或者结尾填充额外的token的过程。由于每个句子的词的数量不同,我们把长度不同的句子输入,增加padding tokens,扩充以使得句子等长。

由于大部分的框架支持的是静态网络,也就是说在模型训练过程中,模型框架是不变的,因此padding是必要的。尽管padding解决了变长序列的问题,但也带来了新的问题——padding token是增加了新的信息/数据,下面我用一个简单的图来做解释。

下边这个图中,最后一个单词表示的是一个padding token,但它也在产生输出的过程里发生了作用。这个问题可以交给pytorch中的Packed Padding sequence(压缩填充序列)来处理。

压缩填充会忽略掉padding token的部分。这些值,永远不会出现在循环神经网络的训练过程中,可以帮助我们建立动态的循环神经网络。

3. 包装和预训练的模型

Pytorch正在推出先进的框架。Hugging Face 公司的Transformers库提供了超过32个先进的框架用于自然语言理解和自然语言生成。

除此之外,pytorch还提供了一些预训练的模型,可以用很少的代码去解决文本到语言、目标检测等问题。

是不是感觉到难以置信?这些是PyTorch的一些非常有用的特性。现在让我们使用PyTorch解决文本分类问题。

二、了解问题的场景

作为本文的一部分,我们将研究一个非常有趣的问题。

Quora希望在他们的平台上跟踪“不真诚”的问题,以便让用户在分享知识的同时感到安全。在这种情况下,一个不真诚的问题被定义为一个旨在陈述而不是寻求有用答案的问题。为了进一步分解这一点,这里有一些特征可以表明某个特定问题是不真诚的:

  • 具有非中性音调;

  • 贬低或煽动;

  • 不是基于现实;

  • 利用性(乱伦、兽交、恋童癖)来获得震惊的价值,而不是寻求真正的答案。

训练集包括以上被问的问题,以及一个标志它是否是不真诚的标签(target=1表示不真诚)。基本事实中存在噪声,也就是说,训练集并不是完美的。我们的任务是识别一个给定的问题是否是“不真诚的”。

数据下载链接为:

https://drive.google.com/open?id=1fcip8PgsrX7m4AFgvUPLaac5pZ79mpwX

现在是时候用PyTorch编写我们自己的文本分类模型了。

三、实现文本分类

首先导入所有建模需要的包。这里有一个简单的关于我们要用到的包的概览:

  • Torch包是用来定义tensor和tensor上的数学运算;

  • TorchText包是PyTorch中NLP的库,包含一些预处理的脚本和常见的NLP数据集。

为了使结果可复现,我指定了种子值。由于深度学习模型的随机性,在执行时可能会产生不同的结果,因此指定种子值非常重要。

  • 数据预处理:

现在我们来看,如何用field(域)来做文本的预处理。这里有两种域对象——Field和LabelField。我们来快速理解一下它们的区别:

  • Field:data模块的Field对象,用于描述数据集中每一列的预处理步骤;

  • LabelField:LabelField是Field对象只用于分类任务的特例。只用于设置unk_token和sequential,默认都为None。

在使用Field之前,看一下它的不同参数和用途:

  • Tokenize:定义分词方法,这里我们用spacy tokenizer,因为它用了新的分词算法;

  • Lower:把文本转化成小写;

  • Batch_first:输入和输出的第一个维度一般都是batch size。

然后,创建元组构成的列表,每个元组都包含一个列名,第二个值是field对象。另外,按照csv文件中列的顺序,来排列元组,当我们忽略一列的时候,用(None,None)表示。

读入必要的列——问题和标签。

fields = [(None, None), ('text',TEXT),('label', LABEL)]

这个代码块中,我通过定义field对象加载了自定义数据集。

现在让我们将数据集分成训练和验证数据

  • 准备输入和输出序列:

下一步是建立文本的vocabulary并把它们转化为整数序列。Vocabulary包含整个文本中的所有的独立的词,每一个词都分配一个索引。下面是参数:

  • min_freq:当vocabulary中的词的频率低于这个参数值的时候把这个词映射为unknown token;

  • 两个特殊的token,一个unknown tokenpadding token加到vocabulary中:Unknown token是用于处理Out Of Vocabulary words;Padding token是把输入序列变为等长的。

我们建立起vocabulary,用预训练好的词嵌入来初始化单词成向量。如果你想随机初始化词嵌入,可以忽略向量的参数。

接下来,准备训练模型的batch。BucketIterator以需要最小填充量的方式形成批次。

  • 模型架构

现在需要定义模型的架构来解决这个二分类问题。Torch中的nn模块,是一个所有模型的基础模型。也就是说,每个模型都必须是nn模块的子类。

我定义了两个函数,init和forward。我来解释一下这两个函数的应用场景。

  • init:初始化类的实例时,init函数自动被调用。因此,它也叫做构造函数。类的参数需要在构造函数中初始化,我们需要定义模型需要用到的层;

  • forward:forward函数定义了inputs前向传播的计算步骤。

最后,我们理解一下各层的细节问题和参数。

嵌入层:对于任何NLP相关的任务,词嵌入都很重要,因为它可以应用数字形式表示一个单词。嵌入层得到一个查询表,其中每一行代表一个词嵌入。嵌入层可以把表示文本的整数序列转化为稠密向量矩阵。嵌入层的两个重要的参数:

  • num_embeddings:查询表中,单词的的个数;

  • embedding_dim:表示一个单词的向量维度。

LSTM:LSTM是RNN的一个变体,可以学习长的依赖关系。下面列举了LSTM的一些你应该了解的重要参数:

  • input_size:输入向量的维度;

  • hidden_size:隐藏层节点的个数;

  • num_layers:网络中的层数;

  • batch_first:如果等于True,输入输出的tensor的形式都是(batch, seq, feature);

  • dropout:默认值是0,如果设为非0,每个LSTM层输出结果都会进到dropout层,以dropout参数值的概率删掉对应比例的神经元;

  • bidirection:如果为True, LSTM是双向的。

Linear Layer:线性层指的是稠密层,有两个重要参数:

  • in_features:输入特征的个数;

  • out_features:隐藏层节点个数。

压缩填充:上文已经讨论过,压缩填充用于动态循环神经网络。如果不采用压缩填充的话,填充后的输入,rnn也会处理padding输入,并返回padded元素的hidden state。但压缩填充是一个很棒的包装,它不显示填充的输入。它直接忽略填充部分并返回非填充元素部分的hidden state。

现在我们已经对这个架构中的所有板块都有了了解,现在可以去看代码了!

下一步是定义超参、初始化模型。

让我们看看模型摘要,并使用预训练的词嵌入初始化嵌入层。

在这里,我定义了模型的优化器、损失和度量:

建模的两个阶段:

  • 训练阶段:model.train() 设置了模型进入训练,并激活dropout层;

  • 预测阶段:model.eval() 开始模型的评估阶段并关闭dropout层。

接下来是定义用于训练模型的函数的代码块。

所以我们有一个函数来训练模型,但是我们也需要一个函数来评估模型。我们来吧 !

最后,我们将对模型进行一定数量的训练,并保存每个时期的最佳模型。

让我们加载最佳模型并定义一个推理函数,它接受用户定义的输入并进行预测太神了!让我们用这个模型来预测几个问题。

小结

我们已经看到了如何在PyTorch中构建自己的文本分类模型,并了解了压缩填充的重要性。您可以随意使用长短期模型的超参数,如隐藏节点数、隐藏层数等,以进一步提高性能。

原文链接:

https://www.analyticsvidhya.com/blog/2020/01/first-text-classification-in-pytorch/

原文标题:

Build Your First Text Classification model using PyTorch


如您想与我们保持交流探讨、持续获得数据科学领域相关动态,包括大数据技术类、行业前沿应用、讲座论坛活动信息、各种活动福利等内容,敬请扫码加入数据派THU粉丝交流群,红数点恭候各位。

编辑:黄继彦

校对:林亦霖

译者简介

王威力,养老医疗行业BI从业者。保持学习。

翻译组招募信息

工作内容:需要一颗细致的心,将选取好的外文文章翻译成流畅的中文。如果你是数据科学/统计学/计算机类的留学生,或在海外从事相关工作,或对自己外语水平有信心的朋友欢迎加入翻译小组。

你能得到:定期的翻译培训提高志愿者的翻译水平,提高对于数据科学前沿的认知,海外的朋友可以和国内技术应用发展保持联系,THU数据派产学研的背景为志愿者带来好的发展机遇。

其他福利:来自于名企的数据科学工作者,北大清华以及海外等名校学生他们都将成为你在翻译小组的伙伴。

点击文末“阅读原文”加入数据派团队~

转载须知

如需转载,请在开篇显著位置注明作者和出处(转自:数据派ID:DatapiTHU),并在文章结尾放置数据派醒目二维码。有原创标识文章,请发送【文章名称-待授权公众号名称及ID】至联系邮箱,申请白名单授权并按要求编辑。

发布后请将链接反馈至联系邮箱(见下方)。未经许可的转载以及改编者,我们将依法追究其法律责任。

点击“阅读原文”拥抱组织

独家 | 教你用Pytorch建立你的第一个文本分类模型!相关推荐

  1. Python 教你训练一个98%准确率的微博抑郁文本分类模型(含数据)

    Paddle是一个比较高级的深度学习开发框架,其内置了许多方便的计算单元可供使用,我们之前写过PaddleHub相关的文章: 1.Python 识别文本情感就这么简单 2.比PS还好用!Python ...

  2. 从零开始构建基于textcnn的文本分类模型(上),word2vec向量训练,预训练词向量模型加载,pytorch Dataset、collete_fn、Dataloader转换数据集并行加载

    伴随着bert.transformer模型的提出,文本预训练模型应用于各项NLP任务.文本分类任务是最基础的NLP任务,本文回顾最先采用CNN用于文本分类之一的textcnn模型,意在巩固分词.词向量 ...

  3. 独家 | 教你用Scrapy建立你自己的数据集(附视频)

    原文标题:Using Scrapy to Build your Own Dataset 作者:Michael Galarnyk 翻译:李清扬 全文校对:丁楠雅 本文长度为2400字,建议阅读5分钟 数 ...

  4. 从零开始学Pytorch(四)之softmax与分类模型

    softmax的基本概念 分类问题 一个简单的图像分类问题,输入图像的高和宽均为2像素,色彩为灰度. 图像中的4像素分别记为x1,x2,x3,x4x_1, x_2, x_3, x_4x1​,x2​,x ...

  5. 手把手教你搭建Bert文本分类模型,快点看过来吧!

    1 赛题名称 基于文本挖掘的企业隐患排查质量分析模型 2 赛题背景 企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义.企业在填报隐患时,往往存在不认真填报的情况,"虚报.假 ...

  6. Pytorch高阶API示范——DNN二分类模型

    代码部分: import numpy as np import pandas as pd from matplotlib import pyplot as plt import torch from ...

  7. 【Pytorch神经网络理论篇】 31 图片分类模型:ResNet模型+DenseNet模型+EffcientNet模型

    1 ResNet模型 在深度学习领域中,模型越深意味着拟合能力越强,出现过拟合问题是正常的,训练误差越来越大却是不正常的. 1.1 训练误差越来越大的原因 在反向传播中,每一层的梯度都是在上一层的基础 ...

  8. 【Pytorch神经网络理论篇】 30 图片分类模型:Inception模型

    1 Inception系列模型 Incepton系列模型包括V1.V2.V3.V4等版本,主要解决深层网络的三个问题: 训练数据集有限,参数太多,容易过拟合: 网络越大,计算复杂度越大,难以应用: 网 ...

  9. 手把手教你如何用 TensorFlow 实现基于 DNN 的文本分类

    许多开发者向新手建议:如果你想要入门机器学习,就必须先了解一些关键算法的工作原理,然后再开始动手实践.但我不这么认为. 我觉得实践高于理论,新手首先要做的是了解整个模型的工作流程,数据大致是怎样流动的 ...

最新文章

  1. 组队学习:学习者参考手册
  2. 计算机专业西电和大工怎么选,放弃985大连理工,选择211西安电子科大,其实很多人都错了...
  3. Centos7无法使用ssh登陆及解决方案
  4. mysql数据库设计实践_MYSQL教程分享20个数据库设计的最佳实践
  5. qt 主动打开虚拟键盘_ipad键盘有用吗?
  6. 设计企业网站大纲_企业网站设计布局
  7. 面试官扎心一问:Tomcat 在 SpringBoot 中是如何启动的?
  8. maven部署项目到tomcat8中
  9. 计算机一级安装的软件要钱吗,电脑没装这5个软件,基本算是废了
  10. 撰写论文时如何复制参考文献公式----Mathpix及Mathtype教程
  11. 【Oracle】建立关联三个表的视图
  12. 史上最贵的merge代码,新浪程序员因加班错失年会77万大奖!
  13. NginxWebUI--强大的nginx可视化配置工具
  14. Google Play 管理中心新增战略指南,助力游戏收入增长
  15. pythoon_interview_redit
  16. C语言:从键盘输入一个整数,分别输出它的个位数、十位数、百位数.....
  17. c4d-造型工具-6
  18. 01-复杂度2 Maximum Subsequence Sum (25分)(数据结构)(C语言实现)
  19. 实现电路阻抗匹配的两个方法
  20. matplotlib的读书笔记

热门文章

  1. python怎么输出文本_python输出语句怎么用
  2. 学习junit和hamcrest的使用
  3. Android 关于::app:clean :app:preBuild UP-TO-DATE :app:preDebugBuild UP-TO-DATE,引用jar冲突问题...
  4. 《C语言编程初学者指南》一2.9 理解运算符优先级
  5. 【Xamarin挖墙脚系列:现有IPhone/IPad 设备尺寸】
  6. 23、OSPF配置实验之特殊区域Totally NSSA
  7. GIT如何查看本地分支与远程分支的关联配置(git branch --set-upstream)
  8. netstat常用命令
  9. mysql 是否有归档模式_查看oracle数据库是否归档和修改归档模式
  10. 字符串数组-获取两个字符串中最大的相同子串(最大相同子串有且只有一个)