本文主要介绍一下如何使用 PyTorch 实现一个简单的基于self-attention机制的LSTM文本分类模型。

目录

  • LSTM
  • self-attention机制
  • 准备数据集
    • 数据集处理
    • 设置输入数据参数
  • 训练模型
    • 训练模型效果
  • 测试模型
    • 测试模型效果

LSTM

LSTM是RNN的一种变种,可以有效地解决RNN的梯度爆炸或者消失问题。

self-attention机制

总的来说self-attention机制分为三步
第一步: query 和 key 进行相似度计算,得到权值:

第二步:将权值进行归一化,得到直接可用的权重

第三步:将权重和 value 进行加权求和:

准备数据集

这里因为我只是想跑一个简单的基于LSTM文本分类的模型,所以并没有用什么非常大型的数据集,而且数据集的格式多种多样,不同的数据集有不同的处理方式,所以为了处理方便,本文选择的是搜狗新闻的数据集总数为五万条,类数为10类,每个类别共5000条数据。

数据集处理

数据集我放在了自己的百度网盘,下载好后放在自己的代码的文件夹下就好,链接如下:
链接:链接: https://pan.baidu.com/s/1eoRwR3v1-hArjxknm6psFQ .
提取码:bz76
百度网盘分享的data文件夹里包括了四个文件,其中两个是训练集和验证集,另外两个分别是停用词和一个自己写的处理文本的工具。全部放在代码运行的文件夹下即可。
下图是已经分词完毕的数据:

数据集处理代码

import re
import copy
import time
import jieba
import string
import torch
import numpy as np
import pandas as pd
from torch import nn
import torch.optim as optim
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchtext.legacy import data
from torchtext.vocab import vectors
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
from torchtext.legacy.data import Field, TabularDataset, Iterator, BucketIterator
from torch.autograd import Variable
stop_words = pd.read_csv("./stop_words.txt", header=None, names=["text"])
def chinese_pre(text_data):# 操字母转化为小写,丢除数字,text_data = text_data.lower()text_data = re.sub("\d+", "", text_data)##分词,使用精确模式a = jieba.cut(text_data, cut_all=True)# 毋丢停用词和多余空格text_data = [word.strip() for word in text_data if word not in stop_words.text.values]# 贽处理后的词语使用空格连接为字符串text_data = " ".join(a)return text_dataclass DataProcessor(object):def read_text(self, is_train_data):if (is_train_data):df = pd.read_csv("./cnews_train1.csv")else:df = pd.read_csv("./cnews_val1.csv")return df["cutword"], df["label"]def word_count(self, datas):# 统计单词出现的频次,并将其降序排列,得出出现频次最多的单词dic = {}for data in datas:data_list = data.split()for word in data_list:if (word in dic):dic[word] += 1else:dic[word] = 1word_count_sorted = sorted(dic.items(), key=lambda item: item[1], reverse=True)return word_count_sorteddef word_index(self, datas, vocab_size):# 创建词表word_count_sorted = self.word_count(datas)word2index = {}# 词表中未出现的词word2index["<unk>"] = 0# 句子添加的paddingword2index["<pad>"] = 1# 词表的实际大小由词的数量和限定大小决定vocab_size = min(len(word_count_sorted), vocab_size)for i in range(vocab_size):word = word_count_sorted[i][0]word2index[word] = i + 2return word2index, vocab_sizedef get_datasets(self, vocab_size, embedding_size, max_len):# 注,由于nn.Embedding每次生成的词嵌入不固定,因此此处同时获取训练数据的词嵌入和测试数据的词嵌入# 测试数据的词表也用训练数据创建train_datas, train_labels = self.read_text(is_train_data=True)test_datas, test_labels = self.read_text(is_train_data=False)word2index, vocab_size = self.word_index(train_datas, vocab_size)train_features = []for data in train_datas:feature = []data_list = data.split()for word in data_list:word = word.lower()  # 词表中的单词均为小写if word in word2index:feature.append(word2index[word])else:feature.append(word2index["<unk>"])  # 词表中未出现的词用<unk>代替if (len(feature) == max_len):  # 限制句子的最大长度,超出部分直接截断break# 对未达到最大长度的句子添加paddingfeature = feature + [word2index["<pad>"]] * (max_len - len(feature))train_features.append(feature)test_features = []for data in test_datas:feature = []data_list = data.split()for word in data_list:word = word.lower() #词表中的单词均为小写if word in word2index:feature.append(word2index[word])else:feature.append(word2index["<unk>"]) #词表中未出现的词用<unk>代替if(len(feature)==max_len): #限制句子的最大长度,超出部分直接截断break#对未达到最大长度的句子添加paddingfeature = feature + [word2index["<pad>"]] * (max_len - len(feature))test_features.append(feature)train_features = torch.LongTensor(train_features)train_labels = torch.LongTensor(train_labels)test_features = torch.LongTensor(test_features)test_labels = torch.LongTensor(test_labels)# 将词转化为embedding# 词表中有两个特殊的词<unk><pad>,所以词表实际大小为vocab_size + 2embed = nn.Embedding(vocab_size + 2, embedding_size)train_features = embed(train_features)test_features = embed(test_features)# 指定输入特征是否需要计算梯度train_features = Variable(train_features, requires_grad=False)train_datasets = torch.utils.data.TensorDataset(train_features, train_labels)test_features = Variable(test_features, requires_grad=False)test_datasets = torch.utils.data.TensorDataset(test_features, test_labels)return train_datasets,test_datasets

