引言

TensorFlow很容易上手,但是TensorFlow的很多trick却是提升TensorFlow心法的法门,之前说过TensorFlow的read心法,现在想说一说TensorFlow在RNN上的心法,简直好用到哭 【以下实验均是基于TensorFlow1.0】

简要介绍tensorflow的RNN

其实在前面多篇都已经提到了TensorFlow的RNN,也在我之前的文章TensorFlow实现文本分类文章中用到了BasicLSTM的方法,通常的,使用RNN的时候,我们需要指定num_step,也就是TensorFlow的roll step步数,但是对于变长的文本来说,指定num_step就不可避免的需要进行padding操作,在之前的文章TensorFlow高阶读写教程也使用了dynamic_padding方法实现自动padding,但是这还不够,因为在跑一遍RNN/LSTM之后,还是需要对padding部分的内容进行删除,我称之为“反padding”,无可避免的,我们就需要指定mask矩阵了,这就有点不优雅,但是TensorFlow提供了一个很优雅的解决方法,让mask去见马克思去了,那就是dynamic_rnn

tf.dynamic_rnn

tensorflow 的dynamic_rnn方法,我们用一个小例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,last_states,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,last_states是由(c,h)组成的tuple,均为[batch,128]。

到这里并没有什么不同,但是dynamic有个参数:sequence_length,这个参数用来指定每个example的长度,比如上面的例子中,我们令 sequence_length为[20,13],表示第一个example有效长度为20,第二个example有效长度为13,当我们传入这个参数的时候,对于第二个example,TensorFlow对于13以后的padding就不计算了,其last_states将重复第13步的last_states直至第20步,而outputs中超过13步的结果将会被置零。

dynamic_rnn例子

#coding=utf-8
import tensorflow as tf
import numpy as np
# 创建输入数据
X = np.random.randn(2, 10, 8)# 第二个example长度为6
X[1,6:] = 0
X_lengths = [10, 6]cell = tf.contrib.rnn.BasicLSTMCell(num_units=64, state_is_tuple=True)outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)result = tf.contrib.learn.run_n(
{"outputs": outputs, "last_states": last_states},
n=1,
feed_dict=None)print result[0]assert result[0]["outputs"].shape == (2, 10, 64)# 第二个example中的outputs超过6步(7-10步)的值应该为0
assert (result[0]["outputs"][1,7,:] == np.zeros(cell.output_size)).all()

我们看输出:

{'outputs': array([[[ 0.02343191,  0.05894056,  0.01552576, ..., -0.06954119,-0.02693178, -0.02773715],[-0.01897412,  0.00430241,  0.05111675, ..., -0.12161507,0.00998021, -0.0282588 ],[-0.01222279, -0.00742003,  0.1395104 , ...,  0.06212089,0.05438172, -0.10756982],..., [ 0.04471944,  0.03058323, -0.08105398, ..., -0.08458089,-0.00789265,  0.00711049],[ 0.07910491, -0.0015225 , -0.08136954, ..., -0.03702021,-0.02530194,  0.07729477],[ 0.06114135,  0.0263763 ,  0.0153004 , ..., -0.07590827,-0.00899063, -0.031571  ]],[[ 0.04057412,  0.0379415 ,  0.01818413, ...,  0.00513165,0.09185232, -0.16915748],[ 0.08922272,  0.04556143, -0.06847201, ..., -0.03329186,0.07859877, -0.22903247],[ 0.04083256, -0.0191676 , -0.00690892, ..., -0.00552511,0.07809589, -0.16655875],..., [ 0.        ,  0.        ,  0.        , ...,  0.        ,0.        ,  0.        ],[ 0.        ,  0.        ,  0.        , ...,  0.        ,0.        ,  0.        ],[ 0.        ,  0.        ,  0.        , ...,  0.        ,0.        ,  0.        ]]]), 'last_states': LSTMStateTuple(c=array([[  1.17486513e-01,   4.53374791e-02,   3.27930624e-02,1.88688948e-01,  -9.18940578e-02,   1.10607361e-01,7.69938294e-02,   1.02080487e-01,   2.35188842e-01,-6.99273490e-02,   1.98158514e-01,  -2.66004847e-02,-2.00984914e-01,  -1.22899439e-01,  -9.09574947e-03,1.25963024e-01,   8.78420353e-02,  -4.48895848e-02,1.41703260e-02,   7.78878760e-03,  -3.56721497e-02,-1.02126920e-01,  -9.31018826e-02,  -1.18749056e-01,-2.15687558e-02,  -6.48136325e-02,  -6.67117612e-02,2.06457878e-01,   1.05809077e-01,   3.25519072e-02,6.68543364e-02,  -1.25674027e-01,   1.65443839e-01,-8.19379933e-02,  -2.68197695e-02,  -1.26924280e-01,9.66936841e-02,   2.45289838e-02,  -3.15856903e-02,-9.30471642e-02,   2.28047923e-02,   1.64577723e-01,-2.13811172e-02,   2.31624708e-01,  -5.05328136e-02,-2.15352598e-01,   1.17756556e-01,   1.24231633e-01,2.17948294e-01,  -1.88141852e-01,   5.56704829e-02,1.85995614e-04,  -1.63170139e-02,   4.14733115e-02,-1.42410828e-01,  -2.10698220e-02,   1.13032204e-01,1.16487820e-01,   1.14937607e-01,   1.15206014e-01,9.07994735e-02,  -1.47575747e-01,  -1.67919061e-02,-5.57344372e-02],[ -1.87032883e-01,  -4.50730933e-02,   1.65264860e-01,-1.57064693e-01,  -1.02704183e-01,  -1.42700035e-01,-1.82858618e-01,  -5.69656656e-02,  -3.19701571e-01,-9.45731981e-04,  -8.96991629e-02,   6.37877888e-02,-7.24395155e-02,   2.24324167e-01,  -2.26432828e-01,-2.12203247e-02,  -9.89278157e-02,  -1.79787292e-01,1.17519710e-01,  -2.43337123e-01,   6.08713955e-02,3.71411367e-01,   3.96845821e-02,  -1.34371544e-01,-1.54702491e-01,  -1.80343050e-02,   7.06988306e-02,-1.58112671e-01,  -1.74782878e-01,   1.24460790e-01,-2.01408352e-02,  -2.19578859e-01,  -1.09101701e-01,-3.36411660e-02,  -4.12966791e-02,  -2.62211522e-01,6.09266090e-02,   5.15926436e-02,   1.31553677e-01,3.85248320e-02,   6.82502698e-02,   3.20785503e-01,6.02489641e-02,   1.03486249e-02,  -1.98853998e-01,2.42482932e-01,  -3.03208095e-03,   3.26806427e-02,1.43904791e-01,   4.83002308e-02,   1.06806422e-01,2.19021559e-01,  -1.04280654e-01,   7.02105858e-02,-1.08238911e-01,   5.31858915e-02,  -1.30427149e-01,-3.14307444e-02,   2.60903800e-02,  -3.49547176e-03,3.15445855e-02,   1.26248331e-01,   2.98049766e-01,-1.35553357e-01]]), h=array([[  6.11413522e-02,   2.63763025e-02,   1.53004046e-02,1.00835659e-01,  -4.07618767e-02,   6.39206416e-02,4.17340362e-02,   5.10448527e-02,   9.37222463e-02,-3.43376107e-02,   1.00684542e-01,  -1.28972917e-02,-1.20061738e-01,  -6.48411970e-02,  -4.66407837e-03,6.29309198e-02,   4.64027731e-02,  -1.80123985e-02,7.18521681e-03,   4.55297690e-03,  -1.95851481e-02,-4.94828658e-02,  -4.56579935e-02,  -5.68909598e-02,-1.03985798e-02,  -2.80805943e-02,  -3.67050137e-02,1.11822759e-01,   4.82685695e-02,   1.51483196e-02,3.61371426e-02,  -4.92942874e-02,   8.74024618e-02,-3.75624886e-02,  -1.54172618e-02,  -6.26848414e-02,3.92306304e-02,   1.08791341e-02,  -1.76010076e-02,-4.68257540e-02,   1.11274774e-02,   7.26592349e-02,-1.10059670e-02,   1.25391653e-01,  -2.45894375e-02,-1.10484543e-01,   5.64758454e-02,   6.85158790e-02,1.05166465e-01,  -9.38722289e-02,   2.87157035e-02,9.68917170e-05,  -7.59567519e-03,   2.00130197e-02,-5.71313903e-02,  -1.06302802e-02,   6.53980752e-02,5.53559936e-02,   5.63571469e-02,   5.87699760e-02,4.93030711e-02,  -7.59082740e-02,  -8.99063316e-03,-3.15710039e-02],[ -8.75580540e-02,  -2.40814362e-02,   7.62920499e-02,-7.99111282e-02,  -5.25187098e-02,  -6.82907819e-02,-9.22920867e-02,  -2.82334342e-02,  -1.35842188e-01,-4.41795008e-04,  -4.67307509e-02,   3.26420635e-02,-3.43710296e-02,   1.08600958e-01,  -1.19684674e-01,-1.15702585e-02,  -5.29742132e-02,  -8.58632779e-02,5.49293634e-02,  -1.28582904e-01,   3.30139501e-02,1.91180419e-01,   2.06462597e-02,  -6.48707477e-02,-8.20119830e-02,  -8.35309469e-03,   3.54353392e-02,-7.91071596e-02,  -8.36684223e-02,   6.17335216e-02,-1.01217617e-02,  -1.00540861e-01,  -5.48336196e-02,-1.71105389e-02,  -2.12356078e-02,  -1.14496268e-01,2.93849624e-02,   2.36536930e-02,   6.08473933e-02,1.81132892e-02,   3.16145248e-02,   1.56376674e-01,3.24342202e-02,   5.35344708e-03,  -9.31969777e-02,1.23855219e-01,  -1.54691975e-03,   1.70947532e-02,7.22062554e-02,   2.54588642e-02,   5.57794494e-02,9.75779489e-02,  -4.55104484e-02,   3.46636330e-02,-5.55832345e-02,   2.72228363e-02,  -7.08426689e-02,-1.49771182e-02,   1.34402453e-02,  -1.72122309e-03,1.56672952e-02,   6.92526562e-02,   1.50181313e-01,-7.16690686e-02]]))}

