代码基于 dennybritz/cnn-text-classification-tf 及 clayandgithub/zh_cnn_text_classify
参考文章 了解用于NLP的卷积神经网络(译) 及 TensorFlow实现CNN用于文本分类(译)
本文完整代码 - Widiot/cnn-zh-text-classification

1. 项目结构

以下是完整的目录结构示例,包括运行之后形成的目录和文件:

cnn-zh-text-classification/data/maildata/cleaned_ham_5000.utf8cleaned_spam_5000.utf8ham_5000.utf8spam_5000.utf8runs/1517572900/checkpoints/...summaries/...prediction.csvvocab.gitignoreREADME.mddata_helpers.pyeval.pytext_cnn.pytrain.py

各个目录及文件的作用如下:

  • data 目录用于存放数据
  • maildata 目录用于存放邮件文件,目前有四个文件,ham_5000.utf8 及 spam_5000.utf8 分别为正常邮件和垃圾邮件,带 cleaned 前缀的文件为清洗后的数据
  • runs 目录用于存放每次运行产生的数据,以时间戳为目录名
  • 1517572900 目录用于存放每次运行产生的检查点、日志摘要、词汇文件及评估产生的结果
  • data_helpers.py 用于处理数据
  • eval.py 用于评估模型
  • text_cnn.py 是 CNN 模型类
  • train.py 用于训练模型

2. 数据

2.1 数据格式

以分类正常邮件和垃圾邮件为例,如下是邮件数据的例子:

# 正常邮件
他们自己也是刚到北京不久 跟在北京读书然后留在这里工作的还不一样 难免会觉得还有好多东西没有安顿下来 然后来了之后还要带着四处旅游甚么什么的 却是花费很大 你要不带着出去玩,还真不行 这次我小表弟来北京玩,花了好多钱 就因为本来预定的几个地方因为某种原因没去 舅妈似乎就很不开心 结果就是钱全白花了 人家也是牢骚一肚子 所以是自己找出来的困难 退一万步说 婆婆来几个月
发文时难免欠点理智 我不怎么灌水,没想到上了十大了,拍的还挺欢,呵呵 写这个贴子,是由于自己太郁闷了,其时,我最主要的目的,是觉得,水木上肯定有一些嫁农村GG但现在很幸福的JJMM.我目前遇到的问题,我的确不知道怎么解决,所以发上来,问一下成功解决这类问题的建议.因为没有相同的经历和体会,是不会理解的,我在我身边就找不到可行的建议. 结果,无心得罪了不少人.呵呵,可能我想了太多关于城乡差别的问题,意识的比较深刻,所以不经意写了出来.
所以那些贵族1就要找一些特定的东西来章显自己的与众不同 这个东西一定是穷人买不起的,所以好多奢侈品也就营运诞生了 想想也是,他们要表也没有啊, 我要是香paris hilton那么有钱,就每天一个牌子的表,一个牌子的时装,一个牌子的汽车,哈哈,。。。要得就是这个派 俺连表都不用, 带手上都累赘, 上课又不能开手机, 所以俺只好经常退一下ppt去看右下脚的时间. 其实 贵族又不用赶时间, 要知道精确时间做啥? 表走的# 垃圾邮件
中信(国际)电子科技有限公司推出新产品: 升职步步高、做生意发大财、连找情人都用的上,详情进入 网  址:  http://www.usa5588.com/ccc 电话:020-33770208   服务热线:013650852999
以下不能正确显示请点此 IFRAME: http://www.ewzw.com/bbs/viewthread.php?tid=3809&fpage=1
尊敬的公司您好!打扰之处请见谅! 我深圳公司愿在互惠互利、诚信为本代开3厘---2点国税、地税等发票。增值税和海关缴款书就以2点---7点来代开。手机:13510631209       联系人:邝先生  邮箱:ao998@163.com     祥细资料合作告知,希望合作。谢谢!!

每个句子单独一行,正常邮件和垃圾邮件的数据分别存放在两个文件中。

2.2 数据处理

