原链接:Text classification with the torchtext library — PyTorch Tutorials 1.11.0+cu102 documentation

(1)导入数据集(经常会出现数据集下载失败的情况),有大佬的网盘:https://pan.baidu.com/s/1Rz_XoaTZWSRiHGOwkACosQ,提取码:j0no 

下载完直接放到当前打开jupyter notebook的目录下,地址就到AG_NEWS.data文件夹即可

(现在的版本好像要加上root=‘地址’,不然会报错)

import torch
from torchtext.datasets import AG_NEWS
path = r'E:\Notebook\自然语言处\Text_classification_with_the_torchtext_library\AG_NEWS.data'
train_iter = iter(AG_NEWS(root=path, split='train'))

(2)构建词汇表

from torchtext.data.utils import get_tokenizer #导入分词工具
from torchtext.vocab import build_vocab_from_iterator #使用迭代器构建词表tokenizer = get_tokenizer('basic_english') #创建分词器对象,采用英文分词
train_iter = AG_NEWS(root=path, split='train')  #获取数据集,并生成迭代器def yield_tokens(data_iter):for _, text in data_iter: #获取每一条的标签label和内容textyield tokenizer(text) #对获取内容分词,并返回。yield返回一个迭代器对象#将未能识别的单词设置为<unk>
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"]) #设置<unk>的索引为默认索引,一旦遇到不能识别单词,转为<unk>的索引值
vocab.set_default_index(vocab['<unk>'])

(3)获取每条数据的label和text

text_pipeline = lambda x: vocab(tokenizer(x)) #获取每一条的text的索引表示
label_pipeline = lambda x: int(x) - 1 #获取对应的label#演示
text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5297]
label_pipeline('10')
>>> 9

(4)生成批数据和迭代器

offset是定界符的张量,表示文本张量中各个序列的起始索引

label_list:batch中每个文本的标签

text_list:batch的每个文本转换成词汇表的索引

offsets:batch中每个文本的长度

from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")def collate_batch(batch):label_list, text_list, offsets = [], [], [0] for (_label, _text) in batch:label_list.append(label_pipeline(_label))processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)offsets.append(processed_text.size(0)) #text.size(0)获取text的长度label_list = torch.tensor(label_list, dtype=torch.int64)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)text_list = torch.cat(text_list)return label_list.to(device), text_list.to(device), offsets.to(device)

其中: offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)类似于从第一个数开始(不包括最后一个数),将每个数依次向后累加,得到的新结果再向后累加。10加到20上为30, 30又加到30上成了60(最后的40不算):

(举个栗子)

>>> offsets = [10, 20, 30, 40]

>>> offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)

>>>offsets变成了tensor[10, 30, 60]

在案例中的含义,offsets列表可以记录每一个text的起始位置索引,从0开始,[0, text_index1, text_index2,....],索引之间相减就可以算出每个text的长度。

cat()将多个tensor融合为一个:

text_list:[ tensor([1, 2, 3]) , tensor([4 ,  5 ,  6]) ]

text_list = torch.cat(text_list) => tensor([1 , 2 , 3 , 4,  5 , 6])

(5)定义模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

(6)定义训练和验证

import timedef train(dataloader):model.train()total_acc, total_count = 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):optimizer.zero_grad()predicted_label = model(text, offsets)loss = criterion(predicted_label, label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches ''| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),total_acc/total_count))total_acc, total_count = 0, 0start_time = time.time()def evaluate(dataloader):model.eval()total_acc, total_count = 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)total_acc += (predicted_label.argmax(1) == label).sum().item()total_count += label.size(0)return total_acc/total_count

(7)设置参数,函数,开始训练模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS(root=path)
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \random_split(train_dataset, [num_train, len(train_dataset) - num_train])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)accu_val = evaluate(valid_dataloader)if total_accu is not None and total_accu > accu_val:scheduler.step()else:total_accu = accu_valprint('-' * 59)print('| end of epoch {:3d} | time: {:5.2f}s | ''valid accuracy {:8.3f} '.format(epoch,time.time() - epoch_start_time,accu_val))print('-' * 59)

(8)检查测试集的准确率

print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

结果:

Checking the results of test dataset.
        test accuracy    0.907

(9)随机测试一篇新闻,进行分类

ag_news_label = {1: "World",2: "Sports",3: "Business",4: "Sci/Tec"}def predict(text, text_pipeline):with torch.no_grad():text = torch.tensor(text_pipeline(text))output = model(text, torch.tensor([0]))return output.argmax(1).item() + 1ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \enduring the season’s worst weather conditions on Sunday at The \Open on his way to a closing 75 at Royal Portrush, which \considering the wind and the rain was a respectable showing. \Thursday’s first round at the WGC-FedEx St. Jude Invitational \was another story. With temperatures in the mid-80s and hardly any \wind, the Spaniard was 13 strokes better in a flawless round. \Thanks to his best putting performance on the PGA Tour, Rahm \finished with an 8-under 62 for a three-stroke lead, which \was even more impressive considering he’d never played the \front nine at TPC Southwind."model = model.to("cpu")print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])

