TensorFlow使用CNN实现中文文本分类

读研期间使用过TensorFlow实现过简单的CNN情感分析(分类),当然这是比较low的二分类情况,后来进行多分类情况。但之前的学习基本上都是在英文词库上训练的。断断续续,想整理一下手头的项目资料,于是就拾起读研期间的文本分类的小项目,花了一点时间,把原来英文文本分类的项目,应用在中文文本分类,效果还不错,在THUCNews中文数据集上,准确率93.9%左右,老规矩,先上源码地址

Github项目源码:nlp-learning-tutorials/THUCNews at master · PanJinquan/nlp-learning-tutorials · GitHub, 记得给个“Star”哈


目录

TensorFlow使用CNN实现中文文本分类

一、项目介绍

1.1 目录结构

1.2 THUCNews数据集

二、CNN模型结构

三、文本预处理

1、jieba中文分词

2、gensim训练word2vec模型

四、训练过程

五、测试过程


一、项目介绍

1.1 目录结构

Github项目源码:nlp-learning-tutorials/THUCNews at master · PanJinquan/nlp-learning-tutorials · GitHub, 记得给个“Star”哈

其他资源地址:

  • 1.THUCTC官方数据集,链接: THUCTC: 一个高效的中文文本分类工具
  • 2.THUCTC百度网盘,链接: 百度网盘 请输入提取码 提取码: bbpe
  • 3.已经训练好的word2vec模型:链接: 百度网盘 请输入提取码 提取码: mtrj
  • 4.使用词向量处理的THUCNews数据:链接: 百度网盘 请输入提取码 提取码: m9dx

1.2 THUCNews数据集

THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。我们在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。相关介绍,可以看这里THUCTC: 一个高效的中文文本分类工具

下载地址:
1.官方数据集下载链接: http://thuctc.thunlp.org/message
2.百度网盘下载链接: https://pan.baidu.com/s/1DT5xY9m2yfu1YGaGxpWiBQ 提取码: bbpe

二、CNN模型结构

CNN文本分类的网络结,如下:

简单分析一下:

(1)我们假定输入CNN的数据是二维的,其中每一行表示一个样本(即一个字词),如图中“I”、“like”等。每一个样本(字词)有d个维度,可以看成是词向量长度,即每个字词的维度,程序中用embedding_dim表示。

(2)使用CNN的卷积对这个二维数据进行卷积:在图像的CNN卷积中,卷积核的大小一般是3*3,5*5等,但在NLP中就不就不能这么搞了,因为这里的输入数据每行是一个样本了!假设卷积核的大小为[filter_height,filter_width],那么卷积核的高度filter_height可以为1,2,3等任意值,而宽度filter_width只能是embedding_dim的大小,这样才能把完整的样本框进去!

下面是使用TensorFlow实现的CNN文本分类网络:TextCNN,

max_sentence_length = 300 # 最大句子长度,也就是说文本样本中字词的最大长度,不足补零,多余的截断
 embedding_dim = 128 #词向量长度,即每个字词的维度
 filter_sizes = [3, 4, 5, 6] #卷积核大小
 num_filters = 200  # Number of filters per filter size 卷价个数
 base_lr=0.001      # 学习率
 dropout_keep_prob = 0.5
 l2_reg_lambda = 0.0  # "L2 regularization lambda (default: 0.0)