数据处理 data_helpers.py 的代码如下,与所参考的代码不同的是:

  • load_data_and_labels():将函数的参数修改为以逗号分隔的数据文件的路径字符串,比如 './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8',这样可以读取多个类别的数据文件以实现多分类问题
  • read_and_clean_zh_file():将函数的 output_cleaned_file 修改为 boolean 类型,控制是否保存清洗后的数据,并在函数中判断,如果已经存在清洗后的数据文件则直接加载,否则进行清洗并选择保存

其他函数与所参考的代码相比变动不大:

import numpy as np
import re
import osdef load_data_and_labels(data_files):"""1. 加载所有数据和标签2. 可以进行多分类,每个类别的数据单独放在一个文件中2. 保存处理后的数据"""data_files = data_files.split(',')num_data_file = len(data_files)assert num_data_file > 1x_text = []y = []for i, data_file in enumerate(data_files):# 将数据放在一起data = read_and_clean_zh_file(data_file, True)x_text += data# 形成数据对应的标签label = [0] * num_data_filelabel[i] = 1labels = [label for _ in data]y += labelsreturn [x_text, np.array(y)]def read_and_clean_zh_file(input_file, output_cleaned_file=False):"""1. 读取中文文件并清洗句子2. 可以将清洗后的结果保存到文件3. 如果已经存在经过清洗的数据文件则直接加载"""data_file_path, file_name = os.path.split(input_file)output_file = os.path.join(data_file_path, 'cleaned_' + file_name)if os.path.exists(output_file):lines = list(open(output_file, 'r').readlines())lines = [line.strip() for line in lines]else:lines = list(open(input_file, 'r').readlines())lines = [clean_str(seperate_line(line)) for line in lines]if output_cleaned_file:with open(output_file, 'w') as f:for line in lines:f.write(line + '\n')return linesdef clean_str(string):"""1. 将除汉字外的字符转为一个空格2. 将连续的多个空格转为一个空格3. 除去句子前后的空格字符"""string = re.sub(r'[^\u4e00-\u9fff]', ' ', string)string = re.sub(r'\s{2,}', ' ', string)return string.strip()def seperate_line(line):"""将句子中的每个字用空格分隔开"""return ''.join([word + ' ' for word in line])def batch_iter(data, batch_size, num_epochs, shuffle=True):'''生成一个batch迭代器'''data = np.array(data)data_size = len(data)num_batches_per_epoch = int((data_size - 1) / batch_size) + 1for epoch in range(num_epochs):if shuffle:shuffle_indices = np.random.permutation(np.arange(data_size))shuffled_data = data[shuffle_indices]else:shuffled_data = datafor batch_num in range(num_batches_per_epoch):start_idx = batch_num * batch_sizeend_idx = min((batch_num + 1) * batch_size, data_size)yield shuffled_data[start_idx:end_idx]if __name__ == '__main__':data_files = './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8'x_text, y = load_data_and_labels(data_files)print(x_text)

2.3 清洗标准

将原始数据进行清洗,仅保留汉字,并把每个汉字用一个空格分隔开,各个类别清洗后的数据分别存放在 cleaned 前缀的文件中,清洗后的数据格式如下:

本 公 司 有 部 分 普 通 发 票 商 品 销 售 发 票 增 值 税 发 票 及 海 关 代 征 增 值 税 专 用 缴 款 书 及 其 它 服 务 行 业 发 票 公 路 内 河 运 输 发 票 可 以 以 低 税 率 为 贵 公 司 代 开 本 公 司 具 有 内 外 贸 生 意 实 力 保 证 我 司 开 具 的 票 据 的 真 实 性 希 望 可 以 合 作 共 同 发 展 敬 侯 您 的 来 电 洽 谈 咨 询 联 系 人 李 先 生 联 系 电 话 如 有 打 扰 望 谅 解 祝 商 琪

3. 模型

CNN 模型类 text_cnn.py 的代码如下,修改的地方如下:

  • 将 concat 和 reshape 的操作结点放在 concat 命名空间下,这样在 TensorBoard 中的节点图更加清晰合理
  • 将计算损失值的操作修改为通过 collection 进行,并只计算 W 的 L2 损失值,删去了计算 b 的 L2 损失值的代码
