反向传播(Back Propagation)是一种与最优化方法(比如梯度下降方法)结合、用来训练人工神经网络的常见方法。该方法对网络中所有权重计算损失函数的梯度,然后将梯度反馈给最优化方法,用来更新网络权重以最小化损失函数。

在神经网络中一个典型的问题就是梯度消失(Gradient Vanishing)的问题,其原因在于是随着神经网络层数的加深,梯度逐渐减小甚至接近0,当梯度变的非常小的时候,就不能为学习提供足够的信息。

Recurrent Neural Networks(递归神经网络,RNN)也存在梯度消失的问题,当输入的序列足够长时,RNN前期的层通常通常由于梯度消失而停止学习,从而导致RNN只拥有短期记忆。也就是说如果输入的序列的足够长,RNN在处理序列后面的信息时,可能已经将序列前面的信息丢失或者遗忘了,RNN很难完整的传递完整的长序列信息。

从而引出其解决方案LSTM和GRU

LSTM和GRU的短期记忆的解决方案,它通过门控(Gates)机制调节信息的流向。Gates可以学习到序列数据中哪些信息是重要的,需要保留;哪些信息是不重要的,可以丢弃,从而解决长序列的信息传递问题。

关于RNN基础

推荐阅读:一文搞懂RNN(循环神经网络)基础篇

下面我以机器翻译为例,RNN先将句子和词语转换为机器可识别的词向量(Word Vector),然后把词向量one by one的交给RNN进行处理。

RNN处理信息过程中,Hidden State作为神经网络的记忆数据,保留了神经网络已经看到的历史信息。通过将前一单元的Hidden State传递给后一个RNN单元,从而实现对历史信息的记忆。

RNN计算Hidden State的过程如下:它将当前RNN单元的输入和前一个RNN单元输出的Hidden State组合起来,经过一个Tanh激活函数,生成当前单元的Hidden State。

Tanh激活函数将输入值压缩至-1和1之间。

如果不采用Tanh激活函数,数据流经多层神经网络后,个别维度会出现急剧膨胀,甚至变成一个天文数字。下图是神经网络每层都对输入数据做了一个x3的操作的效果。

Tanh函数确保网络的输出值在-1与1之间,下图是同样的输入数据流经激活函数为Tanh的多层神经网络的效果。

LSTM

LSTM与RNN有相似的数据流控制机制,差别在于LSTM Cell内部的运作逻辑

LSTM的核心概念是Cell States和各种Gates。Cell State在整个序列的处理过程中都携带相关的信息,即使较早时刻的信息也可以很好的保留,从而降低了短期记忆问题的影响。

Gates都是不同神经网络,它们可以决定哪些信息需要保留在Cell States中,哪些信息需要遗忘。

Sigmoid

Gates中使用了Sigmoid激活函数,Sigmoid激活函数与Tanh激活函数类似,只不过它不是将所有输入数据压缩到(-1, 1)之间,而是将输入数据压缩到(0, 1)之间。Sigmoid激活函数对于Gates数据更新或者遗忘数据非常有用,因为任意数值乘以0都等于0,从而使得这些数据被遗忘或则消失;任意数值乘以1都等于原数值,从而使得这些信息保留下来。所以最终Gates通过训练可以哪些数据是重要的,需要保留;哪些数据是不重要的,需要遗忘。

下面在深入看看各种Gates都做了什么。

Forget Gate

Forget Gate决定哪些信息需要丢弃,哪些信息需要保留。它合并前一个Hidden State和当前的Input信息,然后输入Sigmoid激活函数,输出(0,1)之间的数值。输出值接近0的信息需要被遗忘,输出值接近1的信息需要被保留。

Input Gate

Input Gate首先将前一个Hidden State和当前Input合并起来,送入Sigmoid函数,输出(0,1)之间的值,0表示信息不重要,1表示信息重要。

Hidden State和Input的输入同时也被送入Tanh函数,输出(-1, 1)之间的数值。

Sigmoid的输出和Tanh的输出相乘,决定哪些Tanh的输出信息需要保留,哪些输出信息需要丢弃。

Cell State

前一个Cell State的输出,首先与Forget Gate的输出相乘,选择性的遗忘不重要的信息,再与Input Gate的输出相加,从而实现将当前的Input信息添加到Cell State中,输出新的Cell State。

Output Gate

Output Gate用于输出Hidden State。Output Gate首先将前一个Hidden State和当前Input送入Sigmoid函数,然后与新的Cell State通过Tanh函数的输出相乘,决定Hidden State要将哪些信息携带到下一个Time Step。