import tensorflow as tf
import numpy as npclass TextCNN(object):'''A CNN for text classificationUses and embedding layer, followed by a convolutional, max-pooling and softmax layer.'''def __init__(self, sequence_length, num_classes,embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):# Placeholders for input, output, dropoutself.input_x = tf.placeholder(tf.float32, [None, sequence_length, embedding_size], name = "input_x")self.input_y = tf.placeholder(tf.float32, [None, num_classes], name = "input_y")self.dropout_keep_prob = tf.placeholder(tf.float32, name = "dropout_keep_prob")# Keeping track of l2 regularization loss (optional)l2_loss = tf.constant(0.0)# Embedding layer# self.embedded_chars = [None(batch_size), sequence_size, embedding_size]# self.embedded_chars = [None(batch_size), sequence_size, embedding_size, 1(num_channels)]self.embedded_chars = self.input_xself.embedded_chars_expended = tf.expand_dims(self.embedded_chars, -1)# Create a convolution + maxpool layer for each filter sizepooled_outputs = []for i, filter_size in enumerate(filter_sizes):# "filter_sizes", "3,4,5",with tf.name_scope("conv-maxpool-%s" % filter_size):# Convolution layerfilter_shape = [filter_size, embedding_size, 1, num_filters] # num_filters= 200# filter_shape =[height, width, in_channels, output_channels]W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name="b")conv = tf.nn.conv2d(self.embedded_chars_expended,W,strides=[1,1,1,1],padding="VALID",name="conv")# Apply nonlinearityh = tf.nn.relu(tf.nn.bias_add(conv, b), name = "relu")# Maxpooling over the outputspooled = tf.nn.max_pool(h,ksize=[1, sequence_length - filter_size + 1, 1, 1],strides=[1,1,1,1],padding="VALID",name="pool")pooled_outputs.append(pooled)# Combine all the pooled featuresnum_filters_total = num_filters * len(filter_sizes)self.h_pool = tf.concat(pooled_outputs, 3)# self.h_pool = tf.concat(3, pooled_outputs)self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])# Add dropoutwith tf.name_scope("dropout"):self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)# Final (unnomalized) scores and predictionswith tf.name_scope("output"):W = tf.get_variable("W",shape = [num_filters_total, num_classes],initializer = tf.contrib.layers.xavier_initializer())b = tf.Variable(tf.constant(0.1, shape=[num_classes], name = "b"))l2_loss += tf.nn.l2_loss(W)l2_loss += tf.nn.l2_loss(b)self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name = "scores")self.predictions = tf.argmax(self.scores, 1, name = "predictions")# Calculate Mean cross-entropy losswith tf.name_scope("loss"):losses = tf.nn.softmax_cross_entropy_with_logits(logits = self.scores, labels = self.input_y)self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss# Accuracywith tf.name_scope("accuracy"):correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name = "accuracy")

三、文本预处理

本博客使用jieba工具进行中文分词,使用词进行训练会比使用字进行训练,效果更好。

这部分:已经在《使用gensim训练中文语料word2vec》使用gensim训练中文语料word2vec_pan_jinquan的博客-CSDN博客_gensim 中文,详解讲解,自己看吧!

1、jieba中文分词

这个需要自己安装:pip install jieba 或者pip3 install jieba

2、gensim训练word2vec模型

这里使用jieba工具对THUCNews数据集进行分词,并利用gensim训练基于THUCNews的word2vec模型,这里提供已经训练好的word2vec模型:链接: https://pan.baidu.com/s/1n4ZgiF0gbY0zsK0706wZiw 提取码: mtrj 

2、THUCNews数据处理

有了word2vec模型,我就可以用word2vec词向量处理THUCNews数据:先使用jieba工具将中文句子转为字词,再将字词根据word2vec模型转为embadding 的索引,有了索引就可以获得词向量embadding  。这里并把这些索引数据保存为npy文件。后续训练时,CNN网络只需要读取这些npy文件,并将索引转为embadding,就可以进行训练了。

处理好的THUCNews数据下载地址:链接: https://pan.baidu.com/s/12Hdf36QafQ3y6KgV_vLTsw 提取码: m9dx

下面的代码实现的功能:使用jieba工具将中文句子转为字词,再将字词根据word2vec模型转为embadding 的索引矩阵,然后把这些索引矩阵保存下来(*.npy文件),源代码中batchSize=20000表示:将20000中文TXT文件处理成字词,转为索引矩阵并保存为一个*.npy文件,相当于将20000中文TXT文件保存为一个*.npy文件,主要是为了压缩数据,避免单个文件过大的情况。