设置输入数据参数

在这一步主要是设置词表大小,词嵌入维度、句子最大长度和batch size等输入模型句子的各种参数。

processor = DataProcessor()
train_datasets,test_datasets = processor.get_datasets(vocab_size=6000, embedding_size=256, max_len=100)
batch_size = 16
train_dataloader = DataLoader(train_datasets,batch_size=batch_size,shuffle=True,drop_last=True)test_dataloader = DataLoader(test_datasets,batch_size=batch_size,shuffle=True,drop_last=True)

训练模型

# -*- coding: utf-8 -*-
"""
Created on Fri May 29 09:25:58 2020
文本分类 双向LSTM + Attention 算法
@author:
"""
import torch.nn as nn
from Chinese_data_processor import DataProcessor
temp = 0
vocab_size = 6000   #词表大小
embedding_size = 256   #词向量维度
num_classes = 10     #二分类
sentence_max_len = 100  #单个句子的长度
hidden_size = 100
num_layers = 1  #一层lstm
num_directions = 2  #双向lstm
lr = 1e-3
epochs = 40
print_every_batch = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
filename = 'lstm.pth'
#Bi-LSTM模型
class BiLSTMModel(nn.Module):def __init__(self, embedding_size,hidden_size, num_layers, num_directions, num_classes):super(BiLSTMModel, self).__init__()self.input_size = embedding_sizeself.hidden_size = hidden_sizeself.num_layers = num_layersself.num_directions = num_directionsself.lstm = nn.LSTM(embedding_size, hidden_size, num_layers = num_layers,bidirectional = True)self.attention_weights_layer = nn.Sequential(nn.Linear(hidden_size, hidden_size),nn.ReLU(inplace=True))self.liner = nn.Linear(hidden_size, num_classes)self.act_func = nn.Softmax(dim=1)def forward(self, x):#lstm的输入维度为 [seq_len, batch, input_size]#x [batch_size, sentence_length, embedding_size]x = x.permute(1, 0, 2)         #[sentence_length, batch_size, embedding_size]#由于数据集不一定是预先设置的batch_size的整数倍,所以用size(1)获取当前数据实际的batchbatch_size = x.size(1)#设置lstm最初的前项输出h_0 = torch.randn(self.num_layers*self.num_directions , batch_size, self.hidden_size).to(device)#c_0 = torch.randn(self.num_layers*self.num_directions , batch_size, self.hidden_size).to(device)##out[seq_len, batch, num_directions * hidden_size]多层lstm,out只保存最后一层每个时间步t的输出h_t
#         h_n, c_n [num_layers * num_directions, batch, hidden_size]out, (h_n, c_n) = self.lstm(x, (h_0, c_0))#将双向lstm的输出拆分为前向输出和后向输出(forward_out, backward_out) = torch.chunk(out, 2, dim = 2)out = forward_out + backward_out  #[seq_len, batch, hidden_size]out = out.permute(1, 0, 2)  #[batch, seq_len, hidden_size]
#         #为了使用到lstm最后一个时间步时,每层lstm的表达,用h_n生成attention的权重h_n = h_n.permute(1, 0, 2)  #[batch, num_layers * num_directions,  hidden_size]h_n = torch.sum(h_n, dim=1) #[batch, 1,  hidden_size]h_n = h_n.squeeze(dim=1)  #[batch, hidden_size]attention_w = self.attention_weights_layer(h_n)  #[batch, hidden_size]attention_w = attention_w.unsqueeze(dim=1) #[batch, 1, hidden_size]attention_context = torch.bmm(attention_w, out.transpose(1, 2))  #[batch, 1, seq_len]softmax_w = F.softmax(attention_context, dim=-1)  #[batch, 1, seq_len],权重归一化x = torch.bmm(softmax_w, out)  #[batch, 1, hidden_size]x = out[:,-1,:]x = x.squeeze(dim=1)  #[batch, hidden_size]x = self.liner(x)x = self.act_func(x)return x
best_loss = 100000
model = BiLSTMModel(embedding_size, hidden_size, num_layers, num_directions, num_classes)
model = model.to(device)optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()
for epoch in range(epochs):model.train()print("it is ",epoch)print_avg_loss = 0train_acc = 0all_loss = 0all_acc = 0for i,(datas, labels) in enumerate(train_dataloader):datas = datas.to(device)labels = labels.to(device)preds = model(datas)loss = loss_func(preds, labels)loss = (loss - 0.4).abs() + 0.4 optimizer.zero_grad()loss.backward()optimizer.step()# scheduler.step()#获取预测的最大概率出现的位置preds = torch.argmax(preds, dim=1)train_acc += torch.sum(preds == labels).item()print_avg_loss += loss.item()all_loss += loss.item()all_acc += torch.sum(preds == labels).item()if i % print_every_batch == (print_every_batch-1):print("Batch: %d, Loss: %.4f" % ((i+1), print_avg_loss/print_every_batch))print("Train Acc: {}".format(train_acc/(print_every_batch*batch_size)))print_avg_loss = 0train_acc = 0temp = lossif loss < best_loss:best_loss = lossstate = {"state_dict":model.state_dict(),"optimizer":optimizer.state_dict()}torch.save(model.state_dict(),'lstm.pth') 