结果:

This is a Sports news

Torchtext下的AG_NEWS数据集进行分类(官方文档代码)相关推荐

  1. 跟着官方文档一步一步搭建基于gozero的微服务

    1. 创建项目 mkdir microservice cd microservice go mod init zzh.com/microservice 2. 下载go-zero和goctl 2.1 设 ...

  2. Hyperledger Fabric 2.0 官方文档中文版 第6章 教程(下)

    Hyperledger Fabric 2.0 官方文档中文版 第6章 教程下 总目录 6.教程(下) 使用CouchDB 为什么使用CouchDB? 在Hyperledger Fabric中启用Cou ...

  3. linux3.10.53编译,根据官方文档在Linux下编译安装Apache

    根据官方文档在Linux下编译安装Apache 前言 永远记住官方文档才是最准确的安装手册,这篇文章仅为对官方文档的解读和补充,学习提升务必阅读官方文档: http://httpd.apache.or ...

  4. Windows环境下安装RabbitMQ(官方文档中文版)

    安装RabbitMQ 软件提取 链接:https://pan.baidu.com/s/1bIqTqEf7b8bcGZ6_jSIccg  提取码:nnju 本文按照官方文档按步骤详细解读,废话不多说,下 ...

  5. Sklearn官方文档中文整理4——随机梯度下降和最近邻篇

    Sklearn官方文档中文整理4--随机梯度下降和最近邻篇 1. 监督学习 1.5. 随机梯度下降 1.5.1. 分类[linear_model.SGDClassifier] 1.5.2. 回归[li ...

  6. 文件标识符必须为双精度类型的整数值标量_【翻译】VTK官方文档 - vtk文件格式

    本文翻译自vtk官方文档:vtk_file_format 文末有链接 VTK提供了许多源对象和编写器对象,用于读取和写入流行的数据文件格式,此外,VTK也提供了自己的文件格式.创建一种数据文件格式的主 ...

  7. 《Spring Cloud Netflix官方文档》1.服务发现:Eureka客户端

    1.     服务发现:Eureka客户端 服务发现是微服务架构的关键原则之一.使用手动配置或一些约定方式来处理多服务多实例的方式是非常困难,并且十分脆弱的.Eureka同时是Netflix服务发现的 ...

  8. spark之4:基础指南(源自官方文档)

    spark之4:基础指南(源自官方文档) @(SPARK)[spark, 大数据] spark之4基础指南源自官方文档 一简介 二接入Spark 三初始化Spark 一使用Shell 四弹性分布式数据 ...

  9. 【TensorFlow官方文档】MNIST机器学习入门

    MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:它也包含每一张图片对应的标签,告诉我们这个是数字几.比如,下面这四张图片的标签分别是5,0,4,1. 从一个很简单的数学模型开始:训练 ...

最新文章

  1. emWin GUIBuilder V5.40a 无法保存文件问题解决方案
  2. 4月11日 GCD 总结(二)
  3. 高逼格的程序员这样度过十一假期
  4. hexo 搜索_Hexo+Github博客最简教程-Dockerfile自动搭建
  5. 动手设计 CPU(一)—— 各类元件功能表
  6. JVM、JRE、JDK、Java SE、Java EE、Java ME
  7. 数据库之SQL行列转换
  8. 分享一款实用的太阳能充电电路(室内光照可用)
  9. 计算机再带word打不开,电脑word文档打不开怎么办(附:4种解决办法)
  10. 华为S5700_交换机_基础管理配置
  11. 基于WordPress搭建个人网站
  12. Vue中添加背景图片
  13. 基于ArcGIS Pro 的AI地图配色工具
  14. Redis——过期时间/过期回收策略
  15. 诗词锦集(持续更新)
  16. 又一家硅谷巨头掀裁员潮!员工称像经历一场“屠杀”
  17. 超简单地输出所有水仙花数(Java实现)
  18. 国际短信平台短信路由搭建后台软件定制-移讯云短信系统
  19. sqlserver查询所有表名及表备注
  20. postgresql通过拼音排序

热门文章

  1. 小白教程!x2m和x3m音频格式转换.mp3格式
  2. Java使用Jsoup写爬虫
  3. Navicat for Mysql连接数据库时遇到10038错误
  4. 如何理解照片后期处理
  5. html页面怎么保存和读取cookie 推荐MDN封装的cookie方法
  6. Unity基础—Transform类
  7. JavaSE----2
  8. MCE | 间歇性禁食增强抗癌疗效
  9. 基于空间金字塔网络的光流估计
  10. UVALive - 4394 (区间dp)