本教程展示如何在torchtext中调用文本分类数据集,包括:

  • AG_NEWS,
  • SogouNews,
  • DBpedia,
  • YelpReviewPolarity,
  • YelpReviewFull,
  • YahooAnswers,
  • AmazonReviewPolarity,
  • AmazonReviewFull

这个例子展示了如何用这些文本分类TextClassification数据集之一训练一个有监督学习算法。

使用ngrams加载数据

一组ngrams特征被用于获取局部单词顺序的一些局部信息。实际中,bi-gramtri-gram 作为词组被用于获取比单独一个单词更好的效益。比如:

"load data with ngrams"
Bi-grams results: "load data", "data with", "with ngrams"
Tri-grams results: "load data with", "data with ngrams"

文本分类TextClassification 数据集支持ngrams方法。通过设定ngrams为2,数据集中的样本会被处理为单个单词加bi-grams字符串的列表。

%matplotlib inline
import torch
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./.data'):os.mkdir('./.data')
# 此处为原教程中使用的代码,可以自动下载所需数据,国内由于网络原因会导致连接失败,后面会导入本地数据
# train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
#     root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

手动下载数据并导入

由于网络原因国内不能自动下载数据(百度网盘地址 提取码: 2vj9),需要手动下载数据压缩包ag_news_csv.tar.gz,并将其放到'./.data/'文件夹下

# 导入需要的库及函数
import logging
from torchtext.utils import extract_archive, unicode_csv_reader
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets.text_classification import *
from torchtext.datasets.text_classification import _csv_iterator,_create_data_from_iterator# 定义创建数据集函数,原函数在torchtext.datasets.text_classification文件中,本教程所需参数直接设成了默认值
def _setup_datasets(dataset_tar='./.data/ag_news_csv.tar.gz',dataset_name="AG_NEWS", root='./.data', ngrams=NGRAMS, vocab=None, include_unk=False):# 注释掉下载数据的代码#     dataset_tar = download_from_url(URLS[dataset_name], root=root)extracted_files = extract_archive(dataset_tar)  #解压数据文件for fname in extracted_files:if fname.endswith('train.csv'):train_csv_path = fnameif fname.endswith('test.csv'):test_csv_path = fnameif vocab is None:logging.info('Building Vocab based on {}'.format(train_csv_path))vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams)) #创建词典else:if not isinstance(vocab, Vocab):raise TypeError("Passed vocabulary is not of type Vocab")logging.info('Vocab has {} entries'.format(len(vocab)))logging.info('Creating training data')train_data, train_labels = _create_data_from_iterator(   #创建训练数据vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk) logging.info('Creating testing data')test_data, test_labels = _create_data_from_iterator(   #创建测试数据vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)if len(train_labels ^ test_labels) > 0:raise ValueError("Training and test labels don't match")return (TextClassificationDataset(vocab, train_data, train_labels),  #返回数据集实例TextClassificationDataset(vocab, test_data, test_labels))
train_dataset, test_dataset = _setup_datasets()

输出:

120000lines [00:07, 16793.47lines/s]
120000lines [00:13, 9134.56lines/s]
7600lines [00:00, 9351.33lines/s]

定义模型(Define the model)

模型是由EmbeddingBag
层和线性层组成,如下图所示。nn.EmbeddingBag计算嵌入层的“bag”的平均值。这里的文本条目有不同的长度。nn.EmbeddingBag要求没有填充,因此文本长度被存储在偏置中。
此外,由于 nn.EmbeddingBag积累了词嵌入的平均值,nn.EmbeddingBag 可以提高处理张量序列的效果和内存效率。

import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

初始化实例(Initiate an instance)

AG_NEWS 数据集有四种标签,因此类别数量为4.

1 : World
2 : Sports
3 : Business
4 : Sci/Tec

词典大小等于词典长度(包括单个单词和ngrams)。类别数量等于标签数量,对于 AG_NEWS 来说是4。

VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

产生训练批次的函数

由于文本长度不同,自定义了generate_batch()函数用于产生文本批次和偏置。函数被传递给 torch.utils.data.DataLoader中的collate_fncollate_fn的输入是大小为batch_size的张量的列表,collate_fn将列表打包进最小批次(mini-batch)。请注意,要确保collate_fn被声明为顶层def(声明)。这确保了这个函数可以在任何位置被调用。

原始数据批次输入的文本条目被打包到一个列表并串联成了一个单独张量作为nn.EmbeddingBag的输入。偏置(offsets)是由分隔符组成的张量,分隔符用来表示文本张量中每个独立序列的起始索引。标签(Label)是保存每个文本条目的标签的张量。