训练模型效果

因为LSTM模型参数较少,所以训练速度还是比较快速的,并且第二、三代的时候就可以达到收敛,训练收敛图如下:

测试模型

model.eval()
import operator
from functools import reduce
from sklearn.metrics import classification_report
loss_val = 0.0
corrects = 0.0
y_true = []
y_pred = []
for datas, labels in test_dataloader:datas = datas.to(device)labels = labels.to(device)y_true.append(labels.cpu().numpy().tolist())preds = model(datas)loss = loss_func(preds, labels)loss_val += loss.item() * datas.size(0)y_pred.append(preds.argmax(dim=1).detach().cpu().numpy().tolist())#获取预测的最大概率出现的位置preds = torch.argmax(preds, dim=1)corrects += torch.sum(preds == labels).item()test_loss = loss_val / len(test_dataloader.dataset)test_acc = corrects / len(test_dataloader.dataset)
# print("Test Loss: {}, Test Acc: {},len{}".format(test_loss, test_acc,len(test_dataloader.dataset)))
y_p =reduce(operator.add, y_pred)
y_t =reduce(operator.add, y_true)
print(classification_report(y_true=y_t, y_pred=y_p,digits=4))

测试模型效果

本文模型主要采用了准确率、精确率、召回率和F1值,这四种比较常见的文本分类评价标准,具体数据如下图所示:

