• RNN

    • 结构
    • 训练
    • 应用
  • RNN Variants
  • LSTM
    • 结构
    • 梯度消失及梯度爆炸
  • GRU
    • 结构

一般的神经网络输入和输出的维度大小都是固定的,针对序列类型(尤其是变长的序列)的输入或输出数据束手无策。RNN通过采用具有记忆的隐含层单元解决了序列数据的训练问题。LSTM、GRU属于RNN的改进,解决了RNN中梯度消失爆炸的问题,属于序列数据训练的常用方案。

RNN

结构

传统的神经网络的输入和输出都是确定的,RNN的输入和输出都是不确定的sequence数据。其结构如下:


具体地,RNN有隐含层,隐含层也是记忆层,其状态(权值)会传递到下一个状态中。

htyt=σ(xtWxh+ht−1Whh)=σ(htWhy)

\begin{split} h^t &= \sigma(x^tW_{xh} + h^{t-1}W_{hh}) \\ y^t &= \sigma(h^tW_{hy}) \end{split}

训练

训练步骤如下:

  1. 构建损失函数
  2. 求损失函数对权值的梯度
  3. 采用梯度下降法更新权值参数

关于损失函数,根据需要选择构建即可,下面提供两种常见的损失函数:

CC=12∑n=1N||yn−ŷ n||2=12∑n=1N−logynrn

\begin{split} C &= \frac{1}{2}\sum_{n=1}^N ||y^n-{\hat{y}}^n||^2 \\ C &= \frac{1}{2}\sum_{n=1}^N -\log y_{r^n}^n \end{split}

关于梯度下降,采用BPTT(Backpropagation through time)算法,该算法的核心是对每一个时间戳,计算该时间戳中权重的梯度,然后更新权重。需要注意的是,不同时间戳同样权重的梯度可能是不一样的,如下图所示都减去,相当于更新同一块内存区域中的权重。


应用

  • 多对多:词性标注pos tagging、语音识别、name entity recognition(区分poeple、organizations、places、information extration(区分place of departure、destination、time of departure、time of arrival, other)、机器翻译
  • 多对一:情感分析
  • 一对多:caption generation


RNN Variants

RNN的变种大致包含下面3个思路:

  • 增加隐含层的输入参数:例如除了ht−1,xth^{t-1}, x^t,还可以包含yt−1y^{t-1}作为输入。
  • 增加隐含层的深度
  • 双向RNN



LSTM

结构

  • 单个时间戳,RNN输入1个x,输出1个y
  • 单个时间戳,LSTM输入4个x,输出1个y

相比RNN,LSTM的输入多了3个x,对应3个gate,这3个gate分别是:

  • input gate:控制输入
  • forget gate:控制cell
  • output gate:控制输出

涉及到的激活函数共5个,其中3个控制gate的(通常用sigmoid函数,模拟gate的开闭状态),1个作用于输入上,一个作用于cell的输出上。

LSTM单个时间戳的具体执行如下:

  • 输入:4个输入xx,1个cell的状态cc
  • 输出:1个输出aa,1个更新的cell状态c′c'
c′a=g(z)f(zi)+cf(zf)=h(c′)f(zo)

\begin{split} c' &= g(z)f(z_i) + cf(z_f) \\ a &= h(c')f(z_o) \end{split}

梯度消失及梯度爆炸

首先,要明白RNN中梯度消失与梯度爆炸的原因:在时间戳的更新中,cell的状态不断乘以WhhW_{hh}。简单起见,视WhhW_{hh}为scalar值ww,那么y=xwny=xw^n,∂y∂w=nxwn−1\frac{\partial{y}}{\partial{w}}=nxw^{n-1}。根据ww的值与1的大小关系,梯度会消失或者爆炸。

接下来,要明白LSTM如何解决RNN中梯度消失与爆炸的问题。

针对梯度消失,RNN中当获取c′c'的梯度后,因为c′=cwc' = cw,为了backward获得cc的梯度,要将c′c'的梯度乘以ww;LSTM中存在梯度的快速通道,获取c′c'的梯度后,因为c′=g(z)f(zi)+cf(zf)c' = g(z)f(z_i)+cf(z_f),当forget gate打开时,c′=g(z)f(zi)+cc' = g(z)f(z_i)+c。c′c'的梯度可以直接传递给cc。
总结来说,LSTM相比RNN,将c,c′c,c'的更新关系从乘法变成了加法,因此不用乘以权值系数ww,c′c'的梯度可以直接传递给cc,解决了梯度消失的问题。

针对梯度爆炸,即使将c,c′c,c'的关系由乘法变成了加法,仍然解决不了梯度爆炸。原因便是梯度的路径不止一条,如下图所示,红色的块仍然可能造成梯度爆炸。LSTM解决这个问题的方法是clip,也就是设置梯度最大值,超过最大值的按最大值计。

GRU

结构

GRU相比LSTM的3个gate,只用了两个gate:

  • update gate:ztz_t
  • reset gate:rtr_t

