写在前面

这是NLP保姆级教程的第二篇----基于RNN的文本分类实现(Text RNN)

参考的的论文是来自2016年复旦大学IJCAI上的发表的关于循环神经网络在多任务文本分类上的应用:Recurrent Neural Network for Text Classification with Multi-Task Learning[1]

论文概览

在先前的许多工作中,模型的学习都是基于单任务,对于复杂的问题,也可以分解为简单且相互独立的子问题来单独解决,然后再合并结果,得到最初复杂问题的结果。这样做看似合理,其实是不正确的,因为现实世界中很多问题不能分解为一个一个独立的子问题,即使可以分解,各个子问题之间也是相互关联的,通过一些共享因素或「共享表示(share representation)」 联系在一起。把现实问题当做一个个独立的单任务处理,往往会忽略了问题之间所富含的丰富的关联信息。

上面的问题引出了本文的重点——「多任务学习(Multi-task learning)」,把多个相关(related)的任务(task)放在一起学习。多个任务之间共享一些因素,它们可以在学习过程中,共享它们所学到的信息,这是单任务学习没有具备的。相关联的多任务学习比单任务学习能去的更好的泛化(generalization)效果。本文基于 RNN 循环神经网络,提出三种不同的信息共享机制,整体网络是基于所有的任务共同学习得到。

下图展示的是单任务学习和多任务学习的流程图,可以对比一下区别。

下面具体介绍一下文章中的三个模型。

Model I: Uniform-Layer Architecture

在他们提出的第一个模型中,不同的任务共享一个LSTM网络层和一个embedding layer,此外每个任务还有其自己的embedding layer。所以对于上图中的任务m,输入x包括了两个部分:

其中等号右侧第一项和第二项分别表示该任务「特有」的word embedding和该模型中「共享」的word embedding,两者做一个concatenation。

LSTM网络层是所有任务所共享的,对于任务m的最后sequence representation为LSTM的输出:

Model II: Coupled-Layer Architecture

在第二个模型中,为每个任务都指定了「特定」的LSTM layer,但是不同任务间的LSTM layer可以共享信息。

为了更好地控制在不同LSTM layer之间的信息流动,作者提出了一个global gating unit,使得模型具有决定信息流动程度的能力。

为此,他们改写了LSTM中的表达式:

其中,

Model III: Shared-Layer Architecture

与模型二相似,作者也为每个单独的任务指派了特定的LSTM层,但是对于整体的模型使用了双向的LSTM,这样可以使得信息共享更为准确。

模型表现

论文作者在4个数据集上对上述模型做了评价,并和其他state-of-the-art的网络模型进行了对比,均显示最好的效果。

代码实现

RNN的代码框架和上一篇介绍的CNN类似,首先定义一个RNN类来实现论文中的模型

class RNN(BaseModel):"""A RNN class for sentence classificationWith an embedding layer + Bi-LSTM layer + FC layer + softmax"""def __init__(self, sequence_length, num_classes, vocab_size,embed_size, learning_rate, decay_steps, decay_rate,hidden_size, is_training, l2_lambda, grad_clip,initializer=tf.random_normal_initializer(stddev=0.1)):

这里的模型包括了一层embedding,一层双向LSTM,一层全连接层最后接上一个softmax分类函数。

然后依次定义模型,训练,损失等函数在后续调用。

def inference(self):"""1. embedding layer2. Bi-LSTM layer3. concat Bi-LSTM output4. FC(full connected) layer5. softmax layer"""# embedding layerwith tf.name_scope('embedding'):self.embedded_words = tf.nn.embedding_lookup(self.Embedding, self.input_x)# Bi-LSTM layerwith tf.name_scope('Bi-LSTM'):lstm_fw_cell = rnn.BasicLSTMCell(self.hidden_size)lstm_bw_cell = rnn.BasicLSTMCell(self.hidden_size)if self.dropout_keep_prob is not None:lstm_fw_cell = rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=self.dropout_keep_prob)lstm_bw_cell = rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=self.dropout_keep_prob)outputs, output_states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell,self.embedded_words,dtype=tf.float32)output = tf.concat(outputs, axis=2)output_last = tf.reduce_mean(output, axis=1)# FC layerwith tf.name_scope('output'):self.score = tf.matmul(output_last, self.W_projection) + self.b_projectionreturn self.scoredef loss(self):# losswith tf.name_scope('loss'):losses = tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y, logits=self.score)data_loss = tf.reduce_mean(losses)l2_loss = tf.add_n([tf.nn.l2_loss(cand_v) for cand_v in tf.trainable_variables()if 'bias' not in cand_v.name]) * self.l2_lambdadata_loss += l2_lossreturn data_lossdef train(self):learning_rate = tf.train.exponential_decay(self.learning_rate, self.global_step,self.decay_steps, self.decay_rate, staircase=True)optimizer = tf.train.AdamOptimizer(learning_rate)grads_and_vars = optimizer.compute_gradients(self.loss_val)grads_and_vars = [(tf.clip_by_norm(grad, self.grad_clip), val) for grad, val in grads_and_vars]train_op = optimizer.apply_gradients(grads_and_vars, global_step=self.global_step)return train_op