import tensorflow as tf
import numpy as npclass TextCNN(object):"""字符级CNN文本分类词嵌入层->卷积层->池化层->softmax层"""def __init__(self,sequence_length,num_classes,vocab_size,embedding_size,filter_sizes,num_filters,l2_reg_lambda=0.0):# 输入,输出,dropout的占位符self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name='input_x')self.input_y = tf.placeholder(tf.float32, [None, num_classes], name='input_y')self.dropout_keep_prob = tf.placeholder(tf.float32, name='dropout_keep_prob')# l2正则化损失值(可选)#l2_loss = tf.constant(0.0)# 词嵌入层# W为词汇表,大小为0~词汇总数,索引对应不同的字,每个字映射为128维的数组,比如[3800,128]with tf.name_scope('embedding'):self.W = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),name='W')self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)# 卷积层和池化层# 为3,4,5分别创建128个过滤器,总共3×128个过滤器# 过滤器形状为[3,128,1,128],表示一次能过滤三个字,最后形成188×128的特征向量# 池化核形状为[1,188,1,1],128维中的每一维表示该句子的不同向量表示,池化即从每一维中提取最大值表示该维的特征# 池化得到的特征向量为128维pooled_outputs = []for i, filter_size in enumerate(filter_sizes):with tf.name_scope('conv-maxpool-%s' % filter_size):# 卷积层filter_shape = [filter_size, embedding_size, 1, num_filters]W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name='W')b = tf.Variable(tf.constant(0.1, shape=[num_filters]), name='b')conv = tf.nn.conv2d(self.embedded_chars_expanded,W,strides=[1, 1, 1, 1],padding='VALID',name='conv')# ReLU激活h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')# 池化层pooled = tf.nn.max_pool(h,ksize=[1, sequence_length - filter_size + 1, 1, 1],strides=[1, 1, 1, 1],padding='VALID',name='pool')pooled_outputs.append(pooled)# 组合所有池化后的特征# 将三个过滤器得到的特征向量组合成一个384维的特征向量num_filters_total = num_filters * len(filter_sizes)with tf.name_scope('concat'):self.h_pool = tf.concat(pooled_outputs, 3)self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])# dropoutwith tf.name_scope('dropout'):self.h_drop = tf.nn.dropout(self.h_pool_flat,self.dropout_keep_prob)# 全连接层# 分数和预测结果with tf.name_scope('output'):W = tf.Variable(tf.truncated_normal([num_filters_total, num_classes], stddev=0.1),name='W')b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name='b')if l2_reg_lambda:W_l2_loss = tf.contrib.layers.l2_regularizer(l2_reg_lambda)(W)tf.add_to_collection('losses', W_l2_loss)self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name='scores')self.predictions = tf.argmax(self.scores, 1, name='predictions')# 计算交叉损失熵with tf.name_scope('loss'):mse_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.scores, labels=self.input_y))tf.add_to_collection('losses', mse_loss)self.loss = tf.add_n(tf.get_collection('losses'))# 正确率with tf.name_scope('accuracy'):correct_predictions = tf.equal(self.predictions,tf.argmax(self.input_y, 1))self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, 'float'), name='accuracy')

最终的神经网络结构图在 TensorBoard 中的样式如下:

4. 训练

