长短期记忆是复杂和先进的神经网络结构的重要组成部分。本文的主要思想是解释其背后的数学原理,所以阅读本文之前,建议首先对LSTM有一些了解。

介绍

上面是单个LSTM单元的图表。我知道它看起来可怕,但我们会通过一个接一个的文章,希望它会很清楚。

解释

基本上一个LSTM单元有4个不同的组件。忘记门、输入门、输出门和单元状态。我们将首先简要讨论这些部分的使用,然后深入讨论数学部分。

忘记门

顾名思义,这部分负责决定在最后一步中扔掉或保留哪些信息。这是由第一个s型层完成的。

根据ht-1(以前的隐藏状态)和xt(时间步长t的当前输入),它为单元格状态C_t-1中的每个值确定一个介于0到1之间的值。

遗忘门和上一个状态

如果为1,所有的信息保持原样,如果为0,所有的信息都被丢弃,对于其他的值,它决定有多少来自前一个状态的信息被带入下一个状态。

输入门

Christopher Olah博客的解释在输入门发生了什么:

下一步是决定在单元格状态中存储什么新信息。这包括两部分。首先,一个称为“输入门层”的sigmoid层决定我们将更新哪些值。接下来,一个tanh层创建一个新的候选值的向量,C~t,可以添加到状态中。在下一步中,我们将结合这两者来创建对状态的更新。

现在这两个值i。e i_t和c~t结合决定什么新的输入是被输入到状态。

单元状态

单元状态充当LSTM的内存。这就是它们在处理较长的输入序列时比普通RNN表现得更好的地方。在每一个时间步长,前一个单元状态(Ct-1)与遗忘门结合,以决定什么信息要被传送,然后与输入门(it和c~t)结合,形成新的单元状态或单元的新存储器。

状态的计算公式

输出门

最后,LSTM单元必须给出一些输出。从上面得到的单元状态通过一个叫做tanh的双曲函数,因此单元状态值在-1和1之间过滤。

LSTM单元的基本单元结构已经介绍完成,继续推导在实现中使用的方程。

推导先决条件

推导方程的核心概念是基于反向传播、成本函数和损失。除此以外还假设您对高中微积分(计算导数和规则)有基本的了解。

变量:对于每个门,我们有一组权重和偏差,表示为:

Wf,bf->遗忘门的权重和偏差Wi,bi->输入门的权重和偏差Wc,bc->单元状态的权重和偏差Wo,bo->输出门的权重和偏差Wv ,bv -> 与Softmax层相关的权重和偏差ft, it,ctiledet, o_t -> 输出使用的激活函数af, ai, ac, ao -> 激活函数的输入J是成本函数,我们将根据它计算导数。注意(下划线(_)后面的字符是下标)

前向传播推导

门的计算公式

状态的计算公式

以遗忘门为例说明导数的计算。我们需要遵循下图中红色箭头的路径。

我们画出一条从f_t到代价函数J的路径,也就是

ft→Ct→h_t→J。

反向传播完全发生在相同的步骤中,但是是反向的

ft←Ct←h_t←J。

J对ht求导,ht对Ct求导,Ct对f_t求导。

所以如果我们在这里观察,J和ht是单元格的最后一步,如果我们计算dJ/dht,那么它可以用于像dJ/dC_t这样的计算,因为:

dJ/dCt = dJ/dht * dht/dCt(链式法则)

同样,对第一点提到的所有变量的导数也要计算。

现在我们已经准备好了变量并且清楚了前向传播的公式,现在是时候通过反向传播来推导导数了。我们将从输出方程开始因为我们看到在其他方程中也使用了同样的导数。这时就要用到链式法则了。我们现在开始吧。

反向传播推导

lstm的输出有两个值需要计算。

Softmax:对于交叉熵损失的导数,我们将直接使用最终的方程。

隐藏状态是ht。ht是w.r的微分。根据链式法则,推导过程如下图所示。

输出门相关变量:ao和ot,微分的完整方程如下:

dJ/dVt * dVt/dht * dht/dO_t

dJ/dVt * dVt/dht可以写成dJ/dht(我们从隐藏状态得到这个值)。

ht的值= ot * tanh(ct) ->所以我们只需要对ht w.r求导。t o_t。其区别如下:

同样,a_o和J之间的路径也显示出来。微分的完整方程如下:

