深度学习之循环神经网络(10)GRU简介

  • 1. 复位门
  • 2. 更新门
  • 3. GRU使用方法

 LSTM具有更长的记忆能力,在大部分序列任务上面都取得了比基础RNN模型更好的性能表现,更重要的是,LSTM不容易出现梯度弥散现象。但是LSTM结构相对较复杂,计算代价较高,模型参数量较大。因此科学家们尝试简化LSTM内部的计算流程,特别是减少门控数量。研究发现,遗忘门是LSTM中最重要的门控 [1],甚至发现只有遗忘门的简化版网络在多个基准数据集上面优于标准LSTM网络。在众多的简化版LSTM中, 门控循环网络(Gated Recurrent Unit,简称GRU)是应用最广泛的RNN变种之一。GRU把内部状态向量和输出向量合并,统一为状态向量 h\boldsymbol hh,门控数量也较少到2个: 复位门(Reset Gate)更新门(Update Gate),如下图所示:
GRU网络结构

 下面我们来分别介绍复位门和更新门的原理与功能。

[1] J. Westhuizen 和 J. Lasenby, “The unreasonable effectiveness of the forget gate,” CoRR, 卷 abs/1804.04849, 2018.

1. 复位门

 复位门用于控制上一个时间戳的状态ht−1\boldsymbol h_{t-1}ht−1​进入GRU的量。门控向量gr\boldsymbol g_rgr​由当前时间戳输入xt\boldsymbol x_txt​和上一时间戳状态ht−1\boldsymbol h_{t-1}ht−1​变换得到,关系如下:
gr=σ(Wr[ht−1,xt]+br)\boldsymbol g_r=σ(\boldsymbol W_r [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_r)gr​=σ(Wr​[ht−1​,xt​]+br​)
其中Wr\boldsymbol W_rWr​和br\boldsymbol b_rbr​为复位门的参数,由反向传播算法自动优化,σσσ为激活函数,一般使用Sigmoid函数。门控向量gr=0\boldsymbol g_r=0gr​=0时,新输入h~t\tilde \boldsymbol h_th~t​全部来自于输入xt\boldsymbol x_txt​,不接受ht−1\boldsymbol h_{t-1}ht−1​,此时相当于复位ht−1\boldsymbol h_{t-1}ht−1​。当gr=1\boldsymbol g_r=1gr​=1时,ht−1h_{t-1}ht−1​和输入xt\boldsymbol x_txt​共同产生新输入h~t\tilde\boldsymbol h_th~t​,如下图所示:

复位门

2. 更新门

 更新门用控制上一时间戳状态ht−1\boldsymbol h_{t-1}ht−1​和新输入h~t\tilde\boldsymbol h_th~t​对新状态向量ht\boldsymbol h_tht​的影响程度。更新门控向量gz\boldsymbol g_zgz​由
gz=σ(Wz[ht−1,xt]+bz)\boldsymbol g_z=σ(\boldsymbol W_z [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_z)gz​=σ(Wz​[ht−1​,xt​]+bz​)
得到,其中Wz\boldsymbol W_zWz​和bz\boldsymbol b_zbz​为更新门的参数,由反向传播算法自动优化,σσσ为激活函数,一般使用Sigmoid函数。gz\boldsymbol g_zgz​用于控制新输入h~t\tilde\boldsymbol h_th~t​信号,1−gz1-\boldsymbol g_z1−gz​用于控制状态ht−1\boldsymbol h_{t-1}ht−1​信号:
ht=(1−gz)ht−1+gzh~t\boldsymbol h_t=(1-\boldsymbol g_z ) \boldsymbol h_{t-1}+\boldsymbol g_z \tilde\boldsymbol h_tht​=(1−gz​)ht−1​+gz​h~t​

更新门

可以看到,h~t\tilde\boldsymbol h_th~t​和ht−1\boldsymbol h_{t-1}ht−1​的更新量处于相互竞争、此消彼长的状态。当更新门gz=0\boldsymbol g_z=0gz​=0时,ht\boldsymbol h_tht​全部来自上一时间戳状态ht−1\boldsymbol h_{t-1}ht−1​;当更新门gz=1\boldsymbol g_z=1gz​=1时,ht\boldsymbol h_tht​全部来自新输入h~t\tilde\boldsymbol h_th~t​。

3. GRU使用方法

 同样地,在TensorFlow中,也有Cell方式和层方式实现GRU网络。GRUCell和GRU层的使用方法和之前的SimpleRNNCell、LSTMCell、SimpleRNN和LSTM非常类似。首先是GRUCell的使用,创建GRU Cell对象,并在时间轴上循环展开运算。例如:

import tensorflow as tf
from tensorflow.keras import layersx = tf.random.normal([2, 80, 100])
xt = x[:, 0, :]  # 得到一个时间戳的输入
# 初始化状态向量,GRU只有一个
h = [tf.zeros([2, 64])]
cell = layers.GRUCell(64)  # 新建GRU Cell,向量长度为64
# 在时间戳维度上解开,循环通过cell
for xt in tf.unstack(x, axis=1):out, h = cell(xt, h)
# 输出形状
print(out.shape)

运行结果如下所示:

(2, 64)

 通过layers.GRU类可以方便创建一层GRU网络层,通过Sequential容器可以堆叠多层GRU层的网络。例如:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Sequentialx = tf.random.normal([2, 80, 100])
xt = x[:, 0, :]  # 得到一个时间戳的输入
# 初始化状态向量,GRU只有一个
h = [tf.zeros([2, 64])]
net = keras.Sequential([layers.GRU(64, return_sequences=True),layers.GRU(64)
])
out = net(x)
# 输出形状
print(out.shape)

运行结果如下所示:

(2, 64)

深度学习之循环神经网络(10)GRU简介相关推荐

  1. 深度学习之循环神经网络(11-b)GRU情感分类问题代码

    深度学习之循环神经网络(11-b)GRU情感分类问题代码 1. Cell方式 代码 运行结果 2. 层方式 代码 运行结果 1. Cell方式 代码 import os import tensorfl ...

  2. 深度学习之循环神经网络(11)LSTM/GRU情感分类问题实战

    深度学习之循环神经网络(11)LSTM/GRU情感分类问题实战 1. LSTM模型 2. GRU模型  前面我们介绍了情感分类问题,并利用SimpleRNN模型完成了情感分类问题的实战,在介绍完更为强 ...

  3. 深度学习之循环神经网络(11-a)LSTM情感分类问题代码

    深度学习之循环神经网络(11-a)LSTM情感分类问题代码 1. Cell方式 代码 运行结果 2. 层方式 代码 运行结果 1. Cell方式 代码 import os import tensorf ...

  4. 深度学习之循环神经网络(6)梯度弥散和梯度爆炸

    深度学习之循环神经网络(6)梯度弥散和梯度爆炸  循环神经网络的训练并不稳定,网络的善妒也不能任意加深.那么,为什么循环神经网络会出现训练困难的问题呢?简单回顾梯度推导中的关键表达式: ∂ht∂hi= ...

  5. 深度学习之循环神经网络(4)RNN层使用方法

    深度学习之循环神经网络(4)RNN层使用方法 1. SimpleRNNCell 2. 多层SimpleRNNCell网络 3. SimpleRNN层  在介绍完循环神经网络的算法原理之后,我们来学习如 ...

  6. 深度学习之循环神经网络(2)循环神经网络原理

    深度学习之循环神经网络(2)循环神经网络原理 1. 全连接层 2. 共享权值 3. 全局语义 4. 循环神经网络  现在我们来考虑如何吃力序列信号,以文本序列为例,考虑一个句子: "I di ...

  7. 深度学习之循环神经网络(1)序列表示方法

    深度学习之循环神经网络(1)序列表示方法 序列表示方法 Embedding层 2. 预训练的词向量 前面的卷积神经网络利用数据的局部相关性和权值共享的思想大大减少了网络的参数量,非常适合于图片这种具有 ...

  8. 水很深的深度学习-Task05循环神经网络RNN

    循环神经网络 Recurrent Neural Network 参考资料: Unusual-Deep-Learning 零基础入门深度学习(5) - 循环神经网络 史上最小白之RNN详解_Tink19 ...

  9. 【深度学习】循环神经网络(RNN)的tensorflow实现

    [深度学习]循环神经网络(RNN)的tensorflow实现 一.循环神经网络原理 1.1.RNN的网络结构 1.2.RNN的特点 1.3.RNN的训练 二.循环神经网络的tensorflow实现 参 ...

最新文章

  1. mysql monday event_MySQL获取日期周、月、天,生成序号
  2. Spring Boot(04)——创建自己的自动配置
  3. 「日常训练」Skills(Codeforce Round #339 Div.2 D)
  4. 与大家分享一个我最近开始用的不错的JavaScript IDE
  5. golang 排序_常用排序算法之冒泡排序
  6. 数据类型之数字类型—运算符
  7. leetcode刷题:零钱兑换
  8. 远程下载马bypass waf
  9. POJ3264 Balanced Lineup【线段树】
  10. 在网页输入框输入角标_这个免费插件能帮我们把Excel内容快速填充到网页表单?...
  11. 09-Mysql数据库----外键的变种
  12. Bex5开发平台分辨率问题解决方法
  13. [词汇] 十四、动词
  14. oracle 产品宣传片,史上最牛宣传片!河南的美已惊艳了世界!
  15. 【数学建模】(五):MATLAB程序设计与积分
  16. Hopscotch(POJ-3050)
  17. postfix 测试邮件服务器,搭建Postfix邮件服务器
  18. 3D游戏设计——模型与动画
  19. JAVA对象布局之对象头(Object Header)
  20. iOS 中饼状图的自定义绘制

热门文章

  1. SVN trunk(主线) branch(分支) tag(标记) 用法详解和详细操作步骤
  2. php中的冒泡排序实例,PHP实现冒泡排序的简单实例,php冒泡排序_PHP教程
  3. IOS 企业版发布后,用户通过sarafi浏览器安装无效的解决方案
  4. iOS coredata 多表查询
  5. 挪车+php,还在苦苦寻找占你车位的人?关注这个微信号实现“一键挪车”
  6. WPF:Graphics绘图--Shapes形状
  7. 英国法院裁定GCHQ黑客发动网络攻击并不侵犯人权
  8. 【WPF】获取电磁笔的压感
  9. maven远程发布jar
  10. UILabel设定行间距方法