训练模型的 train.py 代码如下,修改的地方如下:

  • 将数据文件的路径参数修改为一个用逗号分隔开的字符串,便于实现多分类问题
  • tf.flags 重命名为 flags,更加简洁
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn# 参数
# ==================================================flags = tf.flags# 数据加载参数
flags.DEFINE_float('dev_sample_percentage', 0.1,'Percentage of the training data to use for validation')
flags.DEFINE_string('data_files','./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8','Comma-separated data source files')# 模型超参数
flags.DEFINE_integer('embedding_dim', 128,'Dimensionality of character embedding (default: 128)')
flags.DEFINE_string('filter_sizes', '3,4,5','Comma-separated filter sizes (default: "3,4,5")')
flags.DEFINE_integer('num_filters', 128,'Number of filters per filter size (default: 128)')
flags.DEFINE_float('dropout_keep_prob', 0.5,'Dropout keep probability (default: 0.5)')
flags.DEFINE_float('l2_reg_lambda', 0.0,'L2 regularization lambda (default: 0.0)')# 训练参数
flags.DEFINE_integer('batch_size', 64, 'Batch Size (default: 64)')
flags.DEFINE_integer('num_epochs', 10,'Number of training epochs (default: 10)')
flags.DEFINE_integer('evaluate_every', 100,'Evaluate model on dev set after this many steps (default: 100)')
flags.DEFINE_integer('checkpoint_every', 100,'Save model after this many steps (default: 100)')
flags.DEFINE_integer('num_checkpoints', 5,'Number of checkpoints to store (default: 5)')# 其他参数
flags.DEFINE_boolean('allow_soft_placement', True,'Allow device soft device placement')
flags.DEFINE_boolean('log_device_placement', False,'Log placement of ops on devices')FLAGS = flags.FLAGS
FLAGS._parse_flags()
print('\nParameters:')
for attr, value in sorted(FLAGS.__flags.items()):print('{}={}'.format(attr.upper(), value))
print('')# 数据准备
# ==================================================# 加载数据
print('Loading data...')
x_text, y = data_helpers.load_data_and_labels(FLAGS.data_files)# 建立词汇表
max_document_length = max([len(x.split(' ')) for x in x_text])
vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
x = np.array(list(vocab_processor.fit_transform(x_text)))# 随机混淆数据
np.random.seed(10)
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices]# 划分train/test数据集
# TODO: 这种做法比较暴力,应该用交叉验证
dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]del x, y, x_shuffled, y_shuffledprint('Vocabulary Size: {:d}'.format(len(vocab_processor.vocabulary_)))
print('Train/Dev split: {:d}/{:d}'.format(len(y_train), len(y_dev)))
print('')# 训练
# ==================================================with tf.Graph().as_default():session_conf = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement,log_device_placement=FLAGS.log_device_placement)sess = tf.Session(config=session_conf)with sess.as_default():cnn = TextCNN(sequence_length=x_train.shape[1],num_classes=y_train.shape[1],vocab_size=len(vocab_processor.vocabulary_),embedding_size=FLAGS.embedding_dim,filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),num_filters=FLAGS.num_filters,l2_reg_lambda=FLAGS.l2_reg_lambda)# 定义训练相关操作global_step = tf.Variable(0, name='global_step', trainable=False)optimizer = tf.train.AdamOptimizer(1e-3)grads_and_vars = optimizer.compute_gradients(cnn.loss)train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)# 跟踪梯度值和稀疏性(可选)grad_summaries = []for g, v in grads_and_vars:if g is not None:grad_hist_summary = tf.summary.histogram('{}/grad/hist'.format(v.name), g)sparsity_summary = tf.summary.scalar('{}/grad/sparsity'.format(v.name), tf.nn.zero_fraction(g))grad_summaries.append(grad_hist_summary)grad_summaries.append(sparsity_summary)grad_summaries_merged = tf.summary.merge(grad_summaries)# 模型和摘要的保存目录timestamp = str(int(time.time()))out_dir = os.path.abspath(os.path.join(os.path.curdir, 'runs', timestamp))print('\nWriting to {}\n'.format(out_dir))# 损失值和正确率的摘要loss_summary = tf.summary.scalar('loss', cnn.loss)acc_summary = tf.summary.scalar('accuracy', cnn.accuracy)# 训练摘要train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])train_summary_dir = os.path.join(out_dir, 'summaries', 'train')train_summary_writer = tf.summary.FileWriter(train_summary_dir,sess.graph)# 开发摘要dev_summary_op = tf.summary.merge([loss_summary, acc_summary])dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev')dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)# 检查点目录,默认存在checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints'))checkpoint_prefix = os.path.join(checkpoint_dir, 'model')if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)# 写入词汇表文件vocab_processor.save(os.path.join(out_dir, 'vocab'))# 初始化变量sess.run(tf.global_variables_initializer())def train_step(x_batch, y_batch):"""一个训练步骤"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: FLAGS.dropout_keep_prob}_, step, summaries, loss, accuracy = sess.run([train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy], feed_dict)time_str = datetime.datetime.now().isoformat()print('{}: step {}, loss {:g}, acc {:g}'.format(time_str, step, loss, accuracy))train_summary_writer.add_summary(summaries, step)def dev_step(x_batch, y_batch, writer=None):"""在开发集上验证模型"""feed_dict = {cnn.input_x: x_batch,cnn.input_y: y_batch,cnn.dropout_keep_prob: 1.0}step, summaries, loss, accuracy = sess.run([global_step, dev_summary_op, cnn.loss, cnn.accuracy],feed_dict)time_str = datetime.datetime.now().isoformat()print('{}: step {}, loss {:g}, acc {:g}'.format(time_str, step, loss, accuracy))if writer:writer.add_summary(summaries, step)# 生成batchesbatches = data_helpers.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)# 迭代训练每个batchfor batch in batches:x_batch, y_batch = zip(*batch)train_step(x_batch, y_batch)current_step = tf.train.global_step(sess, global_step)if current_step % FLAGS.evaluate_every == 0:print('\nEvaluation:')dev_step(x_dev, y_dev, writer=dev_summary_writer)print('')if current_step % FLAGS.checkpoint_every == 0:path = saver.save(sess, checkpoint_prefix, global_step=current_step)print('Saved model checkpoint to {}\n'.format(path))

