动手学深度学习(三十九)——门控循环单元GRU
文章目录
- 门控循环单元(GRU)
- 一、门控隐藏状态
- 1.1 重置门和更新门
- 1.2候选隐藏状态
- 1.3 隐藏状态
- 二、从零实现GRU
- 2.1 初始化模型参数
- 2.2 定义模型
- 2.3 训练与预测
- 2.4 简洁实现
- 三、小结
- 四、练习
再次声明:本文主要参考李沐老师B站动手学深度学习课程进行笔记整理和代码复现。如果需要看视频,可以click。感谢沐神的分享!!!
门控循环单元(GRU)
我们想想对于一个序列而言,有的早期观测值对所有的未来观测值都非常有用,有的观测值对所有的未来预测都没有用,或者说有的序列各个部分之间有逻辑中断。总结起来就是:
- 并不是每个观测值都是同等重要
- 想要只记住相关的观察需要:
- 能关注的机制(更新门)
- 能遗忘的机制(重置门)
在学术界已经提出了许多方法来解决这个问题。其中最早的方法是"长-短期记忆" (long-short-term memory, LSMT):(Hochreiter.Schmidhuber.1997
) 。门控循环单元(gated recurrent unit, GRU)(Cho.Van-Merrienboer.Bahdanau.ea.2014
) 是一个稍微简化的变体,通常能够提供同等的效果,并且计算 (Chung.Gulcehre.Cho.ea.2014
) 的速度明显更快。由于它更简单,就让我们从门控循环单元开始。
来都来了,把这几篇文章都贴出来吧:
【1】LSTM:Long Short-Term Memory
【2】GRU:Learning Phrase Representations using RNN Encoder–Decoderfor Statistical Machine Translation
一、门控隐藏状态
普通的循环神经网络和门控循环单元之间的关键区别在于后者支持隐藏状态的门控(或者说选通)。这意味着有专门的机制来确定应该何时 更新 隐藏状态,以及应该何时 重置 隐藏状态。这些机制是可学习的,并且能够解决了上面列出的问题。
例如,如果第一个标记非常重要,我们将学会在第一次观测之后不更新隐藏状态。同样,我们也可以学会跳过不相关的临时观测。最后,我们还将学会在需要的时候重置隐藏状态。
1.1 重置门和更新门
我们首先要介绍的是 重置门(reset gate)和 更新门(update gate)。我们把它们设计成 (0,1)(0, 1)(0,1) 区间中的向量,这样我们就可以进行凸组合。例如,重置门允许我们控制可能还想记住的过去状态的数量。同样,更新门将允许我们控制新状态中有多少个是旧状态的副本。
我们从构造这些门控开始。下图描述了门控循环单元中的重置门和更新门的输入,输入是由当前时间步的输入和前一时间步的隐藏状态给出。两个门的输出是由使用 sigmoid 激活函数的两个全连接层给出。
数学描述,对于给定的时间步 ttt,假设输入是一个小批量 Xt∈Rn×d\mathbf{X}_t \in \mathbb{R}^{n \times d}Xt∈Rn×d (样本个数:nnn,输入个数:ddd),上一个时间步的隐藏状态是 Ht−1∈Rn×h\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}Ht−1∈Rn×h(隐藏单元个数:hhh)。然后,重置门 Rt∈Rn×h\mathbf{R}_t \in \mathbb{R}^{n \times h}Rt∈Rn×h 和更新门 Zt∈Rn×h\mathbf{Z}_t \in \mathbb{R}^{n \times h}Zt∈Rn×h 的计算如下:
Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),\begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),
其中 Wxr,Wxz∈Rd×h\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}Wxr,Wxz∈Rd×h 和 Whr,Whz∈Rh×h\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}Whr,Whz∈Rh×h 是权重参数,br,bz∈R1×h\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}br,bz∈R1×h 是偏置参数。请注意,在求和过程中会触发广播机制(请参阅 :numref:subsec_broadcasting
)。我们使用 sigmoid 函数(如 :numref:sec_mlp
中介绍的)将输入值转换到区间 (0,1)(0, 1)(0,1)。
1.2候选隐藏状态
接下来,让我们将重置门 Rt\mathbf{R}_tRt 与RNN中的常规隐状态更新机制集成,得到在时间步 ttt 的候选隐藏状态 H~t∈Rn×h\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}H~t∈Rn×h。
H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),
其中 Wxh∈Rd×h\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}Wxh∈Rd×h 和 Whh∈Rh×h\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}Whh∈Rh×h 是权重参数,bh∈R1×h\mathbf{b}_h \in \mathbb{R}^{1 \times h}bh∈R1×h 是偏置项,符号 ⊙\odot⊙ 是哈达码乘积(按元素乘积)运算符。在这里,我们使用 tanh 非线性激活函数来确保候选隐藏状态中的值保持在区间 (−1,1)(-1, 1)(−1,1) 中。
计算的结果是 候选者(candidate),因为我们仍然需要结合更新门的操作。与基础的RNN相比 候选隐藏状态中的 Rt\mathbf{R}_tRt 和 Ht−1\mathbf{H}_{t-1}Ht−1 的元素相乘可以减少以往状态的影响。每当重置门 Rt\mathbf{R}_tRt 中的项接近 111 时,我们恢复一个如基本RNN中的普通的循环神经网络。对于重置门 Rt\mathbf{R}_tRt 中所有接近 000 的项,候选隐藏状态是以 Xt\mathbf{X}_tXt 作为输入的多层感知机的结果。因此,任何预先存在的隐藏状态都会被 重置 为默认值。下图明了应用重置门之后的计算流程。
1.3 隐藏状态
最后,我们需要结合更新门 Zt\mathbf{Z}_tZt 的效果。这确定新的隐藏状态 Ht∈Rn×h\mathbf{H}_t \in \mathbb{R}^{n \times h}Ht∈Rn×h 在多大程度上就是旧的状态 Ht−1\mathbf{H}_{t-1}Ht−1 ,以及对新的候选状态 H~t\tilde{\mathbf{H}}_tH~t 的使用量。更新门 Zt\mathbf{Z}_tZt 仅需要在 Ht−1\mathbf{H}_{t-1}Ht−1 和 H~t\tilde{\mathbf{H}}_tH~t 之间进行按元素的凸组合就可以实现这个目标。这就得出了门控循环单元的最终更新公式:
Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.
每当更新门 Zt\mathbf{Z}_tZt 接近 111 时,我们就只保留旧状态。此时,来自 Xt\mathbf{X}_tXt 的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 ttt。相反,当 Zt\mathbf{Z}_tZt 接近 000 时,新的隐藏状态 Ht\mathbf{H}_tHt 就会接近候选的隐藏状态 H~t\tilde{\mathbf{H}}_tH~t。==这些设计可以帮助我们处理循环神经网络中的梯度消失问题,并更好地捕获时间步距离很长的序列的依赖关系。==例如,如果整个子序列的所有时间步的更新门都接近于 111,则无论序列的长度如何,在序列起始时间步的旧隐藏状态都将很容易保留并传递到序列结束。下图说明了更新门起作用后的计算流。
总之,门控循环单元具有以下两个显著特征:
- 重置门有助于捕获序列中的短期依赖关系。
- 更新门有助于捕获序列中的长期依赖关系。
二、从零实现GRU
import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
2.1 初始化模型参数
从标准差为0.01的高斯分布中提取权重,偏置设置为0,使用超参数num_hidden
定义隐藏单元的数量,实例化与更新门、重置门、候选状态和输出层相关的所有权重和偏置
def get_params(vocab_size,num_hiddens,device):num_inputs= num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape,device=device)*0.01def three():return (normal((num_inputs,num_hiddens)),normal((num_hiddens,num_hiddens)),torch.zeros(num_hiddens,device=device))W_xz,W_hz,b_z = three() # 更新门参数W_xr,W_hr,b_r = three() # 重置门参数W_xh,W_hh,b_h = three() # 候选状态参数# 输出层参数W_hq = normal((num_hiddens,num_outputs))b_q = torch.zeros(num_outputs,device=device)# 附加梯度# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params
2.2 定义模型
# 定义一个隐藏状态的初始化函数,返回一个形状为(批量大小,隐藏单元数)的张量,值全为0
def init_gru_state(batch_size,num_hiddens,device):return (torch.zeros((batch_size,num_hiddens),device=device),)
Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.\begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z),\\ \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),\\ \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t. \end{aligned} Rt=σ(XtWxr+Ht−1Whr+br),Zt=σ(XtWxz+Ht−1Whz+bz),H~t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh),Ht=Zt⊙Ht−1+(1−Zt)⊙H~t.
# 定义GRU模型
def gru(inputs,state,params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H)@W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)
2.3 训练与预测
训练和预测的工作方式与RNN中的实现完全相同。训练结束后,我们分别打印输出训练集的困惑度和前缀“time traveler”和“traveler”的预测序列上的困惑度。
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 57290.3 tokens/sec on cuda:0
time travelleryou can show black is white by argument said filby
traveller with a slight accession ofcheerfulness really thi
2.4 简洁实现
高级API包含了前文介绍地全部配置细节,所以可以直接实例化GRU。其使用编译好的运算符来进行计算,而非python处理其中的许多细节
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs,num_hiddens)
model = d2l.RNNModel(gru_layer,len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 1.0, 353447.0 tokens/sec on cuda:0
time traveller for so it will be convenient to speak of himwas e
travelleryou can show black is white by argument said filby
三、小结
- 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。
- 重置门有助于捕获序列中的短期依赖关系。
- 更新门有助于捕获序列中的长期依赖关系。
- 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。
四、练习
- 假设我们只想使用时间步 t′t't′ 的输入来预测时间步 t>t′t > t't>t′ 的输出。对于每个时间步,重置门和更新门的最佳值是什么?
更新们和重置门都为0表示不使用之前的隐藏状态数据
- 调整和分析超参数对运行时间、困惑度和输出顺序的影响。
- 比较
rnn.RNN
和rnn.GRU
的不同实现对运行时间、困惑度和输出字符串的影响。 - 如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?
动手学深度学习(三十九)——门控循环单元GRU相关推荐
- 花书+吴恩达深度学习(十七)序列模型之长短期记忆 LSTM 和门控循环单元 GRU
目录 0. 前言 1. 长短期记忆 LSTM 2. 门控循环单元 GRU 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花书+吴恩达深度学习(十五)序列模型之循环神经网络 ...
- 推荐系统遇上深度学习(三十九)-推荐系统中召回策略演进!
推荐系统中的核心是从海量的商品库挑选合适商品最终展示给用户.由于商品库数量巨大,因此常见的推荐系统一般分为两个阶段,即召回阶段和排序阶段.召回阶段主要是从全量的商品库中得到用户可能感兴趣的一小部分候选 ...
- 动手学深度学习(十四)——权重衰退
文章目录 1. 如何缓解过拟合? 2. 如何衡量模型的复杂度? 3. 通过限制参数的选择范围来控制模型容量(复杂度) 4. 正则化如何让权重衰退? 5. 可视化地看看正则化是如何利用权重衰退来达到缓解 ...
- 门控循环单元-GRU单元(Gated Recurrent Unit)
来源:Coursera吴恩达深度学习课程 接下来我们将会学习门控循环单元(Gated Recurrent Unit),它改变了RNN的隐藏层,使其可以更好地捕捉深层连接,并改善了梯度消失问题,让我们看 ...
- 【动手学习pytorch笔记】24.门控循环单元GRU
GRU 序列中并不是所有信息都同等重要,为了记住重要的信息和遗忘不重要的信息,最早的方法是"长短期记忆"(long-short-term memory,LSTM),这节门控循环单元 ...
- 动手学深度学习(PyTorch实现)(九)--VGGNet模型
VGGNet模型 1. VGGNet模型介绍 1.1 VGGNet的结构 1.2 VGGNet结构举例 2. VGGNet的PyTorch实现 2.1 导入相应的包 2.2 基本网络单元block 2 ...
- 现代循环神经网络-1.门控循环单元(GRU)【动手学深度学习v2】
文章目录 1.门控循环单元(GRU) 1.1 门控隐状态 A.重置门与更新门 B.候选隐状态 C.隐状态 1.2 GRU的实现 A.从零实现 B.简洁实现 1.门控循环单元(GRU) GRU是一个比较 ...
- 【动手学深度学习】李沐——循环神经网络
本文内容目录 序列模型 文本预处理 语言模型和数据集 循环神经网络 RNN的从零开始实现 RNN的简洁实现 通过时间反向传播 门控循环单元GRU 长短期记忆网络(LSTM) 深度循环神经网络 双向循环 ...
- 回归预测 | MATLAB实现CNN-GRU(卷积门控循环单元)多输入单输出
回归预测 | MATLAB实现CNN-GRU(卷积门控循环单元)多输入单输出 目录 回归预测 | MATLAB实现CNN-GRU(卷积门控循环单元)多输入单输出 基本介绍 模型结构 CNN神经网络 G ...
最新文章
- mysql主键约束和唯一性约束
- Spring整合CXF,发布RSETful 风格WebService
- linux打开文件命令occ,Linux系统查看文件内容的命令有哪些?
- react-native 签名
- js日期初始化总结:new Date()参数设置
- 程序猿 自己所擅长的还是码代码 请远离 业务。
- 如何在不同开发语言中使用绑定变量_linux C/C++服务器后台开发面试题总结(编程语言篇)...
- wpf textbox能扫描不能手输_3D扫描仪性能怎么样 3D扫描仪价格介绍【详解】
- 架设WIN32汇编程序的开发环境
- win11如何禁用后台应用权限 Windows11禁用后台应用权限的设置方法
- OpenGL 坐标变换(2)
- lightbox的一个ajax效果
- SQL Server 不允许保存更改的解决方法
- Python 正则表达模块详解
- 计算机id dns知识,智能DNS解析知识集锦
- related knowledge points about protein
- vs2019 C#提示程序未兼容
- 6.7.1 机器人系统仿真/URDF、Gazebo与Rviz综合运用/机器人运动控制以及里程计信息显示
- Excel 表格删除重复数据
- window各版本回顾