Pytorch RNN 实现新闻数据分类

  • 概述
  • 数据集
  • Text RNN 模型
  • 评估函数
  • 主函数
  • 输出结果

概述

RNN (Recurrent Netural Network) 是用于处理序列数据的神经网络. 所谓序列数据, 即前面的输入和后面的输入有一定的联系.

数据集

我们将使用 THUCNews 的一个子数据集, 该数据集包含 10 个类别的新闻数据, 单个类别有 10000 条数据.

Text RNN 模型

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as npclass Config(object):"""配置参数"""def __init__(self, dataset, embedding):self.model_name = 'TextCNN'self.train_path = dataset + '/data/train.txt'                                # 训练集self.dev_path = dataset + '/data/dev.txt'                                    # 验证集self.test_path = dataset + '/data/test.txt'                                  # 测试集self.class_list = [x.strip() for x in open(dataset + '/data/class.txt').readlines()]                                # 类别名单self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果self.log_path = dataset + '/log/' + self.model_nameself.embedding_pretrained = torch.tensor(np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\if embedding != 'random' else None                                       # 预训练词向量self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备self.dropout = 0.5                                              # 随机失活self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练self.num_classes = len(self.class_list)                         # 类别数self.n_vocab = 0                                                # 词表大小,在运行时赋值self.num_epochs = 20                                            # epoch数self.batch_size = 128                                           # mini-batch大小self.pad_size = 32                                              # 每句话处理成的长度(短填长切)self.learning_rate = 1e-3                                       # 学习率self.embed = self.embedding_pretrained.size(1)\if self.embedding_pretrained is not None else 300           # 字向量维度self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸self.num_filters = 256                                          # 卷积核数量(channels数)'''Convolutional Neural Networks for Sentence Classification'''class Model(nn.Module):def __init__(self, config):super(Model, self).__init__()if config.embedding_pretrained is not None:self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)else:self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)self.convs = nn.ModuleList([nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])self.dropout = nn.Dropout(config.dropout)self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)def conv_and_pool(self, x, conv):x = F.relu(conv(x)).squeeze(3)x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):#print (x[0].shape)out = self.embedding(x[0])out = out.unsqueeze(1)out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)out = self.dropout(out)out = self.fc(out)return out

评估函数

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from tensorboardX import SummaryWriter# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):for name, w in model.named_parameters():if exclude not in name:if 'weight' in name:if method == 'xavier':nn.init.xavier_normal_(w)elif method == 'kaiming':nn.init.kaiming_normal_(w)else:nn.init.normal_(w)elif 'bias' in name:nn.init.constant_(w, 0)else:passdef train(config, model, train_iter, dev_iter, test_iter, writer):start_time = time.time()model.train()optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)total_batch = 0  # 记录进行到多少batchdev_best_loss = float('inf')last_improve = 0  # 记录上次验证集loss下降的batch数flag = False  # 记录是否很久没有效果提升# writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))for epoch in range(config.num_epochs):print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))# scheduler.step() # 学习率衰减for i, (trains, labels) in enumerate(train_iter):# print (trains[0].shape)outputs = model(trains)model.zero_grad()loss = F.cross_entropy(outputs, labels)loss.backward()optimizer.step()if total_batch % 100 == 0:# 每多少轮输出在训练集和验证集上的效果true = labels.data.cpu()predic = torch.max(outputs.data, 1)[1].cpu()train_acc = metrics.accuracy_score(true, predic)dev_acc, dev_loss = evaluate(config, model, dev_iter)if dev_loss < dev_best_loss:dev_best_loss = dev_losstorch.save(model.state_dict(), config.save_path)improve = '*'last_improve = total_batchelse:improve = ''time_dif = get_time_dif(start_time)msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))writer.add_scalar("loss/train", loss.item(), total_batch)writer.add_scalar("loss/dev", dev_loss, total_batch)writer.add_scalar("acc/train", train_acc, total_batch)writer.add_scalar("acc/dev", dev_acc, total_batch)model.train()total_batch += 1if total_batch - last_improve > config.require_improvement:# 验证集loss超过1000batch没下降,结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreakif flag:breakwriter.close()test(config, model, test_iter)def test(config, model, test_iter):# testmodel.load_state_dict(torch.load(config.save_path))model.eval()start_time = time.time()test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'print(msg.format(test_loss, test_acc))print("Precision, Recall and F1-Score...")print(test_report)print("Confusion Matrix...")print(test_confusion)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)def evaluate(config, model, data_iter, test=False):model.eval()loss_total = 0predict_all = np.array([], dtype=int)labels_all = np.array([], dtype=int)with torch.no_grad():for texts, labels in data_iter:outputs = model(texts)loss = F.cross_entropy(outputs, labels)loss_total += losslabels = labels.data.cpu().numpy()predic = torch.max(outputs.data, 1)[1].cpu().numpy()labels_all = np.append(labels_all, labels)predict_all = np.append(predict_all, predic)acc = metrics.accuracy_score(labels_all, predict_all)if test:report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)confusion = metrics.confusion_matrix(labels_all, predict_all)return acc, loss_total / len(data_iter), report, confusionreturn acc, loss_total / len(data_iter)