# -*-coding: utf-8 -*-
"""@Project: nlp-learning-tutorials@File   : create_word2vec.py@Author : panjq@E-mail : pan_jinquan@163.com@Date   : 2018-11-08 17:37:21
"""
from gensim.models import Word2Vec
import random
import numpy as np
import os
import math
from utils import files_processing,segmentdef info_npy(file_list):sizes=0for file in file_list:data = np.load(file)print("data.shape:{}".format(data.shape))size = data.shape[0]sizes+=sizeprint("files nums:{}, data nums:{}".format(len(file_list), sizes))return sizesdef save_multi_file(files_list,labels_list,word2vec_path,out_dir,prefix,batchSize,max_sentence_length,labels_set=None,shuffle=False):'''将文件内容映射为索引矩阵,并且将数据保存为多个文件:param files_list::param labels_list::param word2vec_path: word2vec模型的位置:param out_dir: 文件保存的目录:param prefix:  保存文件的前缀名:param batchSize: 将多个文件内容保存为一个文件:param labels_set: labels集合:return:'''if not os.path.exists(out_dir):os.mkdir(out_dir)# 把该目录下的所有文件都删除files_processing.delete_dir_file(out_dir)if shuffle:random.seed(100)random.shuffle(files_list)random.seed(100)random.shuffle(labels_list)sample_num = len(files_list)w2vModel=load_wordVectors(word2vec_path)if labels_set is None:labels_set= files_processing.get_labels_set(label_list)labels_list, labels_set = files_processing.labels_encoding(labels_list, labels_set)labels_list=labels_list.tolist()batchNum = int(math.ceil(1.0 * sample_num / batchSize))for i in range(batchNum):start = i * batchSizeend = min((i + 1) * batchSize, sample_num)batch_files = files_list[start:end]batch_labels = labels_list[start:end]# 读取文件内容,字词分割batch_content = files_processing.read_files_list_to_segment(batch_files,max_sentence_length,padding_token='<PAD>',segment_type='word')# 将字词转为索引矩阵batch_indexMat = word2indexMat(w2vModel, batch_content, max_sentence_length)batch_labels=np.asarray(batch_labels)batch_labels = batch_labels.reshape([len(batch_labels), 1])# 保存*.npy文件filename = os.path.join(out_dir,prefix + '{0}.npy'.format(i))labels_indexMat = cat_labels_indexMat(batch_labels, batch_indexMat)np.save(filename, labels_indexMat)print('step:{}/{}, save:{}, data.shape{}'.format(i,batchNum,filename,labels_indexMat.shape))def cat_labels_indexMat(labels,indexMat):indexMat_labels = np.concatenate([labels,indexMat], axis=1)return indexMat_labelsdef split_labels_indexMat(indexMat_labels,label_index=0):labels = indexMat_labels[:, 0:label_index+1]     # 第一列是labelsindexMat = indexMat_labels[:, label_index+1:]  # 其余是indexMatreturn labels, indexMatdef load_wordVectors(word2vec_path):w2vModel = Word2Vec.load(word2vec_path)return w2vModeldef word2vector_lookup(w2vModel, sentences):'''将字词转换为词向量:param w2vModel: word2vector模型:param sentences: type->list[list[str]]:return: sentences对应的词向量,type->list[list[ndarray[list]]'''all_vectors = []embeddingDim = w2vModel.vector_sizeembeddingUnknown = [0 for i in range(embeddingDim)]for sentence in sentences:this_vector = []for word in sentence:if word in w2vModel.wv.vocab:v=w2vModel[word]this_vector.append(v)else:this_vector.append(embeddingUnknown)all_vectors.append(this_vector)all_vectors=np.array(all_vectors)return all_vectorsdef word2indexMat(w2vModel, sentences, max_sentence_length):'''将字词word转为索引矩阵:param w2vModel::param sentences::param max_sentence_length::return:'''nums_sample=len(sentences)indexMat = np.zeros((nums_sample, max_sentence_length), dtype='int32')rows = 0for sentence in sentences:indexCounter = 0for word in sentence:try:index = w2vModel.wv.vocab[word].index  # 获得单词word的下标indexMat[rows][indexCounter] = indexexcept :indexMat[rows][indexCounter] = 0  # Vector for unkown wordsindexCounter = indexCounter + 1if indexCounter >= max_sentence_length:breakrows+=1return indexMatdef indexMat2word(w2vModel, indexMat, max_sentence_length=None):'''将索引矩阵转为字词word:param w2vModel::param indexMat::param max_sentence_length::return:'''if max_sentence_length is None:row,col =indexMat.shapemax_sentence_length=colsentences=[]for Mat in indexMat:indexCounter = 0sentence=[]for index in Mat:try:word = w2vModel.wv.index2word[index] # 获得单词word的下标sentence+=[word]except :sentence+=['<PAD>']indexCounter = indexCounter + 1if indexCounter >= max_sentence_length:breaksentences.append(sentence)return sentencesdef save_indexMat(indexMat,path):np.save(path, indexMat)def load_indexMat(path):indexMat = np.load(path)return indexMatdef indexMat2vector_lookup(w2vModel,indexMat):'''将索引矩阵转为词向量:param w2vModel::param indexMat::return: 词向量'''all_vectors = w2vModel.wv.vectors[indexMat]return all_vectorsdef pos_neg_test():positive_data_file = "./data/ham_5000.utf8"negative_data_file = './data/spam_5000.utf8'word2vec_path = 'out/trained_word2vec.model'sentences, labels = files_processing.load_pos_neg_files(positive_data_file, negative_data_file)# embedding_test(positive_data_file,negative_data_file)sentences, max_document_length = segment.padding_sentences(sentences, '<PADDING>', padding_sentence_length=190)# train_wordVectors(sentences,embedding_size=128,word2vec_path=word2vec_path) # 训练word2vec,并保存word2vec_pathw2vModel=load_wordVectors(word2vec_path) #加载训练好的word2vec模型'''转换词向量提供有两种方法:[1]直接转换:根据字词直接映射到词向量:word2vector_lookup[2]间接转换:先将字词转为索引矩阵,再由索引矩阵映射到词向量:word2indexMat->indexMat2vector_lookup'''# [1]根据字词直接映射到词向量x1=word2vector_lookup(w2vModel, sentences)# [2]先将字词转为索引矩阵,再由索引矩阵映射到词向量indexMat_path = 'out/indexMat.npy'indexMat=word2indexMat(w2vModel, sentences, max_sentence_length=190) # 将字词转为索引矩阵save_indexMat(indexMat, indexMat_path)x2=indexMat2vector_lookup(w2vModel, indexMat) # 索引矩阵映射到词向量print("x.shape = {}".format(x2.shape))# shape=(10000, 190, 128)->(样本个数10000,每个样本的字词个数190,每个字词的向量长度128)if __name__=='__main__':# THUCNews_path='/home/ubuntu/project/tfTest/THUCNews/test'# THUCNews_path='/home/ubuntu/project/tfTest/THUCNews/spam'THUCNews_path='/home/ubuntu/project/tfTest/THUCNews/THUCNews'# 读取所有文件列表files_list, label_list = files_processing.gen_files_labels(THUCNews_path)max_sentence_length=300word2vec_path="../../word2vec/models/THUCNews_word2Vec/THUCNews_word2Vec_128.model"# 获得标签集合,并保存在本地# labels_set=['星座','财经','教育']# labels_set = files_processing.get_labels_set(label_list)labels_file='../data/THUCNews_labels.txt'# files_processing.write_txt(labels_file, labels_set)# 将数据划分为train val数据集train_files, train_label, val_files, val_label= files_processing.split_train_val_list(files_list, label_list, facror=0.9, shuffle=True)# contents, labels=files_processing.read_files_labels(files_list,label_list)# word2vec_path = 'out/trained_word2vec.model'train_out_dir='../data/train_data'prefix='train_data'batchSize=20000labels_set=files_processing.read_txt(labels_file)# labels_set2 = files_processing.read_txt(labels_file)save_multi_file(files_list=train_files,labels_list=train_label,word2vec_path=word2vec_path,out_dir=train_out_dir,prefix=prefix,batchSize=batchSize,max_sentence_length=max_sentence_length,labels_set=labels_set,shuffle=True)print("*******************************************************")val_out_dir='../data/val_data'prefix='val_data'save_multi_file(files_list=val_files,labels_list=val_label,word2vec_path=word2vec_path,out_dir=val_out_dir,prefix=prefix,batchSize=batchSize,max_sentence_length=max_sentence_length,labels_set=labels_set,shuffle=True)

