原理可以参考:使用textCNN进行文本分类的原理
Keras的另一个实现可以参考:Keras实现textCNN文本分类

模型构建与训练

定义网络结构

定义一个textCNN类, 代码为tensorflow2.x版本。

from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Embedding, Dense, Conv1D, GlobalMaxPooling1D, Concatenate, Dropoutclass TextCNN(object):def __init__(self, maxlen, max_features, embedding_dims,class_num=5,last_activation='softmax'):self.maxlen = maxlen               # 句子最大长度self.max_features = max_features   # 词表大小self.embedding_dims = embedding_dimsself.class_num = class_numself.last_activation = last_activation  # 激活函数def get_model(self):input = Input((self.maxlen,))   # batch_size留空embedding = Embedding(self.max_features, self.embedding_dims, input_length=self.maxlen)(input)convs = []for kernel_size in [3, 4, 5]:c = Conv1D(128, kernel_size, activation='relu')(embedding)c = GlobalMaxPooling1D()(c)convs.append(c)x = Concatenate()(convs)  # 拼接output = Dense(self.class_num, activation=self.last_activation)(x)model = Model(inputs=input, outputs=output)return model

定义通用的工具函数

注意:此处在utils.py中定义了一些工具函数
utils.py的代码如下所示:

# coding: utf-8import sys
from collections import Counter
import numpy as np
import tensorflow.keras as kr
import osif sys.version_info[0] > 2:is_py3 = True
else:reload(sys)sys.setdefaultencoding("utf-8")is_py3 = Falsedef open_file(filename, mode='r'):"""常用文件操作,可在python2和python3间切换.mode: 'r' or 'w' for read or write"""if is_py3:return open(filename, mode, encoding='utf-8', errors='ignore')else:return open(filename, mode)def read_file(filename):"""读取单个文件,文件中包含多个类别"""contents = []labels = []with open_file(filename) as f:for line in f:try:raw = line.strip().split("\t")content = raw[1].split(' ')if content:contents.append(content)labels.append(raw[0])except:passreturn contents, labels    def read_single_file(filename):"""读取单个文件,文件为单一类别"""contents = []    label = filename.split('/')[-1].split('.')[0]with open_file(filename) as f:for line in f:try:content = line.strip().split(' ')if content:contents.append(content)except:passreturn contents, labeldef read_files(dirname):"""读取文件夹"""contents = []labels = []files = [f for f in os.listdir(dirname) if f.endswith(".txt")]for filename in files:content, label = read_single_file(os.path.join(dirname, filename))contents.extend(content)labels.extend([label]*len(content))return contents, labelsdef build_vocab(train_dir, vocab_file, vocab_size=5000):"""根据训练集构建词汇表,存储"""data_train, _ = read_files(train_dir)all_data = []for content in data_train:all_data.extend(content)counter = Counter(all_data)count_pairs = counter.most_common(vocab_size - 1)words, _ = list(zip(*count_pairs))# 添加一个 <PAD> 来将所有文本pad为同一长度words = ['<PAD>'] + list(words)open_file(vocab_file, mode='w').write('\n'.join(words) + '\n')def read_vocab(vocab_file):"""读取词汇表"""# words = open_file(vocab_dir).read().strip().split('\n')with open_file(vocab_file) as fp:# 如果是py2 则每个值都转化为unicodewords = [_.strip() for _ in fp.readlines()]word_to_id = dict(zip(words, range(len(words))))return words, word_to_iddef read_category():"""读取分类,编码"""categories = ['car', 'entertainment', 'military', 'sports', 'technology']cat_to_id = dict(zip(categories, range(len(categories))))return categories, cat_to_iddef encode_cate(content, words):"""将id表示的内容转换为文字"""return [(words[x] if x in words else 40000) for x in content]def encode_sentences(contents, words):"""将id表示的内容转换为文字"""return [encode_cate(x,words) for x in contents]def process_file(filename, word_to_id, cat_to_id, max_length=600):"""将文件转换为id表示"""contents, labels = read_file(filename)data_id, label_id = [], []for i in range(len(contents)):data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])label_id.append(cat_to_id[labels[i]])# 使用keras提供的pad_sequences来将文本pad为固定长度x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示return x_pad, y_paddef batch_iter(x, y, batch_size=64):"""生成批次数据"""data_len = len(x)num_batch = int((data_len - 1) / batch_size) + 1indices = np.random.permutation(np.arange(data_len))x_shuffle = x[indices]y_shuffle = y[indices]for i in range(num_batch):start_id = i * batch_sizeend_id = min((i + 1) * batch_size, data_len)yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

数据处理与训练

