本文将会介绍如何在PyTorch中使用CNN模型进行中文文本分类。
  使用CNN实现中文文本分类的基本思路:

  • 文本预处理
  • 将字(或token)进行汇总,形成字典文件,可保留前n个字
  • 文字转数字,不在字典文件中用表示
  • 对文本进行阶段与填充,填充用,将文本向量长度统一
  • 建立Embedding层
  • 建立CNN模型
  • 训练模型,调整参数得到最优表现的模型,获取模型评估指标
  • 保存模型,并在新样本上进行预测

  我们以搜狗小分类数据集为例,使用CNN模型对其进行文本分类。

数据集介绍

  搜狗小分类数据集,共有5个类别,分别为体育、健康、军事、教育、汽车。划分为训练集和测试集,其中训练集每个类别800条样本,测试集每个类别100条样本。

文本预处理

  读取训练集中的文本数据,形成文字列表,打乱顺序,保留前N个文字,形成Pickle文件,并保存类别列表至Pickle文件,便于后续处理,Python代码如下:

# -*- coding: utf-8 -*-
# @Time : 2023/3/16 10:32
# @Author : Jclian91
# @File : preprocessing.py
# @Place : Minghang, Shanghai
from random import shuffle
import pandas as pd
from collections import Counter, defaultdictfrom params import TRAIN_FILE_PATH, NUM_WORDS
from pickle_file_operaor import PickleFileOperatorclass FilePreprossing(object):def __init__(self, n):# 保留前n个字self.__n = ndef _read_train_file(self):train_pd = pd.read_csv(TRAIN_FILE_PATH)label_list = train_pd['label'].unique().tolist()# 统计文字频数character_dict = defaultdict(int)for content in train_pd['content']:for key, value in Counter(content).items():character_dict[key] += value# 不排序sort_char_list = [(k, v) for k, v in character_dict.items()]shuffle(sort_char_list)print(f'total{len(character_dict)}characters.')print('top 10 chars: ', sort_char_list[:10])# 保留前n个文字top_n_chars = [_[0] for _ in sort_char_list[:self.__n]]return label_list, top_n_charsdef run(self):label_list, top_n_chars = self._read_train_file()PickleFileOperator(data=label_list, file_path='labels.pk').save()PickleFileOperator(data=top_n_chars, file_path='chars.pk').save()if __name__ == '__main__':processor = FilePreprossing(NUM_WORDS)processor.run()# 读取pickle文件labels = PickleFileOperator(file_path='labels.pk').read()print(labels)content = PickleFileOperator(file_path='chars.pk').read()print(content)

   文字转数字,不在字典文件中用UNK表示。对文本进行阶段与填充,填充用PAD,将文本向量长度统一。Python实现代码如下:

# -*- coding: utf-8 -*-
# @Time : 2023/3/16 11:15
# @Author : Jclian91
# @File : text_featuring.py
# @Place : Minghang, Shanghai
import pandas as pd
import numpy as np
import torch as T
from torch.utils.data import Dataset, DataLoader, random_splitfrom params import (PAD_NO,UNK_NO,START_NO,SENT_LENGTH,TEST_FILE_PATH,TRAIN_FILE_PATH)
from pickle_file_operaor import PickleFileOperator# load csv file
def load_csv_file(file_path):df = pd.read_csv(file_path)samples, y_true = [], []for index, row in df.iterrows():y_true.append(row['label'])samples.append(row['content'])return samples, y_true# 读取pickle文件
def load_file_file():labels = PickleFileOperator(file_path='labels.pk').read()chars = PickleFileOperator(file_path='chars.pk').read()label_dict = dict(zip(labels, range(len(labels))))char_dict = dict(zip(chars, range(len(chars))))return label_dict, char_dict# 文本预处理
def text_feature(labels, contents, label_dict, char_dict):samples, y_true = [], []for s_label, s_content in zip(labels, contents):# one_hot_vector = [0] * len(label_dict)# one_hot_vector[label_dict[s_label]] = 1# y_true.append([one_hot_vector])y_true.append(label_dict[s_label])train_sample = []for char in s_content:if char in char_dict:train_sample.append(START_NO + char_dict[char])else:train_sample.append(UNK_NO)# 补充或截断if len(train_sample) < SENT_LENGTH:samples.append(train_sample + [PAD_NO] * (SENT_LENGTH - len(train_sample)))else:samples.append(train_sample[:SENT_LENGTH])return samples, y_true# dataset
class CSVDataset(Dataset):# load the datasetdef __init__(self, file_path):label_dict, char_dict = load_file_file()samples, y_true = load_csv_file(file_path)x, y = text_feature(y_true, samples, label_dict, char_dict)self.X = T.from_numpy(np.array(x)).long()self.y = T.from_numpy(np.array(y))# number of rows in the datasetdef __len__(self):return len(self.X)# get a row at an indexdef __getitem__(self, idx):return [self.X[idx], self.y[idx]]# get indexes for train and test rowsdef get_splits(self, n_test=0.3):# determine sizestest_size = round(n_test * len(self.X))train_size = len(self.X) - test_size# calculate the splitreturn random_split(self, [train_size, test_size])if __name__ == '__main__':p = CSVDataset().__getitem__(1)print(p)

