使用 TorchText 进行文本分类

  • 1.访问原始数据集迭代器
  • 2. 准备数据处理管道
  • 3. 生成数据批次和迭代器
  • 4. 定义模型
  • 5. 初始化一个实例
  • 6. 定义训练模型和评估结果的函数
  • 7. 拆分数据集并运行模型
  • 8. 全部代码
  • 小结

这是官方文本篇的一个教程,原1.4版本Pytorch中文链接,1.7版本Pytorch中文链接,原英文文档,介绍了如何使用torchtext中的文本分类数据集,本文是其详细的注解,关于TorchText API的官方英文文档,参考此和此博客


ngrams功能用于捕获有关本地单词顺序的一些部分信息。 在实践中,应用二元语法或三元语法作为单词组比仅仅一个单词提供更多的好处。 一个例子:

"load data with ngrams"
Bi-grams results: "load data", "data with", "with ngrams"
Tri-grams results: "load data with", "data with ngrams"

TextClassification数据集支持 ngrams 方法。 通过将 ngrams 设置为 2,数据集中的示例文本将是一个单字加 bi-grams 字符串的列表


pip install torchtext

原文的这个from torchtext.datasets import text_classification代码是错的,而且text_classification.DATASETS['AG_NEWS']的参数都变了,详见英文手册


torchtext 库提供了一些原始数据集迭代器,这些迭代器产生原始文本字符串。例如,AG_NEWS数据集迭代器产生的原始数据是标签和文本的元组

使用此函数时train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))会报错:

TimeoutError: [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond


URL = {'train': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv",'test': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv",
from torchtext.datasets import AG_NEWS
path = '... your path\\AG_NEWS.data'train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))print(next(train_data))
(3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")
(3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.')

2. 准备数据处理管道





[vocab[token] for token in ['here', 'is', 'an', 'example']]
>>> [476, 22, 31, 5298]



text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5286]
>>> 9

3. 生成数据批次和迭代器

torch.utils.data.DataLoader 推荐给 PyTorch 用户使用(教程在这里)。它适用于实现 getitem()len()协议的地图式数据集,并表示从索引/键到数据样本的映射。它也适用于shuffle argumnent为False的可迭代数据集

在发送至模型之前, collate_fn 函数对 DataLoader 中生成的一批样本进行处理。collate_fn的输入是DataLoader中批量大小的数据, collate_fn根据之前声明的数据处理管道对它们进行处理。这里要注意,一定要将 collate_fn 声明为顶层 def,这样才能保证该函数在每个 worker 中都能使用



x = torch.arange(0, 6).view(2, 3)
tensor([[0, 1, 2],[3, 4, 5]])
tensor([[0, 1, 2],[3, 5, 7]])
tensor([[ 0,  1,  3],[ 3,  7, 12]])


4. 定义模型

该模型由nn.EmbeddingBag层加上一个线性层组成,以达到分类的目的。nn.EmbeddingBag默认模式为 “mean”,计算一个 "袋 "的嵌入物的平均值。虽然这里的文本条目有不同的长度,但由于文本长度是以偏移量保存的,所以nn.EmbeddingBag模块在这里不需要填充




>>> # an Embedding module containing 10 tensors of size 3
>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])
>>> offsets = torch.LongTensor([0,4])
>>> embedding_sum(input, offsets)
tensor([[-0.8861, -5.4350, -0.0523],[ 1.1306, -2.5798, -1.0044]])

5. 初始化一个实例


1 : World
2 : Sports
3 : Business
4 : Sci/Tec


6. 定义训练模型和评估结果的函数


torch.optim.lr_scheduler.StepLR每隔一个step_size epochs,将每个参数组的学习率按gamma衰减。请注意,这种衰减可以与其他来自这个调度器外部的学习率变化同时发生。当last_epoch=-1时,设置初始lr为lr

关于torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)函数,作用是剪切参数迭代的梯度法线,官方文档,法线是在所有梯度上一起计算的,就像它们被连成一个向量一样。梯度是就地修改的,即:梯度剪切,规定了最大不能超过的max_norm