from tensorflow.keras.preprocessing import sequence
import random
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
import sys
sys.path.append('../data/lesson2_data')
from utils import *# 路径等配置
data_dir = "../data/lesson2_data/data"
vocab_file = "../data/lesson2_data/vocab/vocab.txt"
vocab_size = 40000# 神经网络配置
max_features = 40001
maxlen = 100
batch_size = 256
embedding_dims = 50
epochs = 8print('数据预处理与加载数据...')
# 如果不存在词汇表,重建
if not os.path.exists(vocab_file):  build_vocab(data_dir, vocab_file, vocab_size)
# 获得 词汇/类别 与id映射字典
categories, cat_to_id = read_category()
words, word_to_id = read_vocab(vocab_file)# 全部数据
x, y = read_files(data_dir)
data = list(zip(x,y))
del x,y
# 乱序
random.shuffle(data)
# 切分训练集和测试集
train_data, test_data = train_test_split(data)
# 对文本的词id和类别id进行编码
x_train = encode_sentences([content[0] for content in train_data], word_to_id)
y_train = to_categorical(encode_cate([content[1] for content in train_data], cat_to_id))
x_test = encode_sentences([content[0] for content in test_data], word_to_id)
y_test = to_categorical(encode_cate([content[1] for content in test_data], cat_to_id))print('对序列做padding,保证是 samples*timestep 的维度')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)print('构建模型...')
model = TextCNN(maxlen, max_features, embedding_dims).get_model()
model.compile('adam', 'categorical_crossentropy', metrics=['accuracy'])print('训练...')
# 设定callbacks回调函数
my_callbacks = [ModelCheckpoint('../../tmp/cnn_model.h5', verbose=1),EarlyStopping(monitor='val_accuracy', patience=2, mode='max')
]# fit拟合数据
history = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,callbacks=my_callbacks,validation_data=(x_test, y_test))#print('对测试集预测...')
#result = model.predict(x_test)

训练中间信息输出(画图)

import matplotlib.pyplot as plt
plt.switch_backend('agg')
%matplotlib inlinefig1 = plt.figure()
plt.plot(history.history['loss'],'r',linewidth=3.0)
plt.plot(history.history['val_loss'],'b',linewidth=3.0)
plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Loss',fontsize=16)
plt.title('Loss Curves :CNN',fontsize=16)
fig1.savefig('../../tmp/loss_cnn.png')
plt.show()
fig2=plt.figure()
plt.plot(history.history['accuracy'],'r',linewidth=3.0)
plt.plot(history.history['val_accuracy'],'b',linewidth=3.0)
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Accuracy',fontsize=16)
plt.title('Accuracy Curves : CNN',fontsize=16)
fig2.savefig('../../tmp/accuracy_cnn.png')
plt.show()

注意:Windows下pycharm中运行需要把%matplotlib inlineplt.switch_backend('agg')都注释掉才能出图。

模型结构打印

from tensorflow.keras.utils import plot_model
# model.summary()
plot_model(model, show_shapes=True, show_layer_names=True)

