手把手教你用 TensorFlow 实现文本分类(下)

本文作者:AI研习社 2017-05-29 13:36
导语:文本分类全流程解析。

雷锋网(公众号:雷锋网)按:本文作者张庆恒,原文载于作者个人博客,雷锋网(公众号:雷锋网)已获授权。

本篇文章主要记录对之前用神经网络做文本识别的初步优化,进一步将准确率由原来的65%提高到80%,这里优化的几个方面包括:

● 随机打乱训练数据

● 增加隐层,和验证集

● 正则化

● 对原数据进行PCA预处理

● 调节训练参数(迭代次数,batch大小等)

随机化训练数据

观察训练数据集,发现训练集是按类别存储,读进内存后在仍然是按类别顺序存放。这样顺序取一部分作为验证集,很大程度上会减少一个类别的训练样本数,对该类别的预测准确率会有所下降。所以首先考虑打乱训练数据。

在已经向量化的训练数据的基础上打乱数据,首先合并data和label,打乱后再将数据和标签分离为trian.txt和train_label.txt。这里可以直接使用shell命令:

1、将labels加到trian.txt的第一列

paste -d" " train_labels.txt train.txt > train_to_shuf.txt

2、随机打乱文件行

shuf train_to_shuf.txt -o train.txt

3、 提取打乱后文件的第一列,保存到train_labels.txt

cat train.txt | awk '{print $1}' > train_labels.txt

4、删除第一列label.

awk '{$1="";print $0}'  train.txt

这样再次以相同方式训练,准确率由65%上升到75% 。

改变网络结构,增加隐层

之前的网络直接对输入数据做softmax回归,这里考虑增加隐层,数量并加入验证集观察准确率的变化情况。这里加入一个隐层,隐层节点数为500,激励函数使用Relu。替换原来的网络结构,准确率进一步上升。

正则化,改善过拟合

观察模型对训练集的拟合程度到90%+,而通过上步对训练数据的准确率为76%,一定程度上出现了过拟合的现象,这里在原有cost function中上加入正则项,希望减轻过拟合的现象。这里使用L2正则。连同上步部分的代码如下:

#!/usr/bin/python

#-*-coding:utf-8-*-

LAYER_NODE1 = 500 # layer1 node num

INPUT_NODE = 5000

OUTPUT_NODE = 10

REG_RATE = 0.01

import tensorflow as tf

from datasets import datasets

def interface(inputs, w1, b1, w2,b2):

"""

compute forword progration result

"""

lay1 = tf.nn.relu(tf.matmul(inputs, w1) + b1)

return tf.nn.softmax(tf.matmul(lay1, w2) + b2) # need softmax??

data_sets = datasets()

data_sets.read_train_data(".", True)

sess = tf.InteractiveSession()

x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x-input")

y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name="y-input")

w1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER_NODE1], stddev=0.1))

b1 = tf.Variable(tf.constant(0.0, shape=[LAYER_NODE1]))

w2 = tf.Variable(tf.truncated_normal([LAYER_NODE1, OUTPUT_NODE], stddev=0.1))

b2 = tf.Variable(tf.constant(0.0, shape=[OUTPUT_NODE]))

y = interface(x, w1, b1, w2, b2)

cross_entropy = -tf.reduce_sum(y_ * tf.log(y + 1e-10))

regularizer = tf.contrib.layers.l2_regularizer(REG_RATE)

regularization = regularizer(w1) + regularizer(w2)

loss = cross_entropy + regularization

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

#training

tf.global_variables_initializer().run()

saver = tf.train.Saver()

cv_feed = {x: data_sets.cv.text, y_: data_sets.cv.label}

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

for i in range(5000):

if i % 200 == 0:

cv_acc = sess.run(acc, feed_dict=cv_feed)

print "train steps: %d, cv accuracy is %g " % (i, cv_acc)

batch_xs, batch_ys = data_sets.train.next_batch(100)