四、训练过程

训练代码如下,注意,Github上不能上传大文件,所以你需要把上面提供的文件都下载下来,并放在对应的文件目录,就可以训练了。

训练中需要读取训练数据,即*.npy文件,*.npy文件保存的是索引数据,因此需要转为CNN的embadding数据:这个过程由函数:indexMat2vector_lookup完成:train_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, train_batch_data)

#! /usr/bin/env python
# encoding: utf-8import tensorflow as tf
import numpy as np
import os
from text_cnn import TextCNN
from utils import create_batch_data, create_word2vec, files_processingdef train(train_dir,val_dir,labels_file,word2vec_path,batch_size,max_steps,log_step,val_step,snapshot,out_dir):'''训练...:param train_dir: 训练数据目录:param val_dir:   val数据目录:param labels_file:  labels文件目录:param word2vec_path: 词向量模型文件:param batch_size: batch size:param max_steps:  最大迭代次数:param log_step:  log显示间隔:param val_step:  测试间隔:param snapshot:  保存模型间隔:param out_dir:   模型ckpt和summaries输出的目录:return:'''max_sentence_length = 300embedding_dim = 128filter_sizes = [3, 4, 5, 6]num_filters = 200  # Number of filters per filter sizebase_lr=0.001# 学习率dropout_keep_prob = 0.5l2_reg_lambda = 0.0  # "L2 regularization lambda (default: 0.0)allow_soft_placement = True  # 如果你指定的设备不存在,允许TF自动分配设备log_device_placement = False  # 是否打印设备分配日志print("Loading data...")w2vModel = create_word2vec.load_wordVectors(word2vec_path)labels_set = files_processing.read_txt(labels_file)labels_nums = len(labels_set)train_file_list = create_batch_data.get_file_list(file_dir=train_dir, postfix='*.npy')train_batch = create_batch_data.get_data_batch(train_file_list, labels_nums=labels_nums, batch_size=batch_size,shuffle=False, one_hot=True)val_file_list = create_batch_data.get_file_list(file_dir=val_dir, postfix='*.npy')val_batch = create_batch_data.get_data_batch(val_file_list, labels_nums=labels_nums, batch_size=batch_size,shuffle=False, one_hot=True)print("train data info *****************************")train_nums=create_word2vec.info_npy(train_file_list)print("val data   info *****************************")val_nums = create_word2vec.info_npy(val_file_list)print("labels_set info *****************************")files_processing.info_labels_set(labels_set)# Trainingwith tf.Graph().as_default():session_conf = tf.ConfigProto(allow_soft_placement = allow_soft_placement,log_device_placement = log_device_placement)sess = tf.Session(config = session_conf)with sess.as_default():cnn = TextCNN(sequence_length = max_sentence_length,num_classes = labels_nums,embedding_size = embedding_dim,filter_sizes = filter_sizes,num_filters = num_filters,l2_reg_lambda = l2_reg_lambda)# Define Training procedureglobal_step = tf.Variable(0, name="global_step", trainable=False)optimizer = tf.train.AdamOptimizer(learning_rate=base_lr)# optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9)grads_and_vars = optimizer.compute_gradients(cnn.loss)train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)# Keep track of gradient values and sparsity (optional)grad_summaries = []for g, v in grads_and_vars:if g is not None:grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))grad_summaries.append(grad_hist_summary)grad_summaries.append(sparsity_summary)grad_summaries_merged = tf.summary.merge(grad_summaries)# Output directory for models and summariesprint("Writing to {}\n".format(out_dir))# Summaries for loss and accuracyloss_summary = tf.summary.scalar("loss", cnn.loss)acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)# Train Summariestrain_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])train_summary_dir = os.path.join(out_dir, "summaries", "train")train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)# Dev summariesdev_summary_op = tf.summary.merge([loss_summary, acc_summary])dev_summary_dir = os.path.join(out_dir, "summaries", "dev")dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)# Checkpoint directory. Tensorflow assumes this directory already exists so we need to create itcheckpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))checkpoint_prefix = os.path.join(checkpoint_dir, "model")if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)# Initialize all variablessess.run(tf.global_variables_initializer())def train_step(x_batch, y_batch):"""A single training step"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: dropout_keep_prob}_, step, summaries, loss, accuracy = sess.run([train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],feed_dict)if step % log_step==0:print("training: step {}, loss {:g}, acc {:g}".format(step, loss, accuracy))train_summary_writer.add_summary(summaries, step)def dev_step(x_batch, y_batch, writer=None):"""Evaluates model on a dev set"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: 1.0}step, summaries, loss, accuracy = sess.run([global_step, dev_summary_op, cnn.loss, cnn.accuracy],feed_dict)if writer:writer.add_summary(summaries, step)return loss, accuracyfor i in range(max_steps):train_batch_data, train_batch_label = create_batch_data.get_next_batch(train_batch)train_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, train_batch_data)train_step(train_batch_data, train_batch_label)current_step = tf.train.global_step(sess, global_step)if current_step % val_step == 0:val_losses = []val_accs = []# for k in range(int(val_nums/batch_size)):for k in range(100):val_batch_data, val_batch_label = create_batch_data.get_next_batch(val_batch)val_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, val_batch_data)val_loss, val_acc=dev_step(val_batch_data, val_batch_label, writer=dev_summary_writer)val_losses.append(val_loss)val_accs.append(val_acc)mean_loss = np.array(val_losses, dtype=np.float32).mean()mean_acc = np.array(val_accs, dtype=np.float32).mean()print("--------Evaluation:step {}, loss {:g}, acc {:g}".format(current_step, mean_loss, mean_acc))if current_step % snapshot == 0:path = saver.save(sess, checkpoint_prefix, global_step=current_step)print("Saved model checkpoint to {}\n".format(path))def main():# Data preprocesslabels_file = 'data/THUCNews_labels.txt'word2vec_path = "../word2vec/models/THUCNews_word2Vec/THUCNews_word2Vec_128.model"max_steps = 100000  # 迭代次数batch_size = 128out_dir = "./models"  # 模型ckpt和summaries输出的目录train_dir = './data/train_data'val_dir = './data/val_data'train(train_dir=train_dir,val_dir=val_dir,labels_file=labels_file,word2vec_path=word2vec_path,batch_size=batch_size,max_steps=max_steps,log_step=50,val_step=500,snapshot=1000,out_dir=out_dir)if __name__=="__main__":main()