dJ/dVt * dVt/dht * dt /da_o

dJ/dVt * dVt/dht * dht/dOt可以写成dJ/dOt(我们从上面的o_t得到这个值)。

Ct是单元的单元状态。除此之外,我们还处理候选单元格状态ac和c~_t。

Ct的推导很简单,因为从Ct到J的路径很简单。Ct→ht→Vt→j,因为我们已经有了dJ/dht,我们直接微分ht w.r。t Ct。

ht = ot * tanh(ct) ->所以我们只需要对ht w.r求导。t C_t。

微分的完整方程如下:

dJ/dht * dht/dCt * dCt/dc~_t

可以将dJ/dht * dht/dCt写成dJ/dCt(我们在上面有这个值)。

Ct的值如图9公式5所示(下图第3行最后一个Ct缺少波浪号(~)符号->书写错误)。所以我们只需要对C_t w.r求导。t c ~ _t。

ac:如下图所示为ac到J的路径。根据箭头,微分的完整方程如下:

dJ/dht * dht/dCt * dCt/ da_c

dJ/dht * dht/dCt * dCt/dc_t可以写成dJ/dc_t(我们在上面有这个值)。

所以我们只需要对c~t w.r求导。t ac。

输入门相关变量:it和ai

微分的完整方程如下:

dt / dt * dt /dit

可以将dJ/dht * dht/dCt写入为dJ/dCt(我们在单元格状态中有这个值)。所以我们只需要对Ct w.r求导。t it。

a_i:微分的完整方程如下:

dJ/dht * dht/dCt * dt /da_i

dJ/dht * dht/dCt * dCt/dit可以写成dJ/dit(我们在上面有这个值)。所以我们只需要对i_t w.r求导。t ai。

遗忘门相关变量:ft和af

微分的完整方程如下:

dJ/dht * dht/dCt * dCt/df_t

可以将dJ/dht * dht/dCt写入为dJ/dCt(我们在单元格状态中有这个值)。所以我们只需要对Ct w.r求导。t ft。

a_f:微分的完整方程如下:

dJ/dht * dht/dCt * dft/da_t

dJ/dht * dht/dCt * dCt/dft可以写成dJ/dft(我们在上面有这个值)。所以我们只需要对ftw.r求导。t af。

Lstm的输入

每个单元格i有两个与输入相关的变量。前一个单元格状态C_t-1和前一个隐藏状态与当前输入连接,即

[ht-1,xt] > Z_t

C_t-1:这是Lstm单元的内存。图5显示了单元格状态。c - t-1的推导很简单因为只有c - t和c - t。

Zt:如下图所示,Zt进入四个不同的路径,af,ai,ao,ac。

Zt→af→ft→Ct→h_t→J。- >遗忘门

Zt→ai→it→Ct→h_t→J。- >输入门

Zt→ac→c~t→Ct→h_t→J。->单元状态

Zt→ao→ot→Ct→h_t→J。- >输出门

权重和偏差

W和b的推导很简单。下面的推导是针对Lstm的输出门的。对于其余的门,对权重和偏差也进行了类似的处理。

输入和遗忘门的权重和偏差

输出和输出门的权重和偏差

J/dWf = dJ/daf。daf / dWf ->遗忘门

dJ/dWi = dJ/dai。dai / dWi ->输入门

dJ/dWv = dJ/dVtdVt/ dWv ->输出门

dJ/dWo = dJ/dao。dao / dWo ->输出门

我们完成了所有的推导。但是有两点需要强调

到目前为止,我们所做的只是一个时间步长。现在我们要让它只进行一次迭代。

所以如果我们有总共T个时间步长,那么每一个时间步长的梯度会在T个时间步长结束时相加,所以每次迭代结束时的累积梯度为:

每次迭代结束时的累积梯度用来更新权重

总结

LSTM是非常复杂的结构,但它们工作得非常好。具有这种特性的RNN主要有两种类型:LSTM和GRU。

训练LSTMs也是一项棘手的任务,因为有许多超参数,而正确地组合通常是一项困难的任务。

作者:Rahuljha

deephub翻译组