主函数

import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriterparser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, default="TextRNN",help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()if __name__ == '__main__':dataset = 'THUCNews'  # 数据集# 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:randomembedding = 'embedding_SougouNews.npz'if args.embedding == 'random':embedding = 'random'model_name = args.model  # TextCNN, TextRNN,if model_name == 'FastText':from utils_fasttext import build_dataset, build_iterator, get_time_difembedding = 'random'else:from utils import build_dataset, build_iterator, get_time_difx = import_module('models.' + model_name)config = x.Config(dataset, embedding)np.random.seed(1)torch.manual_seed(1)torch.cuda.manual_seed_all(1)torch.backends.cudnn.deterministic = True  # 保证每次结果一样start_time = time.time()print("Loading data...")vocab, train_data, dev_data, test_data = build_dataset(config, args.word)train_iter = build_iterator(train_data, config)dev_iter = build_iterator(dev_data, config)test_iter = build_iterator(test_data, config)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)# trainconfig.n_vocab = len(vocab)model = x.Model(config).to(config.device)writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))if model_name != 'Transformer':init_network(model)print(model.parameters)train(config, model, train_iter, dev_iter, test_iter, writer)

输出结果

Loading data...
Vocab size: 4762
180000it [00:03, 56090.03it/s]
10000it [00:00, 32232.86it/s]
10000it [00:00, 61166.60it/s]
Time usage: 0:00:04
<bound method Module.parameters of Model((embedding): Embedding(4762, 300)(lstm): LSTM(300, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)(fc): Linear(in_features=256, out_features=10, bias=True)
)>
Epoch [1/10]
Iter:      0,  Train Loss:   2.3,  Train Acc: 11.52%,  Val Loss:   2.3,  Val Acc: 10.00%,  Time: 0:00:00 *
Iter:    100,  Train Loss:   1.3,  Train Acc: 50.39%,  Val Loss:   1.3,  Val Acc: 49.63%,  Time: 0:00:02 *
Iter:    200,  Train Loss:  0.72,  Train Acc: 77.54%,  Val Loss:  0.74,  Val Acc: 75.92%,  Time: 0:00:04 *
Iter:    300,  Train Loss:  0.47,  Train Acc: 84.18%,  Val Loss:  0.55,  Val Acc: 82.34%,  Time: 0:00:06 *
Epoch [2/10]
Iter:    400,  Train Loss:   0.5,  Train Acc: 83.59%,  Val Loss:  0.48,  Val Acc: 85.13%,  Time: 0:00:07 *
Iter:    500,  Train Loss:  0.41,  Train Acc: 88.48%,  Val Loss:  0.43,  Val Acc: 86.42%,  Time: 0:00:09 *
Iter:    600,  Train Loss:  0.37,  Train Acc: 88.48%,  Val Loss:  0.41,  Val Acc: 86.93%,  Time: 0:00:11 *
Iter:    700,  Train Loss:  0.42,  Train Acc: 86.33%,  Val Loss:  0.37,  Val Acc: 87.90%,  Time: 0:00:12 *
Epoch [3/10]
Iter:    800,  Train Loss:  0.35,  Train Acc: 89.06%,  Val Loss:  0.39,  Val Acc: 87.81%,  Time: 0:00:14
Iter:    900,  Train Loss:   0.3,  Train Acc: 89.06%,  Val Loss:  0.36,  Val Acc: 88.51%,  Time: 0:00:16 *
Iter:   1000,  Train Loss:   0.3,  Train Acc: 90.43%,  Val Loss:  0.36,  Val Acc: 88.81%,  Time: 0:00:17
Epoch [4/10]
Iter:   1100,  Train Loss:  0.29,  Train Acc: 90.82%,  Val Loss:  0.34,  Val Acc: 89.07%,  Time: 0:00:19 *
Iter:   1200,  Train Loss:  0.28,  Train Acc: 90.82%,  Val Loss:  0.33,  Val Acc: 89.43%,  Time: 0:00:21 *
Iter:   1300,  Train Loss:  0.28,  Train Acc: 90.62%,  Val Loss:  0.33,  Val Acc: 89.41%,  Time: 0:00:22
Iter:   1400,  Train Loss:  0.25,  Train Acc: 91.60%,  Val Loss:  0.33,  Val Acc: 89.37%,  Time: 0:00:24
Epoch [5/10]
Iter:   1500,  Train Loss:  0.26,  Train Acc: 91.80%,  Val Loss:  0.34,  Val Acc: 89.56%,  Time: 0:00:26
Iter:   1600,  Train Loss:  0.18,  Train Acc: 94.34%,  Val Loss:  0.35,  Val Acc: 89.14%,  Time: 0:00:27
Iter:   1700,  Train Loss:  0.23,  Train Acc: 92.58%,  Val Loss:  0.33,  Val Acc: 89.80%,  Time: 0:00:29 *
Epoch [6/10]
Iter:   1800,  Train Loss:  0.23,  Train Acc: 92.97%,  Val Loss:  0.34,  Val Acc: 89.46%,  Time: 0:00:31
Iter:   1900,  Train Loss:  0.18,  Train Acc: 94.34%,  Val Loss:  0.32,  Val Acc: 89.76%,  Time: 0:00:33 *
Iter:   2000,  Train Loss:  0.16,  Train Acc: 93.75%,  Val Loss:  0.34,  Val Acc: 89.28%,  Time: 0:00:34
Iter:   2100,  Train Loss:  0.22,  Train Acc: 92.19%,  Val Loss:  0.32,  Val Acc: 90.12%,  Time: 0:00:36 *
Epoch [7/10]
Iter:   2200,  Train Loss:  0.21,  Train Acc: 92.77%,  Val Loss:  0.34,  Val Acc: 89.67%,  Time: 0:00:38
Iter:   2300,  Train Loss:  0.18,  Train Acc: 94.73%,  Val Loss:  0.35,  Val Acc: 89.81%,  Time: 0:00:39
Iter:   2400,  Train Loss:  0.21,  Train Acc: 92.38%,  Val Loss:  0.36,  Val Acc: 89.21%,  Time: 0:00:41
Epoch [8/10]
Iter:   2500,  Train Loss:  0.19,  Train Acc: 93.75%,  Val Loss:  0.35,  Val Acc: 89.56%,  Time: 0:00:43
Iter:   2600,  Train Loss:  0.19,  Train Acc: 94.53%,  Val Loss:  0.31,  Val Acc: 90.38%,  Time: 0:00:45 *
Iter:   2700,  Train Loss:   0.2,  Train Acc: 93.75%,  Val Loss:  0.33,  Val Acc: 89.95%,  Time: 0:00:46
Iter:   2800,  Train Loss:  0.15,  Train Acc: 94.92%,  Val Loss:  0.33,  Val Acc: 90.05%,  Time: 0:00:48
Epoch [9/10]
Iter:   2900,  Train Loss:  0.22,  Train Acc: 93.16%,  Val Loss:  0.35,  Val Acc: 89.47%,  Time: 0:00:49
Iter:   3000,  Train Loss:  0.16,  Train Acc: 94.53%,  Val Loss:  0.36,  Val Acc: 89.72%,  Time: 0:00:51
Iter:   3100,  Train Loss:  0.19,  Train Acc: 93.95%,  Val Loss:  0.37,  Val Acc: 89.51%,  Time: 0:00:53
Epoch [10/10]
Iter:   3200,  Train Loss:  0.13,  Train Acc: 95.70%,  Val Loss:  0.35,  Val Acc: 89.67%,  Time: 0:00:54
Iter:   3300,  Train Loss:   0.2,  Train Acc: 93.36%,  Val Loss:  0.35,  Val Acc: 90.27%,  Time: 0:00:56
Iter:   3400,  Train Loss:  0.12,  Train Acc: 96.48%,  Val Loss:  0.34,  Val Acc: 89.92%,  Time: 0:00:57
Iter:   3500,  Train Loss:  0.12,  Train Acc: 95.70%,  Val Loss:  0.35,  Val Acc: 89.98%,  Time: 0:00:59
Test Loss:   0.3,  Test Acc: 90.66%
Precision, Recall and F1-Score...precision    recall  f1-score   supportfinance     0.8777    0.9040    0.8906      1000realty     0.9353    0.9110    0.9230      1000stocks     0.8843    0.7950    0.8373      1000education     0.9319    0.9440    0.9379      1000science     0.8297    0.8770    0.8527      1000society     0.9012    0.9210    0.9110      1000politics     0.9001    0.8740    0.8869      1000sports     0.9788    0.9680    0.9734      1000game     0.9299    0.9290    0.9295      1000
entertainment     0.9015    0.9430    0.9218      1000accuracy                         0.9066     10000macro avg     0.9070    0.9066    0.9064     10000weighted avg     0.9070    0.9066    0.9064     10000Confusion Matrix...
[[904  11  38   5  16  10   9   1   1   5][ 14 911  14   6   9  12  10   4   6  14][ 72  25 795   5  57   1  33   0   9   3][  2   1   2 944  10  18   7   0   5  11][ 11   6  18   8 877  17  15   0  32  16][  4  12   1  18   7 921  14   1   7  15][ 16   3  21  14  26  29 874   4   2  11][  1   1   3   1   3   2   4 968   0  17][  2   1   5   5  39   4   3   1 929  11][  4   3   2   7  13   8   2  10   8 943]]
Time usage: 0:00:00