tensor([[ 0.4427,  0.0830,  0.0109,  0.1273],[ 0.1601,  0.0869, -0.0540,  0.0422],...
tensor([0, 0, 0, 3, 1, 1, 1, 3, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 1, 3, 3,3, 1, 1, 2, 1, 2, 1, 1, 3, 3, 1, 1, 1, 3, 1, 3, 0, 1, 0, 0, 1, 3, 3, 3,2, 3, 1, 3, 3, 3, 1, 3, 3, 1, 1, 2, 0, 2, 1, 3])


print(predited_label.argmax(1) == label)
tensor([False,  True,  True,  True, False,  True,  True,  True,  True,  True,True, False,  True,  True,  True, False,  True,  True,  True,  True,True,  True,  True,  True,  True,  True,  True,  True,  True,  True,True, False,  True, False,  True,  True,  True,  True, False,  True,True,  True, False,  True, False,  True,  True, False,  True,  True,True, False, False,  True,  True, False,  True, False, False,  True,False,  True,  True,  True])


(predited_label.argmax(1) == label).sum().item()

7. 拆分数据集并运行模型




| epoch   1 |   500/ 1782 batches, accuracy    0.685
| epoch   1 |  1000/ 1782 batches, accuracy    0.852
| epoch   1 |  1500/ 1782 batches, accuracy    0.876
| end of epoch   1 | time: 15.24s | valid accuracy    0.886
| epoch   2 |   500/ 1782 batches, accuracy    0.896
| epoch   2 |  1000/ 1782 batches, accuracy    0.902
| epoch   2 |  1500/ 1782 batches, accuracy    0.902
| end of epoch   2 | time: 15.20s | valid accuracy    0.899
| epoch   3 |   500/ 1782 batches, accuracy    0.915
| epoch   3 |  1000/ 1782 batches, accuracy    0.914
| epoch   3 |  1500/ 1782 batches, accuracy    0.915
| end of epoch   3 | time: 15.22s | valid accuracy    0.904
| epoch   4 |   500/ 1782 batches, accuracy    0.924
| epoch   4 |  1000/ 1782 batches, accuracy    0.924
| epoch   4 |  1500/ 1782 batches, accuracy    0.923
| end of epoch   4 | time: 15.16s | valid accuracy    0.908
| epoch   5 |   500/ 1782 batches, accuracy    0.930
| epoch   5 |  1000/ 1782 batches, accuracy    0.929
| epoch   5 |  1500/ 1782 batches, accuracy    0.931
| end of epoch   5 | time: 15.21s | valid accuracy    0.900
| epoch   6 |   500/ 1782 batches, accuracy    0.943
| epoch   6 |  1000/ 1782 batches, accuracy    0.941
| epoch   6 |  1500/ 1782 batches, accuracy    0.944
| end of epoch   6 | time: 15.17s | valid accuracy    0.911
| epoch   7 |   500/ 1782 batches, accuracy    0.943
| epoch   7 |  1000/ 1782 batches, accuracy    0.945
| epoch   7 |  1500/ 1782 batches, accuracy    0.946
| end of epoch   7 | time: 15.24s | valid accuracy    0.912
| epoch   8 |   500/ 1782 batches, accuracy    0.945
| epoch   8 |  1000/ 1782 batches, accuracy    0.944
| epoch   8 |  1500/ 1782 batches, accuracy    0.944
| end of epoch   8 | time: 15.20s | valid accuracy    0.913
| epoch   9 |   500/ 1782 batches, accuracy    0.944
| epoch   9 |  1000/ 1782 batches, accuracy    0.948
| epoch   9 |  1500/ 1782 batches, accuracy    0.946
| end of epoch   9 | time: 15.29s | valid accuracy    0.915
| epoch  10 |   500/ 1782 batches, accuracy    0.949
| epoch  10 |  1000/ 1782 batches, accuracy    0.945
| epoch  10 |  1500/ 1782 batches, accuracy    0.946
| end of epoch  10 | time: 15.19s | valid accuracy    0.913
Checking the results of test dataset.
test accuracy    0.908


"MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."


This is a Sports news


'Beijing of Automation, Beijing Institute of Technology'


This is a Sci/Tec news


8. 全部代码

path = '... your path\\AG_NEWS.data'import torch
from torchtext.datasets import AG_NEWStrain_iter = AG_NEWS(root=path, split='train')      # 访问原始数据集迭代器from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocabtokenizer = get_tokenizer('basic_english')      # 输入的字符串
counter = Counter()
for (label, line) in train_iter:counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=1)# 准备数据处理管道
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]      # token就是word,vocab[token]就是其对应的数字
label_pipeline = lambda x: int(x) - 1       # 把1、2、3、4 转化为 0、1、2、3 四类# 生成数据批次和迭代器
from torch.utils.data import DataLoaderdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:label_list.append(label_pipeline(_label))processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)      # torch.Size([41]), torch.Size([58])...text_list.append(processed_text)offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)        # torch.Size([64])offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)      # torch.Size([64])text_list = torch.cat(text_list)        # 若干tensor组成的列表变成一个tensorreturn label_list.to(device), text_list.to(device), offsets.to(device)# dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)# import ipdb
from torch import nn
class TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)      # 将tensor用从均匀分布中抽样得到的值填充self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)        # torch.Size([64, 64])output = self.fc(embedded)      # torch.Size([64, 4])return output# num_class = len(set([label for (label, text) in train_iter]))       # 迭代器需要重新开始才能计算...即train_iter = AG_NEWS(root=path, split='train')      # 访问原始数据集迭代器
num_class = 4
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)import time
def train(dataloader):model.train()       # 训练模式total_acc, total_count = 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):optimizer.zero_grad()predited_label = model(text, offsets)loss = criterion(predited_label, label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)     # 规定了最大不能超过的max_normoptimizer.step()total_acc += (predited_label.argmax(1) == label).sum().item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))total_acc, total_count = 0, 0start_time = time.time()def evaluate(dataloader):model.eval()total_acc, total_count = 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predited_label = model(text, offsets)# loss = criterion(predited_label, label)total_acc += (predited_label.argmax(1) == label).sum().item()total_count += label.size(0)return total_acc / total_countdef predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item() + 1from torch.utils.data.dataset import random_splitif __name__ == '__main__':# 超参数(Hyperparameters)# EPOCHS = 10  # epoch# LR = 5  # learning rate# BATCH_SIZE = 64  # batch size for training## criterion = torch.nn.CrossEntropyLoss()# optimizer = torch.optim.SGD(model.parameters(), lr=LR)# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)# total_accu = None# train_iter, test_iter = AG_NEWS(root=path)# train_dataset = list(train_iter)# test_dataset = list(test_iter)# num_train = int(len(train_dataset) * 0.95)# split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])## train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱# valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)# test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)## for epoch in range(1, EPOCHS + 1):#     epoch_start_time = time.time()#     train(train_dataloader)#     accu_val = evaluate(valid_dataloader)#     if total_accu is not None and total_accu > accu_val:#         scheduler.step()#     else:#         total_accu = accu_val#     print('-' * 59)#     print('| end of epoch {:3d} | time: {:5.2f}s | '#           'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))#     print('-' * 59)### print('Checking the results of test dataset.')# accu_test = evaluate(test_dataloader)# print('test accuracy {:8.3f}'.format(accu_test))## torch.save(model.state_dict(), '... your path\\model_TextClassification.pth')# 以下是评估model.load_state_dict(torch.load('... your path\\model_TextClassification.pth'))ag_news_label = {1: "World",2: "Sports",3: "Business",4: "Sci/Tec"}# ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."ex_text_str = 'Beijing of Automation, Beijing Institute of Technology'# model = model.to("cpu")print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])


  1. 数据集的获取踩了一些坑,首先是中文教材是错的,没有及时更新,还是得去看英文的;以及下次时github又下不动,用IDM才能完成下载……
  2. 数据通道准备实际上就是英文单词的one-hot模型
  3. 数据批次和迭代器的DataLoader应该很重要,它能把数据转化成流式来处理,避免全部读进来,内存直接爆掉;collate_fn这种将batch变成tensor第一次接触有点难懂
  4. 模型比较简单,就是每个单词embedding之后取个平均来表示一个句子
  5. 训练时有个更新学习率的操作,可以借鉴一下;它做了个验证集感觉也没什么用……


  1. 另外一个TorchText的实验代码复现一下
  2. 学习BERT、Transformer模型,编程实现