遇到问题

  1. TensorFlow版本为1.13。
    TensorFlow1.x版本中没有集成keras,需要把tensorflow.keras.xx改为keras.xx

  2. 运行报错:ValueError: Error when checking target: expected dense_1 to have shape (5,) but got array with shape (40001,)
    出现这个问题很可能是多分类的label设置不当导致的。通过检查代码发现, 经过to_categorical之后的y有40001维,且前40000维均为0,最后一维为1。定位到utils.py中的encode_cate函数,发现是label中带有文件夹名,无法映射到cat_to_id,再次定位到read_single_file,发现是Windows和linux分隔符不同导致的。于是将label = filename.split('/')[-1].split('.')[0]修改为:

    import platform
    if platform.system()=='Linux':   # Windows will be : Windows, Linux will be : Linuxlabel = filename.split('/')[-1].split('.')[0]
    else:label = filename.split('\\')[-1].split('.')[0]
    
  3. 预测时plt.plot(history.history['accuracy'],'r',linewidth=3.0)处报错KeyError: 'accuracy'
    原因是keras库版本不同。因为keras库老版本中的参数不是accuracy,而是acc,将参数accuracy替换为acc即可。同理,将val_accuracy替换为val_acc。
    打印print(history.history.keys())可得,history的四个参数为dict_keys(['val_loss', 'val_acc', 'loss', 'acc'])
    修改代码为:

    fig2=plt.figure()
    plt.plot(history.history['acc'],'r',linewidth=3.0)
    plt.plot(history.history['val_acc'],'b',linewidth=3.0)
    
  4. 画结构图时plot_model(model, show_shapes=True, show_layer_names=True)处报错ImportError: Failed to import ``pydot``. Please install ``pydot``. For example with ``pip install pydot``.然后又报错'pydotfailed to call GraphViz.' OSError: ``pydot`` failed to call GraphViz.Please install GraphViz (https://www.graphviz.org/) and ensure that its executables are in the $PATH.
    解决:使用plot_model得先安装好另外两个库,graphviz和pydot。

    pip3 install graphviz
    pip3 install pydot
    

    电脑还得安装graphviz并添加环境变量。下载地址http://www.graphviz.org/download/,选择自己操作系统对应的版本的.msi文件,安装完成后将安装的Graphviz2.38/bin添加到环境变量,添加完后最好重启电脑。如果在pycharm中不显示图片,就用to_file写入到文件里。
    也可以在代码中添加环境变量,如下:

    import os
    os.environ["PATH"] += os.pathsep + 'E:\Program Files (x86)\Graphviz2.38\\bin'
    plot_model(model, show_shapes=True, show_layer_names=True,to_file='test.png')
    

NLP实战之textCNN中文文本分类相关推荐

  1. NLP实战之textRNN中文文本分类

    TextRNN论文:https://www.ijcai.org/Proceedings/16/Papers/408.pdf TextRNN网络结构: 环境: windows 10.tensorflow ...

  2. NLP - 15 分钟搭建中文文本分类模型

    https://eliyar.biz/nlp_chinese_text_classification_in_15mins/

  3. 万字总结Keras深度学习中文文本分类

    摘要:文章将详细讲解Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CNN.TextCNN. 本文分享自华为云社区<Keras深度学习中文 ...

  4. textcnn文本词向量_基于Text-CNN模型的中文文本分类实战

    1 文本分类 文本分类是自然语言处理领域最活跃的研究方向之一,目前文本分类在工业界的应用场景非常普遍,从新闻的分类.商品评论信息的情感分类到微博信息打标签辅助推荐系统,了解文本分类技术是NLP初学者比 ...

  5. Pytorch TextCNN实现中文文本分类(附完整训练代码)

    Pytorch TextCNN实现中文文本分类(附完整训练代码) 目录 Pytorch TextCNN实现中文文本分类(附完整训练代码) 一.项目介绍 二.中文文本数据集 (1)THUCNews文本数 ...

  6. 基于CNN中文文本分类实战

    一.前言 之前写过一篇基于循环神经网络(RNN)的情感分类文章,这次我们换种思路,采用卷积神经网络(CNN)来进行文本分类任务.倘若对CNN如何在文本上进行卷积的可以移步博主的快速入门CNN在NLP中 ...

  7. 【NLP】BERT 模型与中文文本分类实践

    简介 2018年10月11日,Google发布的论文<Pre-training of Deep Bidirectional Transformers for Language Understan ...

  8. 【NLP】Kaggle从零到实践:Bert中文文本分类

    Bert是非常强化的NLP模型,在文本分类的精度非常高.本文将介绍Bert中文文本分类的基础步骤,文末有代码获取方法. 步骤1:读取数据 本文选取了头条新闻分类数据集来完成分类任务,此数据集是根据头条 ...

  9. 【NLP】中文文本分类数据增强方法:EDA 与代码实现

    数据增强可以算作是做深度学习算法的一个小trick.该介绍主要出自论文:EDA: Easy Data Augmentation Techniques for Boosting Performance ...

最新文章

  1. Sonatype收购Vor Security,扩展对Nexus开源组件的支持
  2. SpringMVC + MyBatis整合 【转】
  3. 不惧困难,阿特拉斯机器人展示超强平衡能力
  4. python资料百度网盘-python自动保存百度盘资源到百度盘中的实例代码
  5. 数据库问题解决后,应用面对的挑战
  6. Java做一个动画效果音量调节_设计与实现一个 ISoundable 接口,该接口具有发声功能、还能调节音量大小...
  7. 【C语言】将输入的10个整数逆序输出
  8. Java基础02 位运算符<<、>>
  9. 嵌入式Linux入门12:编程规范
  10. 找出N个无序数中第K大的数
  11. pycharm 里面配置pip,安装库
  12. Oprofile工具的使用
  13. Pycharm导入python项目
  14. PN5180射频识别芯片学习笔记
  15. oki5530sc打印错误_我用的是四通oki 5530sc针式打印机,打印时提示正在打印,但就是不打印...
  16. 三相逆变器仿真matlab,在MATLAB中实现三相电压型逆变器仿真
  17. Java 9 与 Java 10
  18. PTG DAO 生态
  19. 用dup2和dup产生一份file descriptor 的拷贝
  20. 【必备算法】动态规划:LeetCode题(九)309. 最佳买卖股票时机含冷冻期,714. 买卖股票的最佳含手续费

热门文章

  1. 个人阅读作业+个人总结
  2. 信息技术基础2(课程笔记)
  3. 笔记本开机前插入耳机再开机有声音,开机后插入耳机后没声音
  4. 用计算机修改图片或照片,如何利用电脑自带的画图工具修改图片的基本属性
  5. java实现ABAC
  6. 【教程】美团联盟个人怎么注册推广做外卖cps红包
  7. 解决 unity vs2017编辑器 全范围脚本报错 : predefined type 'system.object' is not defined or imported
  8. 手机app显示服务器端异常502,修复 HTTP 502 和 HTTP 503 错误 - Azure App Service | Microsoft Docs...
  9. 柠檬班的课程怎么样,来自一个金融行业转行到软件测试行业的故事
  10. IE和标准下有哪些兼容性的写法