以下面的文本为例,将其转化为向量(假设最大长度为40)后的结果为:

盖世汽车讯,特斯拉去年击败了宝马,夺得了美国豪华汽车市场的桂

[3899, 4131, 2497, 496, 3746, 221, 3273, 1986, 4002, 4882, 3238, 5114, 1516, 353, 4767, 2357, 221, 2920, 387, 353, 4434, 4930, 4079, 4187, 2497, 496, 883, 1325, 1061, 3901, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

创建模型

  创建模型:建立Embedding层,建立CNN模型,模型图如下:


Python实现代码如下:

# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nnfrom params import NUM_WORDS, SENT_LENGTH, EMBEDDING_SIZEclass TextClassifier(nn.ModuleList):def __init__(self):super(TextClassifier, self).__init__()# Parameters regarding text preprocessingself.seq_len = SENT_LENGTHself.num_words = NUM_WORDSself.embedding_size = EMBEDDING_SIZE# Dropout definitionself.dropout = nn.Dropout(0.25)# CNN parameters definition# Kernel sizesself.kernel_1 = 2self.kernel_2 = 3self.kernel_3 = 4self.kernel_4 = 5# Output size for each convolutionself.out_size = 32# Number of strides for each convolutionself.stride = 2# Embedding layer definitionself.embedding = nn.Embedding(self.num_words + 2, self.embedding_size)# Convolution layers definitionself.conv_1 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_1, self.stride)self.conv_2 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_2, self.stride)self.conv_3 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_3, self.stride)self.conv_4 = nn.Conv1d(self.seq_len, self.out_size, self.kernel_4, self.stride)# Max pooling layers definitionself.pool_1 = nn.MaxPool1d(self.kernel_1, self.stride)self.pool_2 = nn.MaxPool1d(self.kernel_2, self.stride)self.pool_3 = nn.MaxPool1d(self.kernel_3, self.stride)self.pool_4 = nn.MaxPool1d(self.kernel_4, self.stride)# Fully connected layer definitionself.fc = nn.Linear(self.in_features_fc(), 5)def in_features_fc(self):"""Calculates the number of output features after Convolution + Max poolingConvolved_Features = ((embedding_size + (2 * padding) - dilation * (kernel - 1) - 1) / stride) + 1Pooled_Features = ((embedding_size + (2 * padding) - dilation * (kernel - 1) - 1) / stride) + 1source: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html"""# Calcualte size of convolved/pooled features for convolution_1/max_pooling_1 featuresout_conv_1 = ((self.embedding_size - 1 * (self.kernel_1 - 1) - 1) / self.stride) + 1out_conv_1 = math.floor(out_conv_1)out_pool_1 = ((out_conv_1 - 1 * (self.kernel_1 - 1) - 1) / self.stride) + 1out_pool_1 = math.floor(out_pool_1)# Calcualte size of convolved/pooled features for convolution_2/max_pooling_2 featuresout_conv_2 = ((self.embedding_size - 1 * (self.kernel_2 - 1) - 1) / self.stride) + 1out_conv_2 = math.floor(out_conv_2)out_pool_2 = ((out_conv_2 - 1 * (self.kernel_2 - 1) - 1) / self.stride) + 1out_pool_2 = math.floor(out_pool_2)# Calcualte size of convolved/pooled features for convolution_3/max_pooling_3 featuresout_conv_3 = ((self.embedding_size - 1 * (self.kernel_3 - 1) - 1) / self.stride) + 1out_conv_3 = math.floor(out_conv_3)out_pool_3 = ((out_conv_3 - 1 * (self.kernel_3 - 1) - 1) / self.stride) + 1out_pool_3 = math.floor(out_pool_3)# Calculate size of convolved/pooled features for convolution_4/max_pooling_4 featuresout_conv_4 = ((self.embedding_size - 1 * (self.kernel_4 - 1) - 1) / self.stride) + 1out_conv_4 = math.floor(out_conv_4)out_pool_4 = ((out_conv_4 - 1 * (self.kernel_4 - 1) - 1) / self.stride) + 1out_pool_4 = math.floor(out_pool_4)# Returns "flattened" vector (input for fully connected layer)return (out_pool_1 + out_pool_2 + out_pool_3 + out_pool_4) * self.out_sizedef forward(self, x):# Sequence of tokes is filterd through an embedding layerx = self.embedding(x)# Convolution layer 1 is appliedx1 = self.conv_1(x)x1 = torch.relu(x1)x1 = self.pool_1(x1)# Convolution layer 2 is appliedx2 = self.conv_2(x)x2 = torch.relu((x2))x2 = self.pool_2(x2)# Convolution layer 3 is appliedx3 = self.conv_3(x)x3 = torch.relu(x3)x3 = self.pool_3(x3)# Convolution layer 4 is appliedx4 = self.conv_4(x)x4 = torch.relu(x4)x4 = self.pool_4(x4)# The output of each convolutional layer is concatenated into a unique vectorunion = torch.cat((x1, x2, x3, x4), 2)union = union.reshape(union.size(0), -1)# The "flattened" vector is passed through a fully connected layerout = self.fc(union)# Dropout is appliedout = self.dropout(out)# Activation function is applied# out = nn.Softmax(dim=1)(out)return out

  训练模型,调整参数得到最优表现的模型,获取模型评估指标,Python实现代码如下:

# -*- coding: utf-8 -*-
import torch
from torch.optim import RMSprop, Adam
from torch.nn import CrossEntropyLoss, Softmax
import torch.nn.functional as F
from torch.utils.data import DataLoader
from numpy import vstack, argmax
from sklearn.metrics import accuracy_scorefrom model import TextClassifier
from text_featuring import CSVDataset
from params import TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, LEARNING_RATE, EPOCHS, TRAIN_FILE_PATH, TEST_FILE_PATH# model train
class ModelTrainer(object):# evaluate the model@staticmethoddef evaluate_model(test_dl, model):predictions, actuals = [], []for i, (inputs, targets) in enumerate(test_dl):# evaluate the model on the test setyhat = model(inputs)# retrieve numpy arrayyhat = yhat.detach().numpy()actual = targets.numpy()# convert to class labelsyhat = argmax(yhat, axis=1)# reshape for stackingactual = actual.reshape((len(actual), 1))yhat = yhat.reshape((len(yhat), 1))# storepredictions.append(yhat)actuals.append(actual)predictions, actuals = vstack(predictions), vstack(actuals)# calculate accuracyacc = accuracy_score(actuals, predictions)return acc# Model Training, evaluation and metrics calculationdef train(self, model):# calculate splittrain = CSVDataset(TRAIN_FILE_PATH)test = CSVDataset(TEST_FILE_PATH)# prepare data loaderstrain_dl = DataLoader(train, batch_size=TRAIN_BATCH_SIZE, shuffle=True)test_dl = DataLoader(test, batch_size=TEST_BATCH_SIZE)# Define optimizeroptimizer = Adam(model.parameters(), lr=LEARNING_RATE)# Starts training phasefor epoch in range(EPOCHS):# Starts batch trainingfor x_batch, y_batch in train_dl:y_batch = y_batch.long()# Clean gradientsoptimizer.zero_grad()# Feed the modely_pred = model(x_batch)# Loss calculationloss = CrossEntropyLoss()(y_pred, y_batch)# Gradients calculationloss.backward()# Gradients updateoptimizer.step()# Evaluationtest_accuracy = self.evaluate_model(test_dl, model)print("Epoch: %d, loss: %.5f, Test accuracy: %.5f" % (epoch+1, loss.item(), test_accuracy))if __name__ == '__main__':model = TextClassifier()# 统计参数量num_params = sum(param.numel() for param in model.parameters())print(num_params)ModelTrainer().train(model)torch.save(model, 'sougou_mini_cls.pth')

模型预测

   对保存好的模型,在验证集上进行指标评估,得到的结果:accuracy为0.7960,precision,recall为0.7960,F1-score为0.7953,混淆矩阵如下:


   对新样本进行预测,Python代码如下:

# -*- coding: utf-8 -*-
# @Time : 2023/3/16 16:42
# @Author : Jclian91
# @File : model_predict.py
# @Place : Minghang, Shanghai
import torch as T
import numpy as npfrom text_featuring import load_file_file, text_feature
from model import TextClassifiermodel = T.load('sougou_mini_cls.pth')label_dict, char_dict = load_file_file()
print(label_dict)text = '盖世汽车讯,特斯拉去年击败了宝马,夺得了美国豪华汽车市场的桂冠,并在今年实现了开门红。1月份,得益于大幅降价和7500美元美国电动汽车税收抵免,特斯拉再度击败宝马,蝉联了美国豪华车销冠,并且注册量超过了排名第三的梅赛德斯-奔驰和排名第四的雷克萨斯的总和。根据Experian的数据,在所有豪华品牌中,1月份,特斯拉在美国的豪华车注册量为49,917辆,同比增长34%;宝马的注册量为31,070辆,同比增长2.5%;奔驰的注册量为23,345辆,同比增长7.3%;雷克萨斯的注册量为23,082辆,同比下降6.6%。奥迪以19,113辆的注册量排名第五,同比增长38%。凯迪拉克注册量为13,220辆,较去年同期增长36%,排名第六。排名第七的讴歌的注册量为10,833辆,同比增长32%。沃尔沃汽车排名第八,注册量为8,864辆,同比增长1.8%。路虎以7,003辆的注册量排名第九,林肯以6,964辆的注册量排名第十。'label, sample = ['汽车'], [text]
samples, y_true = text_feature(label, sample, label_dict, char_dict)
print(text)
print(samples, y_true)
x = T.from_numpy(np.array(samples)).long()
y_pred = model(x)
print(y_pred)

预测结果如下:

新文本 预测类别
盖世汽车讯,特斯拉去年击败了宝马,夺得了美国豪华汽车市场的桂冠,并在今年实现了开门红。1月份,得益于大幅降价和7500美元美国电动汽车税收抵免,特斯拉再度击败宝马,蝉联了美国豪华车销冠,并且注册量超过了排名第三的梅赛德斯-奔驰和排名第四的雷克萨斯的总和。根据Experian的数据,在所有豪华品牌中,1月份,特斯拉在美国的豪华车注册量为49,917辆,同比增长34%;宝马的注册量为31,070辆,同比增长2.5%;奔驰的注册量为23,345辆,同比增长7.3%;雷克萨斯的注册量为23,082辆,同比下降6.6%。奥迪以19,113辆的注册量排名第五,同比增长38%。凯迪拉克注册量为13,220辆,较去年同期增长36%,排名第六。排名第七的讴歌的注册量为10,833辆,同比增长32%。沃尔沃汽车排名第八,注册量为8,864辆,同比增长1.8%。路虎以7,003辆的注册量排名第九,林肯以6,964辆的注册量排名第十。 汽车
北京时间3月16日,NBA官方公布了对于灰熊球星贾-莫兰特直播中持枪事件的调查结果灰熊,由于无法确定枪支是否为莫兰特所有,也无法证明他曾持枪到过NBA场馆,因为对他处以禁赛八场的处罚,且此前已禁赛场次将算在禁赛八场的场次内,他最早将在下周复出。 体育
3月11日,由新浪教育、微博教育、择校行联合主办的“新浪&微博2023国际教育春季巡展•深圳站”于深圳凯宾斯基酒店成功举办。深圳优质学校亮相展会,上千组家庭前来逛展。近30所全国及深圳民办国际化学校、外籍人员子女学校、公办学校国际部等多元化、多类型优质学校参与了本次活动。此外,近10位国际化学校校长分享了学校的办学特色、教育理念及学生的成长案例,参展家庭纷纷表示受益匪浅。展会搭建家校沟通桥梁,帮助家长们合理规划孩子的国际教育之路。深圳国际预科书院/招生办主任沈兰Nancy Shen参加了本次活动并带来了精彩的演讲,以下为演讲实录:" 教育
指导专家:皮肤科教研室副主任、武汉协和医院皮肤性病科主任医师冯爱平教授在临床上,经常能看到有些人出现反复发作的口腔溃疡,四季不断,深受其扰。其实这已不单单是口腔问题,而是全身疾病的体现,特别是一些免疫系统疾病,不仅表现在皮肤还会损害黏膜,下列几种情况是造成“复发性口腔溃疡”的原因。缺乏维生素及微量元素。缺乏微量元素锌、铁、叶酸、维生素B12等时,会引发口角炎。很多日常生活行为可能造成维生素的缺乏,如过分淘洗米、长期进食精米面、吃素食等,很容易造成B族维生素的缺失。 健康

总结

  本项目已上传至Github,访问网址为:https://github.com/percent4/PyTorch_Learning/tree/master/cnn_text_classification

参考文献

  1. Text-Classification-CNN-PyTorch: https://github.com/FernandoLpz/Text-Classification-CNN-PyTorch

PyTorch入门(五)使用CNN模型进行中文文本分类相关推荐

  1. textcnn文本词向量_基于Text-CNN模型的中文文本分类实战

    1 文本分类 文本分类是自然语言处理领域最活跃的研究方向之一,目前文本分类在工业界的应用场景非常普遍,从新闻的分类.商品评论信息的情感分类到微博信息打标签辅助推荐系统,了解文本分类技术是NLP初学者比 ...

  2. 【NLP】BERT 模型与中文文本分类实践

    简介 2018年10月11日,Google发布的论文<Pre-training of Deep Bidirectional Transformers for Language Understan ...

  3. 基于Text-CNN模型的中文文本分类实战

    七月 上海 | 高性能计算之GPU CUDA培训 7月27-29日三天密集式学习  快速带你入门阅读全文> 正文共5260个字,21张图,预计阅读时间28分钟. Text-CNN 1.文本分类 ...

  4. huggingFace 中文模型实战——中文文本分类

    学习了哔哩哔哩up主--兰斯诺特 视频后做的学习笔记 代码网址 https://github.com/lansinuote/Huggingface_Toturials upz主推荐书:<基于Be ...

  5. tensorflow2.0五种机器学习算法对中文文本分类

    向AI转型的程序员都关注了这个号

  6. Pytorch TextCNN实现中文文本分类(附完整训练代码)

    Pytorch TextCNN实现中文文本分类(附完整训练代码) 目录 Pytorch TextCNN实现中文文本分类(附完整训练代码) 一.项目介绍 二.中文文本数据集 (1)THUCNews文本数 ...

  7. linux tf2 中文,tf2+cnn+中文文本分类优化系列(2)

    1 前言 接着上次的tf2+cnn+中文文本分类优化系列(1),本次进行优化:使用多个卷积核进行特征抽取.之前是使用filter_size=2进行2-gram特征的识别,本次使用filter_size ...

  8. 基于CNN中文文本分类实战

    一.前言 之前写过一篇基于循环神经网络(RNN)的情感分类文章,这次我们换种思路,采用卷积神经网络(CNN)来进行文本分类任务.倘若对CNN如何在文本上进行卷积的可以移步博主的快速入门CNN在NLP中 ...

  9. TensorFlow使用CNN实现中文文本分类

    TensorFlow使用CNN实现中文文本分类 读研期间使用过TensorFlow实现过简单的CNN情感分析(分类),当然这是比较low的二分类情况,后来进行多分类情况.但之前的学习基本上都是在英文词 ...

最新文章

  1. java 人脸检测_Java+OpenCV实现人脸检测并自动拍照
  2. 小猿圈之python的输入和输出
  3. mysql rpm 安装多实例_MySQL搭建系列之多实例
  4. 【算法】算法 动态规划 背包问题
  5. amf组网_【5G核心网】5G核心网SA组网方案及4G/5G互操作探讨
  6. unity性能优化初级入门篇
  7. 11.这就是搜索引擎:核心技术详解 --- 搜索引擎缓存机制
  8. 5G物联网数据网关助力工业企业转型升级
  9. PowerShell抓取电脑序列号
  10. goahead(嵌入式Web服务器)之asp、goform篇
  11. 新浪sae php,PHP+新浪微博开放平台+新浪云平台(SAE)1
  12. 华为运营商级路由器配置示例 | 配置BGP方式VPLS示例
  13. m4s格式转换mp3_如何将m4a无损转换mp3音频格式
  14. 二战暨南大学网络空间安全经验贴,纯干货!
  15. QTP_QTP学习笔记(1)
  16. 非安装tomcat,服务安装_离水的鱼_新浪博客
  17. infiniband学习总结
  18. stm32驱动rgb屏电路图_基于STM32F767驱动 LTDC LCD(RGB屏)
  19. 人工智能 —— A*算法
  20. MeasureReady TM 155 开发人员问答,第 2 部分:触摸屏设计和用户测试

热门文章

  1. Java最优化求最大公约数\最小公倍数方法
  2. mujoco强化学习环境配置,常见OSError: __glewBindBuffer错误解决方案
  3. office2013:打开和关闭Word文档提示“MicrosoftWord已停止工作”的解决办法
  4. linux.4 进程概念
  5. Java仿微信时间显示
  6. 2021辽宁高考成绩查询具体时间,2021年辽宁高考成绩什么时候出具体时间几点 具体准确时间...
  7. D3DX库的使用D3DX初始化
  8. Stable Diffusion - ReV Animated v1.2.2 的 2.5D 模型与提示词
  9. 引入vue.js的项目如何实现ie兼容
  10. 自己选择的路,跪着也要走完。(励志)