概括来说,就是Forget Gate决定哪些历史信息要保留;Input Gate决定哪些新的信息要添加进来;Output Gate决定下一个Hidden State要携带哪些历史信息。

下面以杰伦哥的歌词数据集为例

  • 读取数据集
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()
  • 初始化模型参数
#下面的代码对模型参数进行初始化。超参数num_hiddens定义了隐藏单元的个数
num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)def get_params():def _one(shape):ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)return torch.nn.Parameter(ts, requires_grad=True)def _three():return (_one((num_inputs, num_hiddens)),_one((num_hiddens, num_hiddens)),torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))W_xi, W_hi, b_i = _three()  # 输入门参数W_xf, W_hf, b_f = _three()  # 遗忘门参数W_xo, W_ho, b_o = _three()  # 输出门参数W_xc, W_hc, b_c = _three()  # 候选记忆细胞参数# 输出层参数W_hq = _one((num_hiddens, num_outputs))b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)return nn.ParameterList([W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q])
  • 定义模型
#在初始化函数中,长短期记忆的隐藏状态需要返回额外的形状为(批量大小, 隐藏单元个数)的值为0的记忆细胞。def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), torch.zeros((batch_size, num_hiddens), device=device))

下面根据长短期记忆的计算表达式定义模型。需要注意的是,只有隐藏状态会传递到输出层,而记忆细胞不参与输出层的计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)C_tilda = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)C = F * C + I * C_tildaH = O * C.tanh()Y = torch.matmul(H, W_hq) + b_qoutputs.append(Y)return outputs, (H, C)
  • 训练模型并创作歌词

我们在训练模型时只使用相邻采样。设置好超参数后,我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。

num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

我们每过40个迭代周期便根据当前训练的模型创作一段歌词。

d2l.train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,vocab_size, device, corpus_indices, idx_to_char,char_to_idx, False, num_epochs, num_steps, lr,clipping_theta, batch_size, pred_period, pred_len,prefixes)

输出:

epoch 40, perplexity 211.416571, time 1.37 sec- 分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我- 不分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我
epoch 80, perplexity 67.048346, time 1.35 sec- 分开 我想你你 我不要再想 我不要这我 我不要这我 我不要这我 我不要这我 我不要这我 我不要这我 我不- 不分开 我想你你想你 我不要这不样 我不要这我 我不要这我 我不要这我 我不要这我 我不要这我 我不要这我
epoch 120, perplexity 15.552743, time 1.36 sec- 分开 我想带你的微笑 像这在 你想我 我想你 说你我 说你了 说给怎么么 有你在空 你在在空 在你的空 - 不分开 我想要你已经堡 一样样 说你了 我想就这样着你 不知不觉 你已了离开活 后知后觉 我该了这生活 我
epoch 160, perplexity 4.274031, time 1.35 sec- 分开 我想带你 你不一外在半空 我只能够远远著她 这些我 你想我难难头 一话看人对落我一望望我 我不那这- 不分开 我想你这生堡 我知好烦 你不的节我 后知后觉 我该了这节奏 后知后觉 又过了一个秋 后知后觉 我该

在Gluon中我们可以直接调用rnn模块中的LSTM类。

lr = 1e-2 # 注意调整学习率
lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(lstm_layer, vocab_size)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes)

GRU

推荐阅读:人人都能看懂的GRU
GRU与LSTM非常相似,但它去除了Cell State,使用Hidden State来传递信息。GRU只有两个Gates: Reset Gate和Update Gate。
图片来自李宏毅老师

其中 xt和 ht−1是GRU的输入, yt和 ht是GRU的输出。 \text { 其中 } x^{t} \text { 和 } h^{t-1} \text { 是GRU的输入, } y^{t} \text { 和 } h^{t} \text { 是GRU的输出。 } 其中 xt 和 ht−1 是GRU的输入, yt 和 ht 是GRU的输出。 

如上图所示,r是Reset Gate,z为Update Gate。

通过重置门(Reset Gate)处理前一个Cell的输出:


再将h t−1与 x′t^{t-1} \stackrel{\prime}{\text { 与 } x}{ }^{t}t−1 与 x′t 进行拼接, 送入Tanh激活函数得到 h′h^{\prime}h′
最后进行记忆更新的步骤:
ht=(1−z)⊙ht−1+z⊙h′h^{t}=(1-z) \odot h^{t-1}+z \odot h^{\prime} ht=(1−z)⊙ht−1+z⊙h′
Update Gate z的范围为0 1,它的值越接近1,代表记忆数据越多; 它的值越接近
0,则代表遗忘的越多。

相比LSTM, GRU的Tensor Operation更少, 因而训练速度更快, 并且效果与LSTM不相上下。