可以看出,对于第二个example超过6步的outputs,是直接被设置成0了,而last_states将7-10步的输出重复第6步的输出。可见节省了不少的计算开销

心得

对于NLP的一些任务来说,使用tf.dynamic_rnn显然比其他的RNN来的更方便和节约计算资源,因此推荐优先使用tf.dynamic_rnn

tensorflow高阶教程:tf.dynamic_rnn相关推荐

  1. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  2. 深度学习(17)TensorFlow高阶操作六: 高阶OP

    深度学习(17)TensorFlow高阶操作六: 高阶OP 1. Where(tensor) 2. where(cond, A, B) 3. 1-D scatter_nd 4. 2-D scatter ...

  3. 深度学习(16)TensorFlow高阶操作五: 张量限幅

    深度学习(16)TensorFlow高阶操作五: 张量限幅 1. clip_by_value 2. relu 3. clip_by_norm 4. Gradient clipping 5. 梯度爆炸实 ...

  4. 深度学习(15)TensorFlow高阶操作四: 填充与复制

    深度学习(15)TensorFlow高阶操作四: 填充与复制 1. Pad 2. 常用于Image Padding 3. tile 4. tile VS broadcast_to Outline pa ...

  5. 深度学习(14)TensorFlow高阶操作三: 张量排序

    深度学习(14)TensorFlow高阶操作三: 张量排序 一. Sort, argsort 1. 一维Tensor 2. 多维Tensor 二. Top_k 三. Top-k accuracy(To ...

  6. 深度学习(12)TensorFlow高阶操作一: 合并与分割

    深度学习(12)TensorFlow高阶操作一: 合并与分割 1. concat 2. stack: create new dim 3. Dim mismatch 4. unstuck 5. spli ...

  7. python openpyxl读写xlsx_python高阶教程-python操作xlsx文件(openpyxl)

    本篇内容来自原创小册子<python高阶教程>,点击查看目录. 背景 在处理一些作业时,经常会碰到统计未交人数.分数等需求,虽然我们在数据库中有了对应的数据, 但是数据库只是面向开发者的, ...

  8. hexo高阶教程:想让你的博客被更多的人在搜索引擎中搜到吗?

    本文首发在我的个人博客:http://cherryblog.site/,欢迎大家前去参观,顺便求fork,么么哒~ 上一次在掘金上发表的hexo高阶教程:hexo高阶教程next主题优化之加入网易云音 ...

  9. Tensorflow高阶内容(五)- Deep Learning

    高阶内容 5.1 Classification分类学习 5.2 什么是过拟合(Overfitting) 5.3 Dropout 解决 Overfitting 5.4 什么是卷积神经网络CNN(Conv ...

最新文章

  1. 【FFmpeg】结构体详解(二):AVStream、AVPacket、AVOutputFormat
  2. 一般将来时语法课教案_英语语法:一般现在时和现在进行时
  3. android_home is not set mac,mac解决appium-doctor报ANDROID_HOME is NOT set
  4. java调用第三方dll文件 源码_C++调用python文件(包含第三方库)
  5. compiz把xfce4系统搞崩溃后的恢复方案
  6. 前端学习(3023):vue+element今日头条管理-首页layont布局
  7. java字节流转字符串_字节流与字符流的区别及相互转换
  8. PHP文字转语音合成网源码 百度API开发
  9. 清除故障,Windows2003更加亲切
  10. 【clickhouse】clickhouse NO DELAY, INTO OUTFILE, SETTINGS, ON, FORMAT, Dot, SYNC, token
  11. bootstrap3 - 分页
  12. java玻璃效果_swing透明效果(没aero毛玻璃那么好看)
  13. python 代码分块_python大数据分块处理
  14. 阶段1 语言基础+高级_1-3-Java语言高级_05-异常与多线程_第6节 Lambda表达式_8_Lambda省略格式Lambda使用前...
  15. 超级计算机卫星云图,台风路径实时发布系统20号台风云图 台风艾莎尼高清卫星云图实时追踪...
  16. 人力资源管理系统概要设计说明书
  17. php中开通短信验证码,php利用云片网实现短信验证码功能的示例代码
  18. Jxls使用模版导出excel表格公式无法自动计算失效解决
  19. 包装exp是什么意思_药瓶说明中EXP是什么意思?
  20. Python Requests实现天气预报

热门文章

  1. 高通Q+A平台 android gcore解析环境搭建
  2. 指针 Swap交换函数
  3. laravel视图 compact 循环遍历,if判断
  4. bio linux 创建_Linux IO请求处理流程-bio和request
  5. Knald - 1.2.1 烘培贴图,利用贴图转换成其他贴图   笔记
  6. 2020研究生数学建模E题--AlexNet深度网络解法(大雾能见度估计与预测)(含代码)
  7. 【CF226C】Anniversary
  8. mysql查询出现毫秒值快速解决方法
  9. 网页出现503 service unavailable是什么意思?怎么解决?
  10. Tomcat源码解析:环境搭建