训练部分的数据集这里就直接采用CNN那篇文章相同的数据集(懒...),预处理的方式与函数等都是一样的,,,

def train(x_train, y_train, vocab_processor, x_dev, y_dev):with tf.Graph().as_default():session_conf = tf.ConfigProto(# allows TensorFlow to fall back on a device with a certain operation implementedallow_soft_placement= FLAGS.allow_soft_placement,# allows TensorFlow log on which devices (CPU or GPU) it places operationslog_device_placement=FLAGS.log_device_placement)sess = tf.Session(config=session_conf)with sess.as_default():# initialize cnnrnn = RNN(sequence_length=x_train.shape[1],num_classes=y_train.shape[1],vocab_size=len(vocab_processor.vocabulary_),embed_size=FLAGS.embed_size,l2_lambda=FLAGS.l2_reg_lambda,is_training=True,grad_clip=FLAGS.grad_clip,learning_rate=FLAGS.learning_rate,decay_steps=FLAGS.decay_steps,decay_rate=FLAGS.decay_rate,hidden_size=FLAGS.hidden_size)# output dir for models and summariestimestamp = str(time.time())out_dir = os.path.abspath(os.path.join(os.path.curdir, 'run', timestamp))if not os.path.exists(out_dir):os.makedirs(out_dir)print('Writing to {} \n'.format(out_dir))# checkpoint dir. checkpointing – saving the parameters of your model to restore them later on.checkpoint_dir = os.path.abspath(os.path.join(out_dir, FLAGS.ckpt_dir))checkpoint_prefix = os.path.join(checkpoint_dir, 'model')if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)# Write vocabularyvocab_processor.save(os.path.join(out_dir, 'vocab'))# Initialize allsess.run(tf.global_variables_initializer())def train_step(x_batch, y_batch):"""A single training step:param x_batch::param y_batch::return:"""feed_dict = {rnn.input_x: x_batch,rnn.input_y: y_batch,rnn.dropout_keep_prob: FLAGS.dropout_keep_prob}_, step, loss, accuracy = sess.run([rnn.train_op, rnn.global_step, rnn.loss_val, rnn.accuracy],feed_dict=feed_dict)time_str = datetime.datetime.now().isoformat()print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))def dev_step(x_batch, y_batch):"""Evaluate model on a dev setDisable dropout:param x_batch::param y_batch::param writer::return:"""feed_dict = {rnn.input_x: x_batch,rnn.input_y: y_batch,rnn.dropout_keep_prob: 1.0}step, loss, accuracy = sess.run([rnn.global_step, rnn.loss_val, rnn.accuracy],feed_dict=feed_dict)time_str = datetime.datetime.now().isoformat()print("dev results:{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))# generate batchesbatches = data_process.batch_iter(list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)# training loopfor batch in batches:x_batch, y_batch = zip(*batch)train_step(x_batch, y_batch)current_step = tf.train.global_step(sess, rnn.global_step)if current_step % FLAGS.validate_every == 0:print('\n Evaluation:')dev_step(x_dev, y_dev)print('')path = saver.save(sess, checkpoint_prefix, global_step=current_step)print('Save model checkpoint to {} \n'.format(path))def main(argv=None):x_train, y_train, vocab_processor, x_dev, y_dev = prepocess()train(x_train, y_train, vocab_processor, x_dev, y_dev)if __name__ == '__main__':tf.app.run()

「完整代码可以在公众号后台回复"RNN2016"获取。」

一起交流

想和你一起学习进步!『NewBeeNLP』目前已经建立了多个不同方向交流群(机器学习 / 深度学习 / 自然语言处理 / 搜索推荐 / 图网络 / 面试交流 / 等),名额有限,赶紧添加下方微信加入一起讨论交流吧!(注意一定要备注信息才能通过)

本文参考资料

[1]

Recurrent Neural Network for Text Classification with Multi-Task Learning: https://arxiv.org/abs/1605.05101

END -

【NLP保姆级教程】手把手带你CNN文本分类(附代码)

Transformers Assemble(PART III)

BERT源码分析(PART III)

Bug越来越少许愿池

【NLP保姆级教程】手把手带你RNN文本分类(附代码)相关推荐

  1. 【NLP傻瓜式教程】手把手带你RNN文本分类(附代码)

    文章来源于NewBeeNLP,作者kaiyuan 写在前面 这是NLP傻瓜式教程的第二篇----基于RNN的文本分类实现(Text RNN) 参考的的论文是来自2016年复旦大学IJCAI上的发表的关 ...

  2. 【NLP傻瓜式教程】手把手带你HAN文本分类(附代码)

    继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...

  3. 【NLP傻瓜式教程】手把手带你RCNN文本分类(附代码)

    继续之前的文本分类系列 [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) [NLP傻瓜式教程]手把手带你fastText文本分类(附代码) ...

  4. 【NLP傻瓜式教程】手把手带你fastText文本分类(附代码)

    写在前面 已经发布: [NLP傻瓜式教程]手把手带你CNN文本分类(附代码) [NLP傻瓜式教程]手把手带你RNN文本分类(附代码) 继续NLP傻瓜式教程系列,今天的教程是基于FAIR的Bag of ...

  5. 【NLP】保姆级教程:手把手带你CNN文本分类(附代码)

    分享一篇老文章,文本分类的原理和代码详解,非常适合NLP入门! 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classifi ...

  6. 【NLP保姆级教程】手把手带你CNN文本分类(附代码)

    分享一篇老文章,文本分类的原理和代码详解,非常适合NLP入门! 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classifi ...

  7. 【NLP傻瓜式教程】手把手带你CNN文本分类(附代码)

    文章来源于NewBeeNLP,作者kaiyuan 写在前面 本文是对经典论文<Convolutional Neural Networks for Sentence Classification[ ...

  8. go项目部署服务器保姆级教程(带图)

    第一步把项目打包 1.确保本地goland的操作系统为linux go env 找到GOOS如果为window就修改为Linux 修改命令为 go env -w GOOS=linux 2.打包 在项目 ...

  9. js对象、数组、字符串操作总结(保姆级教程)

    对象操作 1. 扩展运算符 作用是遍历某个对象或者数组 testMethod() {// 三个点 ... 俗称扩展运算符或延展运算符,需要注意的是扩展运算符在拷贝的时候只能深拷贝第一层,第二层及以下都 ...

最新文章

  1. 《从零开始学Swift》学习笔记(Day 7)——Swift 2.0中的print函数几种重载形式
  2. 用python画万花筒写轮眼_万花筒写轮眼画法教程
  3. SQL case when then 的用法
  4. 一维数据高斯滤波器_透彻理解高斯混合模型
  5. 『Python基础-12』各种推导式(列表推导式、字典推导式、集合推导式)
  6. 网络光端机产品特点及实际应用范围详解
  7. 35 SD配置-销售凭证设置-定义项目类别组
  8. python步长为负时的情况
  9. 三元运算符(TernaryOperator)
  10. Eclipse中移除未使用的类引用的三种办法
  11. Matlab Tricks(二十三)—— 保存图像到 pdf
  12. Linux-shell编程_xargs命令详解
  13. 《信号与系统学习笔记》—信号与系统(三)
  14. Java 并发编程的艺术
  15. 阿里巴巴字体库的下载以及三种用法
  16. SQLServer 删除表中重复数据(除ID不同的)
  17. JLINK 驱动 V7.00a 更新导致JLINK V9无法使用问题解决
  18. Django 开发收银系统六
  19. 来了!Android应用市场64位应用策略
  20. 毕业后5年,我终于变成了月薪13000的软件测试工程师

热门文章

  1. 数据分析,如何构建指标体系
  2. Axure智慧水务移动端原型、智慧泵房、水厂监控、营收管理、DMA漏损、维护管理、GIS地图、水质监控、电商系统
  3. POJ3250(单调栈)
  4. CentOS 6.5 yum安装配置lnmp服务器(Nginx+PHP+MySQL)
  5. 使用HTML5 canvas做地图(1)基础知识
  6. Netty简单样例分析[转]
  7. javascript模态窗口Demo
  8. django orm 之makemigrations和migrate命令
  9. CentOS下安装Orcale
  10. 标准模板库(STL)学习探究之Multimap容器