训练过程的输出如下:

Parameters:
ALLOW_SOFT_PLACEMENT=True
BATCH_SIZE=64
CHECKPOINT_EVERY=100
DATA_FILES=./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8
DEV_SAMPLE_PERCENTAGE=0.1
DROPOUT_KEEP_PROB=0.5
EMBEDDING_DIM=128
EVALUATE_EVERY=100
FILTER_SIZES=3,4,5
L2_REG_LAMBDA=0.0
LOG_DEVICE_PLACEMENT=False
NUM_CHECKPOINTS=5
NUM_EPOCHS=10
NUM_FILTERS=128Loading data...
Vocabulary Size: 3628
Train/Dev split: 9001/1000Writing to /home/widiot/workspace/tensorflow-ws/tensorflow-gpu/text-classification/cnn-zh-text-classification/runs/15177341862018-02-04T16:50:03.709761: step 1, loss 5.36006, acc 0.46875
2018-02-04T16:50:03.786874: step 2, loss 4.61227, acc 0.390625
2018-02-04T16:50:03.857796: step 3, loss 2.50795, acc 0.5625
...
2018-02-04T16:50:10.819505: step 98, loss 0.622567, acc 0.90625
2018-02-04T16:50:10.899140: step 99, loss 1.10189, acc 0.875
2018-02-04T16:50:10.983192: step 100, loss 0.359102, acc 0.9375Evaluation:
2018-02-04T16:50:11.848838: step 100, loss 0.132987, acc 0.961Saved model checkpoint to /home/widiot/workspace/tensorflow-ws/tensorflow-gpu/text-classification/cnn-zh-text-classification/runs/1517734186/checkpoints/model-1002018-02-04T16:50:12.019749: step 101, loss 0.512838, acc 0.890625
2018-02-04T16:50:12.100965: step 102, loss 0.164333, acc 0.96875
2018-02-04T16:50:12.184899: step 103, loss 0.145344, acc 0.921875
...

训练之后会在 runs 目录下生成对应的数据目录,包含检查点、日志摘要和词汇文件。

训练时的正确率变化如下:

5. 评估

评估模型的 eval.py 代码如下,修改的地方如下:

  • 同 train.py 将数据文件路径参数修改为逗号分隔开的字符串,便于实现多分类问题
  • 添加对自己未经处理的数据的清洗操作,便于直接分类评估数据
