关于《黑马程序员》课程中NLP中 训练新闻分类模型

最近在学习NLP的相关知识,找了资料比较全的黑马程序员中讲解NLP的课程,可是其中有一部分实战 新闻主题分类实战项目中,我发现黑马程序员代码有大两的错误,多处代码逻辑错误:

  1. 首先是数据集下载太慢,因为需要翻墙才能下载,所以大部分情况在加载数据集就会出现Timeout异常
  2. 数据集的处理,在课程中并没有提到,加载本地的csv数据集文件出现的格式不对的情况
  3. 其次,generator_banth()这个方法中返回的数据对象元组形式是不对的,新闻数据集的元组是3项(type, title ,content)分别是新闻的类型,新闻的标题和新闻的内内容,但是在课程却只有两项。

!!!需要注意的是 torchtext 的版本是0.4 ,可能是版本更新后,这个模块被移走了,如果不是0.4 可能会出现from torchtext.datasets.text_classification 这句话错误!!!

针对上述问题,我整理了一个完整可以正常运行的完整代码,希望给个小心心或者关注我一下呀~

先放代码

from torchtext.datasets.text_classification import *
from torchtext.datasets.text_classification import _csv_iterator, _create_data_from_iterator
import os
import time
from torch import optim
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch.nn as nn
from torchtext.data.utils import ngrams_iterator
from torchtext.data.utils import get_tokenizerN_GRAMS =2
if not os.path.isdir('./data'):os.mkdir('./data')BATCH_SIZE = 16
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')# 定义创建数据集
def _setup_data_set(dataset_tar='./data/ag_news_csv.tar.gz',n_grams=N_GRAMS, vocab=None,include_unk=False):extracted_files = extract_archive(dataset_tar)train_csv_path = ''test_csv_path = ''for file_name in extracted_files:if file_name.endswith('train.csv'):train_csv_path = file_nameif file_name.endswith('test.csv'):test_csv_path=file_nameif vocab is None:print("Building Vocab based on %s" % train_csv_path)# 创建词典vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams=n_grams))else:if not isinstance(vocab, Vocab):raise TypeError("Passed vocabulary is not of type Vocab")print('Vocab has %d entries' % len(vocab))print('Creating training data')train_data, train_labels = _create_data_from_iterator(vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk)print('Creating testing data')test_data, test_labels = _create_data_from_iterator(vocab, _csv_iterator(test_csv_path, n_grams, yield_cls=True), include_unk)if len(train_labels ^ test_labels) > 0:raise ValueError("Training and test labels on't match")# 返回数据集实例return (TextClassificationDataset(vocab, train_data, train_labels),TextClassificationDataset(vocab, test_data, test_labels))train_data_set, test_data_set = _setup_data_set()# 定义模型
class TextSentiment(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super().__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):init_range = 0.5self.embedding.weight.data.uniform_(-init_range, init_range)self.fc.weight.data.uniform_(-init_range, init_range)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)VOCAB_SIZE = len(train_data_set.get_vocab())
EMBED_DIM = 32
NUM_CLASS = len(train_data_set.get_labels())
# 实列化
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)N_EPOCH = 5
min_valid_loss = float('inf')criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=4.0)
scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)train_len = int(len(train_data_set) * 0.95)
sub_train_, sub_valid_ = random_split(train_data_set, [train_len, len(train_data_set) - train_len])def generate_batch(batch):label = torch.tensor([entry[0] for entry in batch])text = [entry[1] for entry in batch]offsets =[0] + [len(entry) for entry in text]offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)text = torch.cat(text)return text, offsets, labeldef train_function(sub_train_):loss_ = 0acc_ = 0data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)for i, (text, offsets, cls) in enumerate(data):optimizer.zero_grad()text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)output = model(text, offsets)loss_ = criterion(output, cls)loss_ += loss_.item()loss_.backward()optimizer.step()acc_ += (output.argmax(1) == cls).sum().item()# 调整学习率scheduler.step()return loss_ / len(sub_train_), acc_ / len(sub_train_)def test(data_):loss = 0acc = 0data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)for text, offsets, cls in data:text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)with torch.no_grad():output = model(text, offsets)loss = criterion(output, cls)loss += loss.item()acc += (output.argmax(1) == cls).sum().item()return loss / len(data_), acc / len(data_)for epoch in range(N_EPOCH):start_time = time.time()train_loss, train_acc = train_function(sub_train_)valid_loss, valid_acc = test(sub_valid_)secs = int(time.time() - start_time)mins = secs / 60secs = secs % 60print('Epoch:%d' % (epoch + 1), "| time in %d minutes, %d seconds" % (mins, secs))print(f"\tLoss:{train_loss:.4f}(train)\t|\tAcc:{train_acc * 100:.1f}%(train)")print(f"\tLoss:{valid_loss:.6f}(valid)\t|\tAcc:{valid_acc * 100:.6f}%(valid)")# 测试模型
ag_news_label = {1: "World",2: "Sports",3: "Business",4: "Sci/Tec"
}def predict(text, model, vocab, ngrams):tokenizer = get_tokenizer("basic_english")with torch.no_grad():text = torch.tensor([vocab[token] for token in ngrams_iterator(tokenizer(text),ngrams)])output = model(text, torch.tensor([0]))return output.argmax(1).item() + 1ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \enduring the season’s worst weather conditions on Sunday at The \Open on his way to a closing 75 at Royal Portrush, which \considering the wind and the rain was a respectable showing. \Thursday’s first round at the WGC-FedEx St. Jude Invitational \was another story. With temperatures in the mid-80s and hardly any \wind, the Spaniard was 13 strokes better in a flawless round. \Thanks to his best putting performance on the PGA Tour, Rahm \finished with an 8-under 62 for a three-stroke lead, which \was even more impressive considering he’d never played the \front nine at TPC Southwind."vocab = train_data_set.get_vocab()
model = model.to("cpu")print("This is a %s news" % ag_news_label[predict(ex_text_str, model, vocab, 2)])