train_step.run({x: batch_xs, y_: batch_ys})

path = saver.save(sess, "./model4/model.md")

PCA处理

一方面对文本向量集是严重稀疏的矩阵,而且维度较大,一方面影响训练速度,一方面消耗内存。这里考虑对数据进行PCA处理。该部分希望保存99%的差异率,得到相应的k,即对应的维度。

#!/usr/bin/python

#-*-coding:utf-8-*-

"""

PCA for datasets

"""

import os

import sys

import commands

import numpy

from contextlib import nested

from datasets import datasets

ORIGIN_DIM = 5000

def pca(origin_mat):

"""

gen matrix using pca

row of origin_mat is one sample of dataset

col of origin_mat is one feature

return matrix  U, s and  V

"""

# mean,normaliza1on

avg = numpy.mean(origin_mat, axis=0)

# covariance matrix

cov = numpy.cov(origin_mat-avg,rowvar=0)

#Singular Value Decomposition

U, s, V = numpy.linalg.svd(cov, full_matrices=True)

k = 1;

sigma_s = numpy.sum(s)

# chose smallest k for 99% of variance retained

for k in range(1, ORIGIN_DIM+1):

variance = numpy.sum(s[0:k]) / sigma_s

print "k = %d, variance is %f" % (k, variance)

if variance >= 0.99:

break

if k == ORIGIN_DIM:

print "some thing unexpected , k is same as ORIGIN_DIM"

exit(1)

return U[:, 0:k], k

if __name__ == '__main__':

"""

main, read train.txt, and do pca

save file to train_pca.txt

"""

data_sets = datasets()

train_text, _ = data_sets.read_from_disk(".", "train", one_hot=False)

U, k = pca(train_text)

print "U shpae: ", U.shape

print "k is : ", k

text_pca = numpy.dot(train_text, U)

text_num = text_pca.shape[0]

print "text_num in pca is ", text_num

with open("./train_pca.txt", "a+") as f:

for i in range(0, text_num):

f.write(" ".join(map(str, text_pca[i,:])) + "\n")

最终得到k=2583。该部分准确率有所提高但影响不大。

调整网络参数

该部分主要根据严重集和测试集的表现不断调整网路参数,包括学习率、网路层数、每层节点个数、正则损失、迭代次数、batch大小等。最终得到80%的准确率。

小结

对神经网路进行初步优化,由原来的65%的准确率提高到80%,主要的提高在于训练数据的随机化,以及网络结构的调整。为提升训练速度,同时减少内存消耗,对数据进行了降维操作。

之后对代码的结构进行了整理,这里没有提及,该部分代码包括 nn_interface.py 和 nn_train.py 分别实现对网络结构的定义以及训练流程的管理。

后面会结合tensorflow的使用技巧对训练进行进一步优化。

雷锋网相关文章:

手把手教你用 TensorFlow 实现文本分类(上)

手把手教你如何用 TensorFlow 实现基于 DNN 的文本分类