参考自;

  1. 人人都能看懂的GRU
  2. Illustrated Guide to LSTM’s and GRU’s: A step by step explanation
  3. 长短期记忆(LSTM)

LSTM和GRU介绍相关推荐

  1. 难以置信!LSTM和GRU的解析从未如此清晰(动图+视频)

    作者 | Michael Nguyen 编译 | 蔡志兴.费棋 编辑 | Jane 出品 | AI科技大本营 [导语]机器学习工程师 Michael Nguyen 在其博文中发布了关于 LSTM 和 ...

  2. 循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用

    循环神经网络实现文本情感分类之Pytorch中LSTM和GRU模块使用 1. Pytorch中LSTM和GRU模块使用 1.1 LSTM介绍 LSTM和GRU都是由torch.nn提供 通过观察文档, ...

  3. 从LSTM到GRU基于门控的循环神经网络总结

    1.概述 为了改善基本RNN的长期依赖问题,一种方法是引入门控机制来控制信息的累积速度,包括有选择性地加入新的信息,并有选择性遗忘之前累积的信息.下面主要介绍两种基于门控的循环神经网络:长短时记忆网络 ...

  4. 【强烈推荐】最好理解的LSTM与GRU教程

    AI识别你的语音.回答你的问题.帮你翻译外语,都离不开一种特殊的循环神经网络(RNN):长短期记忆网络(Long short-term memory,LSTM). 最近,国外有一份关于LSTM及其变种 ...

  5. (Unfinished)RNN-循环神经网络之LSTM和GRU-04介绍及推导

    (Unfinished)尚未完成 一.说明 关于LSTM的cell结构和一些计算在之前已经介绍了,可以点击这里查看 本篇博客主要涉及一下内容: LSTM前向计算说明(之前的博客中LSTM部分实际已经提 ...

  6. 了解LSTM和GRU

    lstm和gru 深度学习 , 自然语言处理 (Deep Learning, Natural Language Processing) In my last article, I have intro ...

  7. 深度学习 《LSTM和GRU模型》

    前言: 前面我们学习了标准的单向单层和单向多层的RNN,这一博文我来介绍RNN的改进版本LSTM和GRU,至于为什么有这个改进的方案,以及如何理解它们,我会尽量用最通俗的语言俩表达. 学习自博客htt ...

  8. RNN、LSTM、GRU的原理和实现

    个人博客:http://www.chenjianqu.com/ 原文链接:http://www.chenjianqu.com/show-41.html 用python实现了经典的RNN,LSTM和GR ...

  9. 循环神经网络之LSTM、GRU

    循环神经网络之LSTM.GRU 1. 什么是 LSTM? LSTM(Long short-term memory,长短期记忆)是一种特殊的RNN,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题 ...

最新文章

  1. 在Docker应用场景下 如何使用新技术快速实现DevOps
  2. 项目: 用Easyx绘制围棋和象棋的棋盘
  3. 操作Frame和IFrame中页面元素
  4. Vue中计算属性与class,style绑定
  5. android fragment 嵌套,Fragment嵌套Fragment时遇到的那些坑
  6. 从C语言到C++语言
  7. 人与人之间交往最重要的是什么?
  8. IPTV软件如何做自己的广告系统?
  9. MatLab 2016b下载资源
  10. 交换机VLAN 模式trunk和access 区别
  11. ubuntu system setting no everthing
  12. 【电动车】电动汽车两阶段优化调度策略(Matlab代码实现)
  13. win10计算机用户名和密码怎么找到,win10怎么查看wifi账号密码_win10电脑怎么看wifi密码...
  14. 如何从Win11系统安装回win10系统?
  15. C#分享网址到QQ空间带参数
  16. Butter Knife[黄油刀]配置
  17. 11 | 向埃隆·马斯克学习任务分解
  18. 致所有初学者--助力所有ERP初学者!!!
  19. 互动机顶盒与普通机顶盒比较
  20. python读入图像是四维,需要将其转换为三维图像

热门文章

  1. Hystrix面试 - 深入 Hystrix 断路器执行原理
  2. 消息队列面试 - 如何保证消息的可靠性传输?
  3. 【Linux系列】Linux基础知识整理
  4. python 生成 和 加载 requirements.txt
  5. 函数式编程 -- 函数组合
  6. 【Java】用键盘输入若干数字,以非数字字符结束,计算这些数的和和平均值
  7. C# 9.0 新功能一览
  8. C#LeetCode刷题-双指针
  9. 编写区块链_编写由区块链驱动的在线社区的综合指南
  10. 猛男教你写代码_猛男程序员,鼓存储器和1960年代机器代码的取证分析