运行效果:

Building Vocab based on ./data/ag_news_csv/train.csv
120000lines [00:06, 17621.91lines/s]
Vocab has 1308844 entries
Creating training data
7600lines [00:00, 8405.65lines/s]
Creating testing data
7600lines [00:00, 9790.11lines/s]
Epoch:1 | time in 0 minutes, 0 secondsLoss:0.0003(train)      |       Acc:40.9%(train)Loss:0.0038(valid)      |       Acc:47.1%(valid)
Epoch:2 | time in 0 minutes, 0 secondsLoss:0.0002(train)      |       Acc:70.4%(train)Loss:0.0031(valid)      |       Acc:67.9%(valid)
Epoch:3 | time in 0 minutes, 0 secondsLoss:0.0004(train)      |       Acc:82.4%(train)Loss:0.0126(valid)      |       Acc:52.6%(valid)
Epoch:4 | time in 0 minutes, 0 secondsLoss:0.0001(train)      |       Acc:88.3%(train)Loss:0.0026(valid)      |       Acc:60.8%(valid)
Epoch:5 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:91.9%(train)Loss:0.0002(valid)      |       Acc:79.7%(valid)
Epoch:6 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:94.8%(train)Loss:0.0001(valid)      |       Acc:81.8%(valid)
Epoch:7 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:96.7%(train)Loss:0.0001(valid)      |       Acc:83.4%(valid)
Epoch:8 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:98.5%(train)Loss:0.0001(valid)      |       Acc:83.4%(valid)
Epoch:9 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:99.3%(train)Loss:0.0001(valid)      |       Acc:81.1%(valid)
Epoch:10 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:99.6%(train)Loss:0.0002(valid)      |       Acc:82.1%(valid)
Epoch:11 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:99.8%(train)Loss:0.0001(valid)      |       Acc:84.7%(valid)
Epoch:12 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:99.9%(train)Loss:0.0001(valid)      |       Acc:83.2%(valid)
Epoch:13 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0001(valid)      |       Acc:83.7%(valid)
Epoch:14 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0001(valid)      |       Acc:83.2%(valid)
Epoch:15 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0001(valid)      |       Acc:84.2%(valid)
Epoch:16 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0001(valid)      |       Acc:85.0%(valid)
Epoch:17 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0001(valid)      |       Acc:85.0%(valid)
Epoch:18 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0000(valid)      |       Acc:85.5%(valid)
Epoch:19 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0001(valid)      |       Acc:84.2%(valid)
Epoch:20 | time in 0 minutes, 0 secondsLoss:0.0000(train)      |       Acc:100.0%(train)Loss:0.0001(valid)      |       Acc:84.5%(valid)
This is a Sports news