史上最简单的LSTM文本分类实现:搜狗新闻文本分类(附代码)相关推荐

  1. Uber发布史上最简单的深度学习框架Ludwig!

    昨日,Uber官网重磅宣布新开源深度学习框架Ludwig,不需要懂编程知识,让专家能用的更顺手,让非专业人士也可以玩转人工智能,堪称史上最简单的深度学习框架! Ludwig是一个建立在TensorFl ...

  2. 重磅!Uber发布史上最简单的深度学习框架Ludwig!不懂编程也能玩转人工智能

    点击我爱计算机视觉标星,更快获取CVML新技术 昨日,Uber官网重磅宣布新开源深度学习框架Ludwig,不需要懂编程知识,让专家能用的更顺手,让非专业人士也可以玩转人工智能,堪称史上最简单的深度学习 ...

  3. 史上最简单MySQL教程详解(进阶篇)之索引及失效场合总结

    史上最简单MySQL教程详解(进阶篇)之索引及其失效场合总结 什么是索引及其作用 索引的种类 各存储引擎对于索引的支持 简单介绍索引的实现 索引的设置与分析 普通索引 唯一索引(Unique Inde ...

  4. Uber发布史上最简单的深度学习框架Ludwig!不懂编程也能玩转人工智能

    昨日,Uber官网重磅宣布新开源深度学习框架Ludwig,不需要懂编程知识,让专家能用的更顺手,让非专业人士也可以玩转人工智能,堪称史上最简单的深度学习框架! image Ludwig是一个建立在Te ...

  5. Android 自定义控件打造史上最简单的侧滑菜单

    侧滑菜单在很多应用中都会见到,最近QQ5.0侧滑还玩了点花样~~对于侧滑菜单,一般大家都会自定义ViewGroup,然后隐藏菜单栏,当手指滑动时,通过Scroller或者不断的改变leftMargin ...

  6. mysql交叉查询教程_史上最简单的 MySQL 教程(二十六)「连接查询(上)」

    连接查询连接查询:将多张表(大于等于 2 张表)按照某个指定的条件进行数据的拼接,其最终结果记录数可能有变化,但字段数一定会增加. 连接查询的意义:在用户查询数据的时候,需要显示的数据来自多张表. 连 ...

  7. 史上最简单的SpringCloud教程 | 第四篇:断路器(Hystrix)--里面有BUG,所以我转载改一下

    017年04月09日 21:14:05 阅读数:271535 转载请标明出处:  http://blog.csdn.net/forezp/article/details/69934399  本文出自方 ...

  8. 史上最简单的 MySQL 教程(二)「关系型数据库」

    关系型数据库 1 定义 关系型数据库,是一种建立在关系模型(数学模型)上的数据库. 至于关系模型,则是一种所谓建立在关系上的模型,其包含三个方面,分别为: 数据结构:数据存储的形式,二维表(行和列): ...

  9. 2010年史上最简单的做母盘教程

    2010年史上最简单的做母盘教程 辛苦了两个小时才把教程写完....写得不好大家多多包涵 其实做母盘是一件十分简单的事,只要大家敢去试就能成功的,这教程只给小白看的,老鸟路过指点一下. 本人是珠海信佑 ...

  10. 史上最简单的spark教程第十七章-快速开发部署第一个sparkStreaming+Java流处理程序

    第一个流处理程序sparkStreaming+Java 史上最简单的spark教程 所有代码示例地址:https://github.com/Mydreamandreality/sparkResearc ...

最新文章

  1. SAP MM初阶创建服务采购订单时订购单位和物料组的缺省值
  2. php 7月世界排名2017,TIOBE2017榜单公布,PHP还会是世界上最好的语言吗?
  3. python的工作方向-Python的就业的方向和前景
  4. Microsoft Visual Studio (VS)2010 常用快捷键大全 便捷开发
  5. flutter怎么手动刷新_如何手动刷新或重新加载Flutter Firestore StreamBuilder?
  6. 【Asp.net】Session对象
  7. Preparing Cities for Robot Cars【城市准备迎接自动驾驶汽车】
  8. centos安装多个tomcat
  9. oracle 未找到段的存储定义,Exp-00003 no storage definition found issue in oracle 11g (未找到段 (0,0) 的存储定义)...
  10. 内存映射文件mmap原理分析
  11. OCR技术系列实践:银行卡、身份证、门牌号、护照、车牌、印刷体汉字识别
  12. 南抖音北快手,智障界的两泰斗
  13. java看视频可以学会吗,看it教程视频自学Java编程可以学会吗?
  14. 四路模拟高清解码,CVI,四通道多合一同轴高清解码芯片方案
  15. a豆的使命:每一位年轻人都值得珍重
  16. 牛客网 - [牛客假日团队赛6]迷路的牛
  17. docker学习 --Compose 容器编排,常用命令等.集成spring。mysql。redis
  18. 【MDCC 2016】iOS开发峰会回顾:实战Coding演示 技术大牛带你起飞
  19. jshint详细说明【vscode插件】
  20. 企业邮箱托管选哪家好,163企业邮箱如何购买?

热门文章

  1. 计算机的表白隐藏功能,微信还有这个功能?隐藏代码还能表白!教你高级告白手段...
  2. js文本框设置必填项_显示隐藏js字段 设置必填非必填
  3. layui 图片剪切/截取
  4. 数据分析案例-二手车价格预测
  5. 【超详细】QQ空间说说爬取教程(看看你的女神在想什么~
  6. java读写excel,解决poi包中没有org.apache.poi.ss.usermodel.CellType的问题
  7. 广东又将添新高校:香山大学来了
  8. 使用x264压制视频简介
  9. poj 1113 Wall 凸包
  10. Android事件分发理解