Tensorflow:tf.contrib.rnn.DropoutWrapper函数(谷歌已经为Dropout申请了专利!)、MultiRNNCell函数的解读与理解

目录

1、tf.contrib.rnn.DropoutWrapper函数解读与理解

1.1、源代码解读

1.2、案例应用

2、tf.contrib.rnn.MultiRNNCell函数解读与理解

2.1、源代码解读

2.2、案例应用


tensorflow官网API文档:https://tensorflow.google.cn/api_docs

1、tf.contrib.rnn.DropoutWrapper函数解读与理解

在机器学习的模型中,如果模型的参数太多,而训练样本又太少,训练出来的模型很容易产生过拟合的现象。在训练神经网络的时候经常会遇到过拟合的问题。过拟合具体表现在:模型在训练数据上损失函数较小,预测准确率较高;但是在测试数据上损失函数比较大,预测准确率较低。

机器学习模型训练中,过拟合现象实在令人头秃。而 2012 年 Geoffrey Hinton 提出的 Dropout 对防止过拟合有很好的效果。之后大量 Dropout 变体涌现,这项技术也成为机器学习研究者常用的训练 trick。万万没想到的是,谷歌为该项技术申请了专利,而且这项专利已经正式生效,2019-06-26 专利生效,2034-09-03 专利到期!

Dropout,指在神经网络中,每个神经单元在每次有数据流入时,以一定的概率keep_prob正常工作,否则输出0值。这是一种有效的正则化方法,可以有效降低过拟合。在RNN中进行dropout时,对于RNN的部分不进行dropout,也就是说从t-1时候的状态传递到t时刻进行计算时,这个中间不进行memory的dropout;仅在同一个t时刻中,多层cell之间传递信息的时候进行dropout。在RNN中,这里的dropout是在输入,输出,或者不用的循环层之间使用,或者全连接层,不会在同一层的循环体中使用。

1.1、源代码解读

Operator adding dropout to inputs and outputs of the given cell. 操作者将dropout添加到给定单元的输入和输出。
tf.compat.v1.nn.rnn_cell.DropoutWrapper(*args, **kwargs
)

Args:

  • cell: an RNNCell, a projection to output_size is added to it.
  • input_keep_prob: unit Tensor or float between 0 and 1, input keep probability; if it is constant and 1, no input dropout will be added.
  • output_keep_prob: unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added.
  • state_keep_prob: unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. State dropout is performed on the outgoing states of the cell. Note the state components to which dropout is applied when state_keep_prob is in (0, 1) are also determined by the argumentdropout_state_filter_visitor (e.g. by default dropout is never applied to the c component of an LSTMStateTuple).
  • variational_recurrent: Python bool. If True, then the same dropout pattern is applied across all time steps per run call. If this parameter is set, input_size must be provided.
  • input_size: (optional) (possibly nested tuple of) TensorShape objects containing the depth(s) of the input tensors expected to be passed in to the DropoutWrapper. Required and used iff variational_recurrent = True and input_keep_prob < 1.
  • dtype: (optional) The dtype of the input, state, and output tensors. Required and used iffvariational_recurrent = True.
  • seed: (optional) integer, the randomness seed.
  • dropout_state_filter_visitor: (optional), default: (see below). Function that takes any hierarchical level of the state and returns a scalar or depth=1 structure of Python booleans describing which terms in the state should be dropped out. In addition, if the function returns True, dropout is applied across this sublevel. If the function returns False, dropout is not applied across this entire sublevel. Default behavior: perform dropout on all terms except the memory (c) state of LSTMCellState objects, and don't try to apply dropout to TensorArray objects: def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): # Never perform dropout on the c state. return LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray): return False return True
  • **kwargs: dict of keyword arguments for base layer.