手把手教你用 TensorFlow 实现文本分类(下)相关推荐

  1. html文本分类输出,手把手教你用 TensorFlow 实现文本分类(上)

    雷锋网(公众号:雷锋网)按:本文作者张庆恒,原文载于作者个人博客,雷锋网已获授权. 由于需要学习语音识别,期间接触了深度学习的算法.利用空闲时间,想用神经网络做一个文本分类的应用, 目的是从头到尾完成 ...

  2. 手把手教 | 使用Bert预训练模型文本分类(内附源码)

    作者:GjZero 标签:Bert, 中文分类, 句子向量 本文约1500字,建议阅读8分钟. 本文从实践入手,带领大家进行Bert的中文文本分类和作为句子向量进行使用的教程. Bert介绍 Bert ...

  3. 实战七:手把手教你用TensorFlow进行验证码识别(上)

    实战七:手把手教你用TensorFlow进行验证码识别(上) github下载地址 目录 准备模型开发环境 生成验证码数据集 输入与输出数据处理 模型结构设计 模型损失函数设计 模型训练过程分析 模型 ...

  4. 手把手教你使用TensorFlow训练出自己的模型

    手把手教你使用TensorFlow训练出自己的模型 一.前言 搭建TensorFlow开发环境一直是初学者头疼的问题,为了帮忙初学者快速使用TensorFlow框架训练出自己的模型,作者开发了一款基于 ...

  5. 实战六:手把手教你用TensorFlow进行手写数字识别

    手把手教你用TensorFlow进行手写数字识别 github下载地址 目录 手写体数字MNIST数据集介绍 MNIST Softmax网络介绍 实战MNIST Softmax网络 MNIST CNN ...

  6. 报名 | NVIDIA线下交流会:手把手教你搭建TensorFlow Caffe深度学习服务器

    7月21日(周六)下午14:30,量子位与NVIDIA英伟达开发者社区联合举办线下交流会,拥有丰富一线开发经验的NVIDIA开发者社区经理Ken He,将手把手教你搭建TensorFlow & ...

  7. 今晚直播 | 谷歌资深工程师手把手教你使用TensorFlow最新API构建学习模型

    目前,深度学习的研究和应用大受追捧,各种开源的深度学习框架层出不穷.TensorFlow 作为目前最受欢迎的深度学习框架,已经在 GitHub 上获得了 112194 个 star,受欢迎程序可见一斑 ...

  8. 官方资源帖!手把手教你在TensorFlow 2.0中实现CycleGAN,推特上百赞

    铜灵 发自 凹非寺 量子位 出品| 公众号 QbitAI CycleGAN,一个可以将一张图像的特征迁移到另一张图像的酷算法,此前可以完成马变斑马.冬天变夏天.苹果变桔子等一颗赛艇的效果. 这行被顶会 ...

  9. 手把手教你用TensorFlow、Keras打造美剧《硅谷》中的“识别热狗”APP

    来源:机械鸡 作者:瑶瑶 本文长度为10000字,建议阅读20分钟+ 本文手把手教你开发自己的app~ HBO热播剧<硅谷>最近推出了一款能够识别"热狗"和" ...

最新文章

  1. Brian 的 Perl 问题之万能指南
  2. perl学习之(not install YAML)解决
  3. Qt编程之QTreeWidget使用方法
  4. arcgis拆分多部件要素
  5. css3组件实战--绚丽效果篇
  6. 避免线上故障的10条建议
  7. 从零开始,DIY一个jQuery(2)
  8. ESX的VSWITCH坏了,如何转移到新建的虚拟交换机上?
  9. VMware安装linux系统镜像教程
  10. 87.3 laravel中常见问题以及解决方案
  11. PrimeNG之DataTable
  12. 鹰迪电商|抖音发布作品定位可以随便设置吗?
  13. 使用Java统计英文文章的单词频率。
  14. Python基础之闭包函数
  15. sql中模糊查询的字段中包含百分号%的语句
  16. 线性代数学习笔记——第二十一讲——矩阵秩的等式
  17. 会议或期刊是否被EI
  18. 灰色 GM(1,1)模型在重庆商品房销售价格预测中的应用
  19. java123456
  20. 【中级软考】数字签名的概念及其作用

热门文章

  1. oracle批量加载,Oracle教程:使用SQL*Loader高速批量数据加载工具
  2. Oracle存储过程以及游标
  3. C++虚继承和虚基类详解(二)
  4. linux下使用 du查看某个文件或目录占用磁盘空间的大小
  5. 概率统计笔记:分布的核
  6. 数据中台产品经理面试指南(二)
  7. tableau必知必会之用蝴蝶图(旋风图)实现数据之间对比
  8. Python 函数式编程
  9. 台式电脑可以连wifi吗_[Windows] wifi音箱:台式电脑也可以连接蓝牙音箱了
  10. Python入门100题 | 第072题