五、测试过程

这里提供两种测试方法:

(1):text_predict(files_list, labels_file, models_path, word2vec_path, batch_size)

该方法,可以直接测试待分类的中文文本

(2):batch_predict(val_dir,labels_file,models_path,word2vec_path,batch_size)

该方法,用于批量测试,val_dir目录保存的是测试数据的npy文件,这些文件都是上面用word2vec词向量处理THUCNews数据文件。

#! /usr/bin/env python
# encoding: utf-8import tensorflow as tf
import numpy as np
import os
from text_cnn import TextCNN
from utils import create_batch_data, create_word2vec, files_processing
import mathdef text_predict(files_list, labels_file, models_path, word2vec_path, batch_size):'''预测...:param val_dir:   val数据目录:param labels_file:  labels文件目录:param models_path:  模型文件:param word2vec_path: 词向量模型文件:param batch_size: batch size:return:'''max_sentence_length = 300embedding_dim = 128filter_sizes = [3, 4, 5, 6]num_filters = 200  # Number of filters per filter sizel2_reg_lambda = 0.0  # "L2 regularization lambda (default: 0.0)print("Loading data...")w2vModel = create_word2vec.load_wordVectors(word2vec_path)labels_set = files_processing.read_txt(labels_file)labels_nums = len(labels_set)sample_num=len(files_list)labels_list=[-1]labels_list=labels_list*sample_numwith tf.Graph().as_default():sess = tf.Session()with sess.as_default():cnn = TextCNN(sequence_length = max_sentence_length,num_classes = labels_nums,embedding_size = embedding_dim,filter_sizes = filter_sizes,num_filters = num_filters,l2_reg_lambda = l2_reg_lambda)# Initialize all variablessess.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess, models_path)def pred_step(x_batch):"""predictions model on a dev set"""feed_dict = {cnn.input_x: x_batch,cnn.dropout_keep_prob: 1.0}pred = sess.run([cnn.predictions],feed_dict)return predbatchNum = int(math.ceil(1.0 * sample_num / batch_size))for i in range(batchNum):start = i * batch_sizeend = min((i + 1) * batch_size, sample_num)batch_files = files_list[start:end]# 读取文件内容,字词分割batch_content= files_processing.read_files_list_to_segment(batch_files,max_sentence_length,padding_token='<PAD>')# [1]将字词转为索引矩阵,再映射为词向量batch_indexMat = create_word2vec.word2indexMat(w2vModel, batch_content, max_sentence_length)val_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, batch_indexMat)# [2]直接将字词映射为词向量# val_batch_data = create_word2vec.word2vector_lookup(w2vModel,batch_content)pred=pred_step(val_batch_data)pred=pred[0].tolist()pred=files_processing.labels_decoding(pred,labels_set)for k,file in enumerate(batch_files):print("{}, pred:{}".format(file,pred[k]))def batch_predict(val_dir,labels_file,models_path,word2vec_path,batch_size):'''预测...:param val_dir:   val数据目录:param labels_file:  labels文件目录:param models_path:  模型文件:param word2vec_path: 词向量模型文件:param batch_size: batch size:return:'''max_sentence_length = 300embedding_dim = 128filter_sizes = [3, 4, 5, 6]num_filters = 200  # Number of filters per filter sizel2_reg_lambda = 0.0  # "L2 regularization lambda (default: 0.0)print("Loading data...")w2vModel = create_word2vec.load_wordVectors(word2vec_path)labels_set = files_processing.read_txt(labels_file)labels_nums = len(labels_set)val_file_list = create_batch_data.get_file_list(file_dir=val_dir, postfix='*.npy')val_batch = create_batch_data.get_data_batch(val_file_list, labels_nums=labels_nums, batch_size=batch_size,shuffle=False, one_hot=True)print("val data   info *****************************")val_nums = create_word2vec.info_npy(val_file_list)print("labels_set info *****************************")files_processing.info_labels_set(labels_set)# Trainingwith tf.Graph().as_default():sess = tf.Session()with sess.as_default():cnn = TextCNN(sequence_length = max_sentence_length,num_classes = labels_nums,embedding_size = embedding_dim,filter_sizes = filter_sizes,num_filters = num_filters,l2_reg_lambda = l2_reg_lambda)# Initialize all variablessess.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess, models_path)def dev_step(x_batch, y_batch):"""Evaluates model on a dev set"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: 1.0}loss, accuracy = sess.run([cnn.loss, cnn.accuracy],feed_dict)return loss, accuracyval_losses = []val_accs = []for k in range(int(val_nums/batch_size)):# for k in range(int(10)):val_batch_data, val_batch_label = create_batch_data.get_next_batch(val_batch)val_batch_data = create_word2vec.indexMat2vector_lookup(w2vModel, val_batch_data)val_loss, val_acc=dev_step(val_batch_data, val_batch_label)val_losses.append(val_loss)val_accs.append(val_acc)print("--------Evaluation:step {}, loss {:g}, acc {:g}".format(k, val_loss, val_acc))mean_loss = np.array(val_losses, dtype=np.float32).mean()mean_acc = np.array(val_accs, dtype=np.float32).mean()print("--------Evaluation:step {}, mean loss {:g}, mean acc {:g}".format(k, mean_loss, mean_acc))def main():# Data preprocesslabels_file = 'data/THUCNews_labels.txt'# word2vec_path = 'word2vec/THUCNews_word2vec300.model'word2vec_path = "../word2vec/models/THUCNews_word2Vec/THUCNews_word2Vec_128.model"models_path='models/checkpoints/model-30000'batch_size = 128val_dir = './data/val_data'batch_predict(val_dir=val_dir,labels_file=labels_file,models_path=models_path,word2vec_path=word2vec_path,batch_size=batch_size)test_path='/home/ubuntu/project/tfTest/THUCNews/my_test'files_list = files_processing.get_files_list(test_path,postfix='*.txt')text_predict(files_list, labels_file, models_path, word2vec_path, batch_size)if __name__=="__main__":main()