参数:

  • cell:一个RNNCell,向它添加一个到output_size的投影。
  • input_keep_prob:单位张量或浮点数在0到1之间,输入保持概率;如果是常数和1,则不添加输入dropout。
  • output_keep_prob:单位张量或浮动在0和1之间,输出保持概率;如果是常数和1,则不添加输出dropout。
  • state_keep_prob:单位张量或浮点数在0到1之间,输出保持概率;如果是常数和1,则不添加输出dropout。状态退出是在计算单元的输出状态上执行的。注意,当state_keep_prob位于(0,1)中时,dropout应用到的状态组件也由argumentdropout_state_filter_visitor(例如。默认情况下,dropout从不应用于LSTMStateTuple的c组件)。
  • variational_recurrent: Python布尔类型。如果为真,则在每次运行调用的所有时间步上应用相同的退出模式。如果设置了该参数,则必须提供input_size。
  • input_size:(可选的)(可能嵌套的元组)TensorShape对象,包含期望传递给DropoutWrapper的输入张量的深度。需要和使用的iff variational_= True和input_keep_prob < 1。
  • (可选)输入、状态和输出张量的dtype。需要和使用iffvariational_= True。
  • 种子:(可选)整数,随机种子。
  • dropout_state_filter_visitor:(可选),默认:(见下)。函数,该函数接受状态的任何层次结构,并返回一个标量或深度=1的Python布尔值结构,该结构描述应该删除状态中的哪些项。此外,如果函数返回True,则在此子层上应用dropout。如果函数返回False,则不会在整个子层上应用dropout。默认行为:除了LSTMCellState对象的内存(c)状态外,在所有条件下执行dropout,并且不要试图将dropout应用到TensorArray对象:def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): #永远不要在c状态下执行dropout。返回LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray):返回False返回True
  • **kwargs:基层关键字参数的字典。

Methods

get_initial_state

View source

get_initial_state(inputs=None, batch_size=None, dtype=None
)

zero_state

View source

zero_state(batch_size, dtype
)

1.2、案例应用

相关文章:TF之LSTM:利用多层LSTM算法对MNIST手写数字识别数据集进行多分类


lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)      #定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob) #添加 dropout layer, 一般只设置 output_keep_prob

2、tf.contrib.rnn.MultiRNNCell函数解读与理解

2.1、源代码解读

RNN cell composed sequentially of multiple simple cells. RNN细胞由多个简单细胞依次组成。
tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True
)

Args:

  • cells: list of RNNCells that will be composed in this order.
  • state_is_tuple: If True, accepted and returned states are n-tuples, where n = len(cells). If False, the states are all concatenated along the column axis. This latter behavior will soon be deprecated.

参数:

单元格:按此顺序组成的RNNCells列表。
state_is_tuple:如果为真,则接受状态和返回状态为n元组,其中n = len(cell)。如果为假,则所有状态都沿着列轴连接。后一种行为很快就会被摒弃。

Methods

get_initial_state

View source

get_initial_state(inputs=None, batch_size=None, dtype=None
)

zero_state

View source

zero_state(batch_size, dtype
)

Return zero-filled state tensor(s).

Args:

  • batch_size: int, float, or unit Tensor representing the batch size.
  • dtype: the data type to use for the state.
 

Returns:

If state_size is an int or TensorShape, then the return value is a N-D tensor of shape [batch_size, state_size]filled with zeros.

If state_size is a nested list or tuple, then the return value is a nested list or tuple (of the same structure) of 2-Dtensors with the shapes [batch_size, s] for each s in state_size.

返回

如果state_size是一个int或TensorShape,那么返回值就是一个包含0的shape [batch_size, state_size]的N-D张量。

如果state_size是一个嵌套列表或元组,那么返回值就是一个嵌套列表或元组(具有相同结构)的2-张量,其中每个s的形状[batch_size, s]为state_size中的每个s。

2.2、案例应用

相关文章:DL之LSTM:LSTM算法论文简介(原理、关键步骤、RNN/LSTM/GRU比较、单层和多层的LSTM)、案例应用之详细攻略

num_units = [128, 64]
cells = [BasicLSTMCell(num_units=n) for n in num_units]
stacked_rnn_cell = MultiRNNCell(cells)