Pytorch RNN 实现新闻数据分类相关推荐

  1. Pytorch实战——基于RNN的新闻分类

    目录 一.项目介绍 二.基于RNN的新闻分类 Step1 加载数据集 Step2 分词和构建词汇表 Step3 构建数据加载器 dataloader Step4 定义神经网络模型 Step5 定义模型 ...

  2. 【文本分类】基于BERT预训练模型的灾害推文分类方法、基于BERT和RNN的新闻文本分类对比

    ·阅读摘要: 两篇论文,第一篇发表于<图学学报>,<图学学报>是核心期刊:第二篇发表于<北京印刷学院学报>,<北京印刷学院学报>没有任何标签. ·参考文 ...

  3. Pytorch RNN(详解RNN+torch.nn.RNN()实现)

    目录 一.RNN简介 二.RNN简介2 三.pytorch RNN 3.1    定义RNN()

  4. pytorch rnn 实现手写字体识别

    pytorch rnn 实现手写字体识别 构建 RNN 代码 加载数据 使用RNN 训练 和测试数据 构建 RNN 代码 import torch import torch.nn as nn from ...

  5. 手把手写深度学习(5)——Pytorch+RNN自动生成邓紫棋风格歌词

    前言:前面两篇文章讲了RNN的基础理论和用mxnet搭建一个RNN网络,自动生成歌词.本文是时候亮出我三十年邓紫棋歌迷的身份,用使用更广泛的Pytorch框架,搭建一个RNN模型,用来自动生成邓紫棋风 ...

  6. 《Pytorch - RNN模型》

    前言: 之前的博文部分地讲解了RNN的标准结构,也用pytorch的RNNCell类和RNN类实现了前向传播的计算,这里我们再举一个例子,做一个特别简单特别简单特别简单特别简单的翻译器,目标如下: 将 ...

  7. Deep Dive into Pytorch RNN/LSTM

    RNN/LSTM 1.LSTM及其pytorch实现 1.1 遗忘门 forget gate 1.2 输入门 input gate 1.3 Cell State更新 1.4 Output Gate 2 ...

  8. [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...

  9. pytorch RNN实现分类

    数据加载(简单看) from __future__ import unicode_literals, print_function, division from io import open impo ...

最新文章

  1. Python3.7版本unittest框架添加用例的方法
  2. 关于bin和obj文件夹。debug 和release的区别(转)
  3. 方立勋_30天掌握JavaWeb_request
  4. [转]在ROS下使用zeroconf配置多机通信
  5. [css] 要让Chrome支持小于12px的文字怎么做?
  6. 总结深度学习各种网络结构【更新中...】
  7. quartz实现每周一至周五 非法定节假日 每天9:30-11:30,13:00-15:00执行定时任务
  8. swf到html5转换器,iPixSoft SWF to HTML5 Converter(SWF到HTML5转换器) V3.6.0 官方版[安全工具]...
  9. coolfire文章之八
  10. 3D打印切片软件cura使用
  11. 安卓马赛克view_去马赛克软件app下载
  12. word没保存?如何找回未保存的word文档
  13. C语言求任意数的阶层
  14. oracle中索引的使用
  15. web server and web service
  16. 2021-2027全球及中国攀冰专用装备行业研究及十四五规划分析报告
  17. STM32F103代码远程升级(五)基于MQTT协议WiFi远程升级代码的实现
  18. C++ Problems
  19. Java面向对象之线程相关概念 和 线程基本使用
  20. 双色球彩票预测可视化(python)

热门文章

  1. oracle当中怎么保留两位小数
  2. 5招教你如何做用户行为分析
  3. eclipse将程序打包放到linux运行
  4. 桌面支持--PLM软件必须右键用管理员账号打开
  5. 挖到这个高危SSRF漏洞,我和我的小伙伴们都惊呆了!
  6. 看完后,你将离成功不远了...让我们一起奋斗吧!【转】
  7. 数字银行的全球与中国市场2022-2028年:技术、参与者、趋势、市场规模及占有率研究报告
  8. 综合评价理想解法(TOPSIS解法)
  9. if语句的写法之普通,文艺,2B青年写法
  10. pixabay注册失败原因以及解决办法