import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
import csv
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn# 参数
# ==================================================flags = tf.flags# 数据参数
flags.DEFINE_string('data_files','./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8','Comma-separated data source files')# 评估参数
flags.DEFINE_integer('batch_size', 64, 'Batch Size (default: 64)')
flags.DEFINE_string('checkpoint_dir', './runs/1517572900/checkpoints','Checkpoint directory from training run')
flags.DEFINE_boolean('eval_train', False, 'Evaluate on all training data')# 其他参数
flags.DEFINE_boolean('allow_soft_placement', True,'Allow device soft device placement')
flags.DEFINE_boolean('log_device_placement', False,'Log placement of ops on devices')FLAGS = flags.FLAGS
FLAGS._parse_flags()
print('\nParameters:')
for attr, value in sorted(FLAGS.__flags.items()):print('{}={}'.format(attr.upper(), value))
print('')# 加载训练数据或者修改测试句子
if FLAGS.eval_train:x_raw, y_test = data_helpers.load_data_and_labels(FLAGS.data_files)y_test = np.argmax(y_test, axis=1)
else:x_raw = ['亲爱的CFer,您获得了英雄级道具。还有全新英雄级道具在等你来拿,立即登录游戏领取吧!','第一个build错误的解决方法能再说一下吗,我还是不懂怎么解决', '请联系张经理获取最新资讯']y_test = [0, 1, 0]# 对自己的数据的处理
x_raw_cleaned = [data_helpers.clean_str(data_helpers.seperate_line(line)) for line in x_raw
]
print(x_raw_cleaned)# 将数据转为词汇表的索引
vocab_path = os.path.join(FLAGS.checkpoint_dir, '..', 'vocab')
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)
x_test = np.array(list(vocab_processor.transform(x_raw_cleaned)))print('\nEvaluating...\n')# 评估
# ==================================================checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():session_conf = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement,log_device_placement=FLAGS.log_device_placement)sess = tf.Session(config=session_conf)with sess.as_default():# 加载保存的元图和变量saver = tf.train.import_meta_graph('{}.meta'.format(checkpoint_file))saver.restore(sess, checkpoint_file)# 通过名字从图中获取占位符input_x = graph.get_operation_by_name('input_x').outputs[0]# input_y = graph.get_operation_by_name('input_y').outputs[0]dropout_keep_prob = graph.get_operation_by_name('dropout_keep_prob').outputs[0]# 我们想要评估的tensorspredictions = graph.get_operation_by_name('output/predictions').outputs[0]# 生成每个轮次的batchesbatches = data_helpers.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)# 收集预测值all_predictions = []for x_test_batch in batches:batch_predictions = sess.run(predictions, {input_x: x_test_batch,dropout_keep_prob: 1.0})all_predictions = np.concatenate([all_predictions, batch_predictions])# 如果提供了标签则打印正确率
if y_test is not None:correct_predictions = float(sum(all_predictions == y_test))print('\nTotal number of test examples: {}'.format(len(y_test)))print('Accuracy: {:g}'.format(correct_predictions / float(len(y_test))))# 保存评估为csv
predictions_human_readable = np.column_stack((np.array(x_raw),all_predictions))
out_path = os.path.join(FLAGS.checkpoint_dir, '..', 'prediction.csv')
print('Saving evaluation to {0}'.format(out_path))
with open(out_path, 'w') as f:csv.writer(f).writerows(predictions_human_readable)

评估过程中的输出如下:

Parameters:
ALLOW_SOFT_PLACEMENT=True
BATCH_SIZE=64
CHECKPOINT_DIR=./runs/1517572900/checkpoints
DATA_FILES=./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8
EVAL_TRAIN=False
LOG_DEVICE_PLACEMENT=False['亲 爱 的 您 获 得 了 英 雄 级 道 具 还 有 全 新 英 雄 级 道 具 在 等 你 来 拿 立 即 登 录 游 戏 领 取 吧', '第 一 个 错 误 的 解 决 方 法 能 再 说 一 下 吗 我 还 是 不 懂 怎 么 解 决', '请 联 系 张 经 理 获 取 最 新 资 讯']Evaluating...Total number of test examples: 3
Accuracy: 1
Saving evaluation to ./runs/1517572900/checkpoints/../prediction.csv

评估之后会在 runs 目录对应的文件夹下生成一个 prediction.csv 文件,如下所示:

亲爱的CFer,您获得了英雄级道具。还有全新英雄级道具在等你来拿,立即登录游戏领取吧!,0.0
第一个build错误的解决方法能再说一下吗,我还是不懂怎么解决,1.0
请联系张经理获取最新资讯,0.0