TensorFlow使用CNN实现中文文本分类相关推荐

  1. TensorFlow – 使用CNN进行中文文本分类

    使用卷积神经网络(CNN)处理自然语言处理(NLP)中的文本分类问题.本文将结合TensorFlow代码介绍: 词嵌入 填充 Embedding 卷积层 卷积(tf.nn.conv1d) 池化(poo ...

  2. 【NLP】TensorFlow实现CNN用于中文文本分类

    代码基于 dennybritz/cnn-text-classification-tf 及 clayandgithub/zh_cnn_text_classify 参考文章 了解用于NLP的卷积神经网络( ...

  3. python中文文本分析_基于CNN的中文文本分类算法(可应用于垃圾邮件过滤、情感分析等场景)...

    基于cnn的中文文本分类算法 简介 参考IMPLEMENTING A CNN FOR TEXT CLASSIFICATION IN TENSORFLOW实现的一个简单的卷积神经网络,用于中文文本分类任 ...

  4. Tensorflow使用Char-CNN实现中文文本分类(1)

    前言 在之前的中文文本分类中,使用了LSTM来进行模型的构建(详情参考: Tensorflow使用LSTM实现中文文本分类(2).使用numpy实现LSTM和RNN网络的前向传播过程).除了使用LST ...

  5. 详解CNN实现中文文本分类过程

    摘要:本文主要讲解CNN实现中文文本分类的过程,并与贝叶斯.决策树.逻辑回归.随机森林.KNN.SVM等分类算法进行对比. 本文分享自华为云社区<[Python人工智能] 二十一.Word2Ve ...

  6. Tensorflow使用LSTM实现中文文本分类(1)

    前言 使用Tensorflow,利用LSTM进行中文文本的分类. 数据集格式如下: ''' 体育 马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的 ...

  7. 基于cnn的中文文本分类

    资源下载地址:https://download.csdn.net/download/sheziqiong/86799359 资源下载地址:https://download.csdn.net/downl ...

  8. CNN在中文文本分类的应用

    深度学习近一段时间以来在图像处理和NLP任务上都取得了不俗的成绩.通常,图像处理的任务是借助CNN来完成的,其特有的卷积.池化结构能够提取图像中各种不同程度的纹理.结构,并最终结合全连接网络实现信息的 ...

  9. 基于CNN的中文文本分类算法(可应用于垃圾文本过滤、情感分析等场景)

    向AI转型的程序员都关注了这个号

最新文章

  1. 模型压缩95%,MIT韩松等人提出新型Lite Transformer
  2. 食物链 poj 1182
  3. 1的个数 itoa函数使用
  4. 计算机视觉开源库OpenCV之照明和色彩空间
  5. Microsoft RTF栈溢出漏洞(CVE-2010-3333)漏洞分析
  6. python笔记:数组的一些操作
  7. C++与C#类型对应关系
  8. 【数据竞赛】百赛百试,十试九灵的特征筛选策略-Pearson Correlation
  9. loading linux img2a,嵌入式Linux中initrd的应用--浅析ramdisk、ramfs、initrd和initramfs
  10. 消息称谷歌Pixel系列手机默认禁用美颜功能
  11. 前端-requests-flask对应关系 form
  12. Gitbook中有序列表不能正常显示的解决办法
  13. OSAL操作系统抽象层
  14. js制作简易班级抽签程序
  15. html文件svchost,解决html文件的DropFileName = svchost.exe木马
  16. IT博客大赛,快来围观!【人艰不拆】
  17. Tagxedo在线云词成像制作工具
  18. 董树义 近代微波测量技术_本土IC领域又一关键技术获得突破!
  19. 费曼算法(Feynman algorithm)
  20. 【LOJ6005】【网络流24题】最长递增子序列

热门文章

  1. office 自动编号系列碰到问题小解
  2. Xcode 4.4 的新特性 | LLVM 4.0 的新语法
  3. underscore.js _.map[Collections]
  4. TcpSocket编程与Event编写学习的好例子
  5. mysql分库分表备份脚本[转帖]
  6. 【PyCharm疑问】在pycharm中带有中文时,有时会导致程序判断错误,是何原因?...
  7. Xamarin.Android 使用ListView绑定数据
  8. 百度面试测试开发工程师内容
  9. 8个前沿的 HTML5 CSS3 效果【附源码下载】
  10. android NDK 知识汇总