def generate_batch(batch):label = torch.tensor([entry[0] for entry in batch])text = [entry[1] for entry in batch]offsets = [0] + [len(entry) for entry in text]# torch.Tensor.cumsum 返回dim维度元素的累积和# torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)text = torch.cat(text)return text, offsets, label

定义函数训练模型并验证结果

为PyTorch用户推荐torch.utils.data.DataLoader,它可以轻松的使数据加载并行化(相关教程)。这里使用DataLoader 加载AG_NEWS数据集并将其传递给模型用于训练/验证。

from torch.utils.data import DataLoaderdef train_func(sub_train_):# 训练模型train_loss = 0train_acc = 0data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,collate_fn=generate_batch)for i, (text, offsets, cls) in enumerate(data):optimizer.zero_grad()text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)output = model(text, offsets)loss = criterion(output, cls)train_loss += loss.item()loss.backward()optimizer.step()train_acc += (output.argmax(1) == cls).sum().item()# 调节学习率scheduler.step()return train_loss / len(sub_train_), train_acc / len(sub_train_)def test(data_):loss = 0acc = 0data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)for text, offsets, cls in data:text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)with torch.no_grad():output = model(text, offsets)loss = criterion(output, cls)loss += loss.item()acc += (output.argmax(1) == cls).sum().item()return loss / len(data_), acc / len(data_)

分割数据集及训练模型

由于原始AG_NEWS数据没有校验集,本文按照0.95(训练)和0.05(校验)的比例将训练数据分割为训练/校验集。这里使用PyTorch核心库中的torch.utils.data.dataset.random_split函数。

CrossEntropyLoss标准将nn.LogSoftmax()nn.NLLLoss()合并到一个类。当训练C个类别的分类问题时非常有用。SGD作为优化器实现了随机梯度下降方法。初始学习率被设置为4.0。StepLR用于每次训练(epochs)调整学习率。

import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \random_split(train_dataset, [train_len, len(train_dataset) - train_len])for epoch in range(N_EPOCHS):start_time = time.time()train_loss, train_acc = train_func(sub_train_)valid_loss, valid_acc = test(sub_valid_)secs = int(time.time() - start_time)mins = secs / 60secs = secs % 60print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

输出:

Epoch: 1  | time in 0 minutes, 22 secondsLoss: 0.0261(train) |   Acc: 84.7%(train)Loss: 0.0001(valid)    |   Acc: 90.5%(valid)
Epoch: 2  | time in 0 minutes, 17 secondsLoss: 0.0119(train)    |   Acc: 93.6%(train)Loss: 0.0001(valid)    |   Acc: 90.9%(valid)
Epoch: 3  | time in 0 minutes, 9 secondsLoss: 0.0069(train) |   Acc: 96.5%(train)Loss: 0.0000(valid)    |   Acc: 89.9%(valid)
Epoch: 4  | time in 0 minutes, 22 secondsLoss: 0.0039(train)    |   Acc: 98.1%(train)Loss: 0.0000(valid)    |   Acc: 91.4%(valid)
Epoch: 5  | time in 0 minutes, 22 secondsLoss: 0.0022(train)    |   Acc: 99.0%(train)Loss: 0.0000(valid)    |   Acc: 91.3%(valid)

使用测试数据评价模型

print('Checking the results of test dataset...')
test_loss, test_acc = test(test_dataset)
print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')

输出:

Checking the results of test dataset...Loss: 0.0003(test)    |   Acc: 88.3%(test)

在随机新闻上测试

使用目前最好的模型并测试一条高尔夫新闻。标签信息可以在此获得。

import re
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizerag_news_label = {1 : "World",2 : "Sports",3 : "Business",4 : "Sci/Tec"}def predict(text, model, vocab, ngrams):tokenizer = get_tokenizer("basic_english")with torch.no_grad():text = torch.tensor([vocab[token]for token in ngrams_iterator(tokenizer(text), ngrams)])output = model(text, torch.tensor([0]))return output.argmax(1).item() + 1ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \enduring the season’s worst weather conditions on Sunday at The \Open on his way to a closing 75 at Royal Portrush, which \considering the wind and the rain was a respectable showing. \Thursday’s first round at the WGC-FedEx St. Jude Invitational \was another story. With temperatures in the mid-80s and hardly any \wind, the Spaniard was 13 strokes better in a flawless round. \Thanks to his best putting performance on the PGA Tour, Rahm \finished with an 8-under 62 for a three-stroke lead, which \was even more impressive considering he’d never played the \front nine at TPC Southwind."vocab = train_dataset.get_vocab()
model = model.to("cpu")print("This is a %s news" %ag_news_label[predict(ex_text_str, model, vocab, 2)])