lstm数学推导_手推公式:LSTM单元梯度的详细的数学推导相关推荐

  1. 【干货】105页周志华教授《机器学习手推公式》开源PDF

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 上述内容是手推公式的主要内容,本项目的Github主页如下:https://git ...

  2. 干货 | 《深度学习》手推公式笔记开源PDF下载!

    为大家找到的王博(Kings)的<深度学习>手推公式笔记,需要的伙伴可以在公众号"飞马会"菜单栏回复数字"91"查看获取方式. 深度学习手推笔记部分 ...

  3. 985博士《深度学习》手推公式笔记开源PDF下载!

    前几天为大家找到的王博(Kings)的笔记[机器学习手推笔记],大家都非常喜欢,近几天发现王博的Github又更新了深度学习版本笔记 GitHub地址(点击原文阅读可直达GitHub): https: ...

  4. lstm 输入数据维度_[mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解

    最近想了解一些关于LSTM的相关知识,在进行代码测试的时候,有个地方一直比较疑惑,关于LSTM的输入和输出问题.一直不清楚在pytorch里面该如何定义LSTM的输入和输出.首先看个pytorch官方 ...

  5. lstm数学推导_如何在训练LSTM的同时训练词向量?

    你本来也不用自己手动进行词向量更新啊,你搞这么一出最后收敛到0那不是必然的么? @霍华德 老师的答案已经给你推导出来了. 实际上你问的这个问题很简单--只要把Embedding层本身也当成模型参数的一 ...

  6. pytorch 反卷积 可视化_手推反卷积

    先手推卷积热个身 在推导反卷积之前,先推导一下卷积. 假设输入为 ,卷积核为 ,输出大小的计算公式为 .当 时,输出为 . 将输入矩阵转成一个 的列阵,卷积核扩展为 的矩阵,即 则 , 所以 . 用p ...

  7. lstm预测股票_股票相关性与lstm预测误差

    lstm预测股票 When trying to look at examples of LSTMs in Keras, I've found a lot that focus on using the ...

  8. 机器学习之求解无约束最优化问题方法(手推公式版)

    文章目录 前言 1. 基础知识 1.1 方向导数 1.2 梯度 1.3 方向导数与梯度的关系 1.4 泰勒展开公式 1.5 Jacobian矩阵与Hessian矩阵 1.6 正定矩阵 2. 梯度下降法 ...

  9. lstm 输入数据维度_理解Pytorch中LSTM的输入输出参数含义

    本文不会介绍LSTM的原理,具体可看如下两篇文章 Understanding LSTM Networks DeepLearning.ai学习笔记(五)序列模型 -- week1 循环序列模型 1.举个 ...

最新文章

  1. 【AJAX】Ajax学习总结
  2. 中国电子信息工程科技发展十大趋势(2019)发布
  3. 四则运算 - java实现(叶尚文, 张鸿)
  4. Android 5.0中的FDE功能实现
  5. 废旧纸箱做机器人图片_网购后的快递纸箱被你扔掉了吗?
  6. new chosen courses at ifm
  7. wxWidgets:wxThreadHelper类用法
  8. Java线程的使用及共享协作
  9. Python基础知识4: while循环基本使用
  10. Flinksql读取Kafka写入Iceberg 实践亲测
  11. SpringBoot中拦截器
  12. linux文件管理命令详解
  13. 给intellij IDEA设置背景颜色
  14. GET和POST 区别
  15. C语言扫雷游戏代码以及基本原理教学(一看就会)
  16. unity 物理碰撞
  17. filp_open/filp_close/vfs_read/vfs_write
  18. Intranet/Internet
  19. BLE传输速率以及抓包工具
  20. 计算机组成原理带符号的阵列乘法器,计算机组成原理阵列乘法器课程设计报告精选.doc...

热门文章

  1. java 年计算_用Java计算leap年
  2. python中matrix函数_使用python解线性矩阵方程(numpy中的matrix类)
  3. perl数组硬引用_perl引用和数组 - SibylY的个人空间 - OSCHINA - 中文开源技术交流社区...
  4. 5.9 程序示例--非线性分类-机器学习笔记-斯坦福吴恩达教授
  5. Android学习笔记:ScrollView卷轴视图
  6. 基于MATLAB的Okumura-Hata模型的仿真
  7. 考前自学系列·计算机组成原理·常见的数据寻址方式(地址码,操作数位置)
  8. 多线程题目 2019.06.02 晚
  9. Linux之ln命令
  10. Spring集成Mybatis,spring4.x整合Mybatis3.x