记忆网络RNN、LSTM与GRU相关推荐

  1. 循环神经网络(RNN)与长短期记忆网络(LSTM)讲解

    循环神经网络(RNN) 对于典型的深度神经网络(DNN),就是通过在输入层与输出层之间增加隐藏层来构建网络,如下图所示. 与DNN不同的是,循环神经网络(RNN)赋予了网络对前面的内容的一种" ...

  2. 57 长短期记忆网络(LSTM)【动手学深度学习v2】

    57 长短期记忆网络(LSTM)[动手学深度学习v2] 深度学习学习笔记 学习视频:https://www.bilibili.com/video/BV1JU4y1H7PC/?spm_id_from=a ...

  3. 长短期记忆网络(LSTM)

    长短期记忆网络(LSTM) 1.LSTM介绍 LSTM 表示长短期记忆网络,当我们的神经网络需要在记忆最近的事物和很久以前的事情之间切换时,LSTM 是非常有用的. 2.RNN vs LSTM RNN ...

  4. 动手学深度学习(四十)——长短期记忆网络(LSTM)

    文章目录 一.长短期记忆网络(LSTM) 1.1 门控记忆单元 1.2 输入门.遗忘门与输出门 1.3候选记忆单元 1.4 记忆单元 1.5 隐藏状态 二.从零实现LSTM 2.1 初始化模型参数 2 ...

  5. 09.2. 长短期记忆网络(LSTM)

    文章目录 9.2. 长短期记忆网络(LSTM) 9.2.1. 门控记忆元 9.2.1.1. 输入门.忘记门和输出门 9.2.1.2. 候选记忆元 9.2.1.3. 记忆元 9.2.1.4. 隐状态 9 ...

  6. Tensorflow使用CNN卷积神经网络以及RNN(Lstm、Gru)循环神经网络进行中文文本分类

    Tensorflow使用CNN卷积神经网络以及RNN(Lstm.Gru)循环神经网络进行中文文本分类 本案例采用清华大学NLP组提供的THUCNews新闻文本分类数据集的一个子集进行训练和测试http ...

  7. LSTM -长短期记忆网络(RNN循环神经网络)

    文章目录 基本概念及其公式 输入门.输出门.遗忘门 候选记忆元 记忆元 隐状态 从零开始实现 LSTM 初始化模型参数 定义模型 训练和预测 简洁实现 小结 基本概念及其公式 LSTM,即(long ...

  8. 长短期记忆网络(LSTM)简述

    本文是学习LSTMs入门知识的总结. LSTM(Long-Short Term Memory)是递归神经网络(RNN:Recurrent Neutral Network)的一种. RNNs也叫递归神经 ...

  9. tensorflow实现循环神经网络——经典网络(LSTM、GRU、BRNN)

    参考链接: https://www.cnblogs.com/tensorflownews/p/7293859.html http://www.360doc.com/content/17/0321/10 ...

最新文章

  1. linux特殊符号大全
  2. python那么慢为什么还有人用-Python执行效率慢,为什么还这么火?【黑马程序员】...
  3. html怎么使用伪类清除浮动,JS中使用 after 伪类清除浮动实例
  4. 压缩感知(I) A Compressed Sense of Compressive Sensing (I)
  5. 搭建一个简易的https
  6. Vue——整合Katex
  7. java opencv人脸识别_java+opencv+intellij idea实现人脸识别
  8. python会什么比c慢
  9. 不规则图形数格子的方法_北师大版五年级数学上册数学6.1组合图形的面积微课堂、同步练习、图文解读...
  10. Random Forest
  11. 一个景点的给input域一个默认值,然后在聚焦的时候清空它 jquery方法
  12. 跟小海一起看下雪——用HTML、CSS和JS实现简单的下雪特效
  13. Python读取PDF文档并翻译
  14. MODIS数据火点提取方法
  15. ubuntu18.04播放MP4
  16. 利用沙盒技术破解APP的API协议加密
  17. dijkstra模板(fast)
  18. 修改移动硬盘盘符(G盘--E盘)
  19. R数通杀思路分享-反部分混淆解析canvas和fonts指纹
  20. nodejs控制台打印图案

热门文章

  1. LIGA Stereo:基于双目3D检测的Lidar几何感知表示学习(ICCV2021)
  2. 实操教程|详细记录solov2的ncnn实现和优化
  3. 双一流校长:学校要扩大博士生规模!适当控制硕士生规模,因为住宿条件跟不上了...
  4. 嵌入式的我们为什么要学ROS
  5. bootstrap-datetimepicker时间控件添加清除按钮
  6. spring 源码 找不到 taskprovider_一步一步构建Spring5源码
  7. mysql内连接和外连接的区别_Swoole4创建Mysql连接池
  8. Nature子刊:宏基因组中挖掘原核基因组的分析流程
  9. SEL | 植物通过根系分泌物招募假单孢菌协助抵抗地上部病原菌侵染
  10. 【Plant Cell】突破!加入一种酵母,可显著提高水稻氮利用率及产量!