输出:

This is a Sports news

这是一个运动新闻

样例代码见此笔记。

[翻译Pytorch教程]NLP部分:使用TorchText进行文本分类相关推荐

  1. [翻译Pytorch教程]NLP从零开始:使用序列到序列网络和注意力机制进行翻译

    翻译自官网手册:NLP From Scratch: Translation with a Sequence to Sequence Network and Attention Author: Sean ...

  2. [翻译Pytorch教程]NLP从零开始:使用字符级RNN进行名字生成

    翻译自官网手册:NLP From Scratch: Generating Names with a Character-Level RNN Author: Sean Robertson 原文githu ...

  3. 【NLP傻瓜式教程】手把手带你RCNN文本分类(附代码)

    继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...

  4. 【NLP傻瓜式教程】手把手带你HAN文本分类(附代码)

    继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...

  5. 【NLP傻瓜式教程】手把手带你fastText文本分类(附代码)

    写在前面 已经发布: [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) 继续NLP傻瓜式教程系列,今天的教程是基于FAIR的Bag of ...

  6. 【NLP保姆级教程】手把手带你RNN文本分类(附代码)

    写在前面 这是NLP保姆级教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关于循环神经网络在多任务文本分类上的应用:Recur ...

  7. pytorch 定义torch类型数据_PyTorch 使用TorchText进行文本分类

    本教程演示如何在 torchtext 中使用文本分类数据集,包括 - AG_NEWS, - SogouNews, - DBpedia, - YelpReviewPolarity, - YelpRevi ...

  8. pytorch 定义torch类型数据_PyTorch 使用 TorchText 进行文本分类

    本教程介绍了如何使用torchtext中的文本分类数据集,包括- AG_NEWS, - SogouNews, - DBpedia, - YelpReviewPolarity, - YelpReview ...

  9. 『NLP学习笔记』TextCNN文本分类原理及Pytorch实现

    TextCNN文本分类原理及Pytorch实现 文章目录 一. TextCNN网络结构 1.1. CNN在文本分类上得应用 1.2. 回顾CNN以及Pytorch解析 1.2.1. CNN特点 1.2 ...

最新文章

  1. springboot 集成jpa_基于Spring Boot+JPA Restful 风格的数据
  2. 系统视频教学视频教程_太极拳教学视频教程,董氏太极拳基本功训练方法视频...
  3. window.location操作url对象
  4. int, float, double之间不得不说的故事
  5. nginx下虚拟目录配置301域名重定向
  6. 换一种方式“写代码 编程序“,为自己的程序生涯找条新路
  7. chrome 历史版本和chrome webDriver历史版本
  8. GaussDB系列数据库简介
  9. matlab 入射线反射线,ray 射线追踪的仿真小程序最多可以模拟三次反射, 出 图 matlab 272万源代码下载- www.pudn.com...
  10. SEO整体优化有哪些操作步骤
  11. 牛客 检测命令是否正确
  12. 英语入门学习笔记2:英语语法知识树
  13. 手机的内核版本、基带版本等都是什么意思?
  14. 【洛谷】P3957 [NOIP2017 普及组] 跳房子
  15. 勒索病毒现状和防御勒索病毒最佳实践(云端和线下个人电脑,服务器都可部署)
  16. PMP考试 工作绩效数据 工作绩效信息 工作绩效报告 区别与联系
  17. 南华大学计算机系宿舍,2021年南华大学新生宿舍条件和宿舍环境图片
  18. 3、Qt5 主窗口点击按钮 弹出另一个自定义窗口
  19. [矩阵的QR分解系列四] QR(正交三角)分解
  20. 数据库主流容灾方案对比分析

热门文章

  1. 更改Mysql数据库密码
  2. 勒索病毒-特洛伊木马变种
  3. Wince Battery driver
  4. 安卓讲课笔记2.1Activity概述——上机操作
  5. 森林防火综合解决方案
  6. 国内外几个主流的CMS系统推荐
  7. 讲道理 | 特征值和特征向量意义
  8. 数值模拟使用matlab实现案例
  9. 怎么写网站的需求文档
  10. word表格分页时怎样能自动生成表头