Tensorflow:tf.contrib.rnn.DropoutWrapper函数(谷歌已经为Dropout申请了专利!)、MultiRNNCell函数的解读与理解相关推荐

  1. DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

    DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读 目录 tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读 函 ...

  2. RNN调试错误:lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size) 方法已失效

    调试递归神经网络(RNN)的时候出现如下错误: ### module 'tensorflow.contrib.rnn' has no attribute 'core_rnn_cell' 经检查是tf. ...

  3. TensorFlow——tf.contrib.layers库中的相关API

    在TensorFlow中封装好了一个高级库,tf.contrib.layers库封装了很多的函数,使用这个高级库来开发将会提高效率,卷积函数使用tf.contrib.layers.conv2d,池化函 ...

  4. 案例:谷歌人工智能算法Dropout申请专利

    2019年6月26日,谷歌对Dropout算法提出的专利申请正式生效,专利有效期为15年,2034年9月3日到期.Dropout算法最早由Hinton于2012年提出,是一种在深度学习.训练神经网络时 ...

  5. RNN循环神经网络的直观理解:基于TensorFlow的简单RNN例子

    RNN 直观理解 一个非常棒的RNN入门Anyone Can learn To Code LSTM-RNN in Python(Part 1: RNN) 基于此文章,本文给出我自己的一些愚见 基于此文 ...

  6. TensorFlow 2——替换【tensorflow.compat.v1.contrib.rnn.LSTMCell】解决方案

    问题描述 Traceback (most recent call last):   File "D:/Code/Project/a18/ocr/demo.py", line 16, ...

  7. TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn

    原文教程:tensorflow官方教程 记录关键内容与学习感受.未完待续.. Creating Estimators in tf.contrib.learn --tf.contrib.learn框架, ...

  8. Tensorflow— 递归神经网络RNN

    代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_datamnist = input_ ...

  9. TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例

    TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例 目录 输出结果 代码设计 输出结果 后期更新-- 代码设计 import tensorflow ...

最新文章

  1. BciPy: 一款基于Python用于BCI研究的开源软件
  2. vs调试按钮为灰色的_IntelliJ IDEA 调试 Java 8,实在太香了
  3. nginx A/B 灰色发布
  4. linux ntfs 用户权限,Linux在NTFS中创建的文件的权限
  5. Linux命令:mkdir
  6. fread函数和fwrite函数,read,write
  7. C# -- 在底图上动态生成文字和图片
  8. Visual FoxPro权威指南pdf
  9. CSDN优质博主推荐(C/C++领域)-持续更新中
  10. iphone开蓝牙wifi上网慢_苹果iphone 7手机连接wifi网速很慢怎么办?
  11. PAT乙级1068 万绿丛中一点红(测试点3、测试点5)
  12. C# Excel 新建工作表,新增工作表,更改工作表的名字
  13. 【购房必备知识】成都公积金贷款/商业贷款的一些知识记录
  14. 警惕安全档案的陷阱 | 确认偏见
  15. 电脑桌面图标变成白色图标如何处理
  16. Python爬虫+数据分析,2019年你想看的A股牛市都在这里了!
  17. 放弃谷歌实习转投ICC,我是如何曲线上岸G家的?
  18. 如何修改Moodle上传文件大小的限制
  19. js中唤醒弹框的3种方式
  20. CUDA安装失败(已解决)

热门文章

  1. Toolbar中Overflow Menu不显示问题
  2. Network device support
  3. MySQL 查询重复数据,删除重复数据保留id最小的一条作为唯一数据
  4. Android 消息机制详解(Android P)
  5. 如何通过构建以太坊智能合约来销售商品
  6. leetcode396. Rotate Function
  7. UIPasteboard
  8. 你不知道的composer自动加载
  9. 如何将已有mdf文件导入到SQL 2000 或者 SQL 2005
  10. 消息队列面试连环炮,你抗得住吗?