【NLP】TensorFlow实现CNN用于中文文本分类相关推荐

  1. TensorFlow使用CNN实现中文文本分类

    TensorFlow使用CNN实现中文文本分类 读研期间使用过TensorFlow实现过简单的CNN情感分析(分类),当然这是比较low的二分类情况,后来进行多分类情况.但之前的学习基本上都是在英文词 ...

  2. TensorFlow – 使用CNN进行中文文本分类

    使用卷积神经网络(CNN)处理自然语言处理(NLP)中的文本分类问题.本文将结合TensorFlow代码介绍: 词嵌入 填充 Embedding 卷积层 卷积(tf.nn.conv1d) 池化(poo ...

  3. python中文文本分析_基于CNN的中文文本分类算法(可应用于垃圾邮件过滤、情感分析等场景)...

    基于cnn的中文文本分类算法 简介 参考IMPLEMENTING A CNN FOR TEXT CLASSIFICATION IN TENSORFLOW实现的一个简单的卷积神经网络,用于中文文本分类任 ...

  4. Tensorflow使用Char-CNN实现中文文本分类(1)

    前言 在之前的中文文本分类中,使用了LSTM来进行模型的构建(详情参考: Tensorflow使用LSTM实现中文文本分类(2).使用numpy实现LSTM和RNN网络的前向传播过程).除了使用LST ...

  5. 详解CNN实现中文文本分类过程

    摘要:本文主要讲解CNN实现中文文本分类的过程,并与贝叶斯.决策树.逻辑回归.随机森林.KNN.SVM等分类算法进行对比. 本文分享自华为云社区<[Python人工智能] 二十一.Word2Ve ...

  6. Tensorflow使用LSTM实现中文文本分类(1)

    前言 使用Tensorflow,利用LSTM进行中文文本的分类. 数据集格式如下: ''' 体育 马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的 ...

  7. 用于中文文本分类的中文停用词

    用于中文文本分类的中文停用词,1893个. ! " # $ % & ' ( ) * + , - -- . .. ... ...... ................... ./ . ...

  8. 【NLP】BERT 模型与中文文本分类实践

    简介 2018年10月11日,Google发布的论文<Pre-training of Deep Bidirectional Transformers for Language Understan ...

  9. 基于cnn的中文文本分类

    资源下载地址:https://download.csdn.net/download/sheziqiong/86799359 资源下载地址:https://download.csdn.net/downl ...

最新文章

  1. Opencv腐蚀操作去除激光反光光斑
  2. python debug【】
  3. 北京内推 | 微软亚洲研究院自然语言计算组招聘NLP研究型实习生
  4. undefined reference to '__gxx_personality_v0'
  5. (译)利用ASP.NET加密和解密Web.config中连接字符串
  6. 三态输出门实验报告注意事项_数电基础知识:各种IO输出的类型
  7. CheckstyleException: cannot initialize module TreeWalker - TreeWalker is not allowed as a
  8. css中background的使用总结
  9. 百度人脸识别:功能开通
  10. 帆软帮助文档_聚焦商业智能主赛道,帆软如何取得里程碑式突破
  11. java jsp实验设计心得_jsp课程设计心得_课程设计总结心得
  12. pyqt5优秀项目python_【项目】PYQT5--Python/C++实现网络聊天室
  13. 一看就会的侧方位停车技巧 见了就收了吧
  14. 如何阅读看懂datasheet
  15. 微生物组-扩增子16S分析和可视化(2022.7)
  16. markdown java 代码高亮_Markdown 入门教程
  17. 计算机语言输入不见了,电脑输入法为什么不见了
  18. fatal: You have not concluded your cherry-pick (CHERRY_PICK_HEAD exists). Please, commit your change
  19. 【情人节特别篇】想知道玫瑰在哪些城市最畅销嘛?
  20. 在线学习算法FTRL基本原理

热门文章

  1. 测试用例设计的几种常见方法,测试用例的几大要素
  2. 使用TsLint,报错:space indentation expected
  3. Linux 自定义 RPM 包及制作 YUM 仓库
  4. 【职业学习规划】Android架构师方向
  5. python粘贴代码到word_写论文必备,在线代码高亮工具,无缝粘贴到 Word
  6. Qt5 实战No.01 桌面时钟
  7. python设置ucs2编码_UCS2编码与解码
  8. Python+Vue计算机毕业设计新疆旅游景点信息查询网站0y596(源码+程序+LW+部署)
  9. win7用什么浏览器好
  10. 电脑窗口颜色设置---保护眼睛