黑马NLP实战 --- 新闻分类模型训练相关推荐

  1. bert中文分类模型训练+推理+部署

    文章预览: 0. bert简介 1. bert结构 1. bert中文分类模型训练 1 下载bert项目代码 代码结构 2 下载中文预训练模型 3 制作中文训练数据集 2. bert模型推理 1.te ...

  2. keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习

    keras构建前馈神经网络(feedforward neural network)进行多分类模型训练学习 前馈神经网络(feedforward neural network)是一种最简单的神经网络,各 ...

  3. 神经网络学习小记录19——微调VGG分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录19--微调VGG分类模型训练自己的数据(猫狗数据集) 注意事项 学习前言 什么是VGG16模型 VGG模型的复杂程度 训练前准备 1.数据集处理 2.创建Keras的VGG模型 3 ...

  4. AI:神经网络IMDB电影评论二分类模型训练和评估

    AI:Keras神经网络IMDB电影评论二分类模型训练和评估,python import keras from keras.layers import Dense from keras import ...

  5. PytorchCNN图片识别和分类模型训练框架

    PytorchCNN图片识别和分类模型训练框架 文章目录 PytorchCNN图片识别和分类模型训练框架 前言 一.图片数据集预处理 二.模型训练 1.transforms.Compose准备 2.通 ...

  6. 神经网络学习小记录17——使用AlexNet分类模型训练自己的数据(猫狗数据集)

    神经网络学习小记录17--使用AlexNet分类模型训练自己的数据(猫狗数据集) 学习前言 什么是AlexNet模型 训练前准备 1.数据集处理 2.创建Keras的AlexNet模型 开始训练 1. ...

  7. 推荐系统实战中LR模型训练(二)

    背景: 上一篇推荐系统实战中LR模型训练(一) 中完成了LR模型训练的代码部分.本文中将详细讲解数据准备部分,即将文本数据数值化为稀疏矩阵的形式. 文本数据: 稀疏矩阵: 实现过程: 文本数据格式如下 ...

  8. 不平衡数据集分类实战:成人收入数据集分类模型训练和评估

    许多二分类任务并不是每个类别都有相同数量的数据,存在着数据分布不平衡的情况. 一个常用的例子是成人收入数据集,它涉及到社交关系.教育水平等个人数据,以此来预测成人的收入水平,判断其是否拥有5万美元/年 ...

  9. 【NLP】经典分类模型朴素贝叶斯解读

    贝叶斯分类器在早期的自然语言处理任务中有着较多实际的应用,例如大部分的垃圾邮件处理都是用的贝叶斯分类器.贝叶斯分类器的理论对于理解后续的NLP模型有很大的进益,感兴趣的小伙伴一定要好好看看,本文会详细 ...

最新文章

  1. python stm32-STM32F4系列使用MicroPython开发
  2. python怎么导入包-Python模块导入与包构建最佳实践
  3. 11月17日spring mvc入门培训
  4. 如何安全的存储用户密码?(中)代码篇
  5. 网关、负载均衡、服务注册发现什么关系?
  6. 入门数据分析选择Python还是SQL?七个常用操作对比!
  7. 高等数学上-赵立军-北京大学出版社-题解-练习2.4
  8. [技术回顾系列]--认识WebService全貌
  9. JScharts快速入门
  10. c# 中对于每次修改的程序 都必须重新手动生成 才能编译的问题
  11. JMeter——JMeter如何进行汉化
  12. 【硬件】串口422的DB9接法
  13. linux之替换开机logo
  14. SPSS 检验后显著性识别
  15. php读取服务器csv文件,PHP进行读取CSV文件数据和生成CSV文件
  16. 有什么方法可以把WPS转为Word:小白教你一招搞定
  17. 软件自动化测试框架STAF概述
  18. MEM/MBA英语基础(01) 10类词性说明
  19. urlencode、unquote
  20. 一道对10年间中国行政区划个数进行对比的Python考试题

热门文章

  1. PhotoShopCS6报错error:16解决办法
  2. 泰勒公式专题 拉格朗日余项与佩亚诺余项,麦克劳林公式
  3. 什么是数据科学?如何把数据变成产品
  4. A later version of Node.js is already installed. Setup willnow exit.
  5. python如何不以科学计数法形式输出小数
  6. LINQ教程一:LINQ简介
  7. 透彻理解SLAM中的非线性最小二乘问题
  8. 别让用户发呆—设计中的防呆策略[转]
  9. 输入三个数按从小到大顺序输出_攀枝花数显型耐压测试仪公司,期待合作
  10. Oracle账号频繁被锁定