Preface

For a long time I’ve been looking for a good tutorial on implementing LSTM networks. They seemed to be complicated and I’ve never done anything with them before. Quick googling didn’t help, as all I’ve found were some slides.

Fortunately, I took part in Kaggle EEG Competition and thought that it might be fun to use LSTMs and finally learn how they work. I based my solution and this post’s code on char-rnn by Andrej Karpathy, which I highly recommend you to check out.

RNN misconception

There is one important thing that as I feel hasn’t been emphasized strongly enough (and is the main reason why I couldn’t get myself to do anything with RNNs). There isn’t much difference between an RNN and feedforward network implementation. It’s the easiest to implement an RNN just as a feedforward network with some parts of the input feeding into the middle of the stack, and a bunch of outputs coming out from there as well. There is no magic internal state kept in the network. It’s provided as a part of the input!

The overall structure of RNNs is very similar to that of feedforward networks.

LSTM refresher

This section will cover only the formal definition of LSTMs. There are lots of other nice blog posts describing in detail how can you imagine and think of these equations.

LSTMs have many variations, but we’ll stick to a simple one. One cell consists of three gates (input, forget, output), and a cell unit. Gates use a sigmoid activation, while input and cell state is often transformed with tanh. LSTM cell can be defined with a following set of equations:

Gates:

it=g(Wxixt+Whiht−1+bi)it=g(Wxixt+Whiht−1+bi)
ft=g(Wxfxt+Whfht−1+bf)ft=g(Wxfxt+Whfht−1+bf)
ot=g(Wxoxt+Whoht−1+bo)ot=g(Wxoxt+Whoht−1+bo)

Input transform:

c_int=tanh(Wxcxt+Whcht−1+bc_in)c_int=tanh(Wxcxt+Whcht−1+bc_in)

State update:

ct=ft⋅ct−1+it⋅c_intct=ft⋅ct−1+it⋅c_int
ht=ot⋅tanh(ct)ht=ot⋅tanh(ct)

It can be pictured like this:

Because of the gating mechanism the cell can keep a piece of information for long periods of time during work and protect the gradient inside the cell from harmful changes during the training. Vanilla LSTMs don’t have a forget gate and add unchanged cell state during the update (it can be seen as a recurrent connection with a constant weight of 1), what is often referred to as a Constant Error Carousel (CEC). It’s called like that, because it solves a serious RNN training problem of vanishing and exploding gradients, which in turn makes it possible to learn long-term relationships.

Building your own LSTM layer

The code for this tutorial will be using Torch7. Don’t worry if you don’t know it. I’ll explain everything, so you’ll be able to implement the same algorithm in your favorite framework.

The network will be implemented as a nngraph.gModule, which basically means that we’ll define a computation graph consisting of standard nn modules. We will need the following layers:

  • nn.Identity() - passes on the input (used as a placeholder for input)
  • nn.Dropout(p) - standard dropout module (drops with probability 1 - p)
  • nn.Linear(in, out) - an affine transform from in dimensions to out dims
  • nn.Narrow(dim, start, len) - selects a subvector along dim dimension having lenelements starting from start index
  • nn.Sigmoid() - applies sigmoid element-wise
  • nn.Tanh() - applies tanh element-wise
  • nn.CMulTable() - outputs the product of tensors in forwarded table
  • nn.CAddTable() - outputs the sum of tensors in forwarded table

Inputs

First, let’s define the input structure. The array-like objects in lua are called tables. This network will accept a table of tensors like the one below:

local inputs = {}
table.insert(inputs, nn.Identity()())   -- network input
table.insert(inputs, nn.Identity()())   -- c at time t-1
table.insert(inputs, nn.Identity()())   -- h at time t-1
local input = inputs[1]
local prev_c = inputs[2]
local prev_h = inputs[3]

Identity modules will just copy whatever we provide to the network into the graph.

Computing gate values

To make our implementation faster we will be applying the transformations of the whole LSTM layer simultaneously.

local i2h = nn.Linear(input_size, 4 * rnn_size)(input)  -- input to hidden
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)   -- hidden to hidden
local preactivations = nn.CAddTable()({i2h, h2h})       -- i2h + h2h

If you’re unfamiliar with nngraph it probably seems strange that we’re constructing a module and already calling it once more with a graph node. What actually happens is that the second call converts the nn.Module to nngraph.gModule and the argument specifies it’s parent in the graph.

preactivations outputs a vector created by a linear transform of input and previous hidden state. These are raw values which will be used to compute the gate activations and the cell input. This vector is divided into 4 parts, each of size rnn_size. The first will be used for in gates, second for forget gates, third for out gates and the last one as a cell input (so the indices of respective gates and input of a cell number ii are {i, rnn_size+i, 2⋅rnn_size+i, 3⋅rnn_size+i}{i, rnn_size+i, 2⋅rnn_size+i, 3⋅rnn_size+i}).

 

Next, we have to apply a nonlinearity, but while all the gates use the sigmoid, we will use a tanh for the input preactivation. Because of this, we will place two nn.Narrow modules, which will select appropriate parts of the preactivation vector.

-- gates
local pre_sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(preactivations)
local all_gates = nn.Sigmoid()(pre_sigmoid_chunk)-- input
local in_chunk = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(preactivations)
local in_transform = nn.Tanh()(in_chunk)

After the nonlinearities we have to place a couple more nn.Narrows and we have the gates done!

local in_gate = nn.Narrow(2, 1, rnn_size)(all_gates)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(all_gates)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(all_gates)
 

Cell and hidden state

Having computed the gate values we can now calculate the current cell state. All that’s required are just two nn.CMulTable modules (one for f⋅clt−1f⋅ct−1l and one for i⋅xi⋅x), and a nn.CAddTable to sum them up to a current cell state.

-- previous cell state contribution
local c_forget = nn.CMulTable()({forget_gate, prev_c})
-- input contribution
local c_input = nn.CMulTable()({in_gate, in_transform})
-- next cell state
local next_c = nn.CAddTable()({c_forget,c_input
})

It’s finally time to implement hidden state calculation. It’s the simplest part, because it just involves applying tanh to current cell state (nn.Tanh) and multiplying it with an output gate (nn.CMulTable).

local c_transform = nn.Tanh()(next_c)
local next_h = nn.CMulTable()({out_gate, c_transform})
 

Defining the module

Now, if you want to export the whole graph as a standalone module you can wrap it like that:

-- module outputs
outputs = {}
table.insert(outputs, next_c)
table.insert(outputs, next_h)-- packs the graph into a convenient module with standard API (:forward(), :backward())
return nn.gModule(inputs, outputs)

Examples

LSTM layer implementation is available here. You can use it like that:

th> LSTM = require 'LSTM.lua'[0.0224s]
th> layer = LSTM.create(3, 2)[0.0019s]
th> layer:forward({torch.randn(1,3), torch.randn(1,2), torch.randn(1,2)})
{1 : DoubleTensor - size: 1x22 : DoubleTensor - size: 1x2
}[0.0005s]

To make a multi-layer LSTM network you can forward subsequent layers in a for loop, taking next_hfrom previous layer as next layer’s input. You can check this example.

Training

If you’re interested please leave a comment and I’ll try to expand this post!

That’s it!

That’s it. It’s quite easy to implement any RNN when you understand how to deal with the hidden state. After connecting several layers just put a regular MLP on top and connect it to last layer’s hidden state and you’re done!

Here are some nice papers on RNNs if you’re interested:

  • Visualizing and Understanding Recurrent Networks
  • An Empirical Exploration of Recurrent Network Architectures
  • Recurrent Neural Network Regularization
  • Sequence to Sequence Learning with Neural Networks
Source: http://apaszke.github.io/lstm-explained.html

LSTM implementation explained相关推荐

  1. DL-3利用MNIST搭建神经网络模型(三种方法):1.用CNN 2.用CNN+RNN 3.用自编码网络autoencoder

    Author:吾爱北方的母老虎 原创链接:https://blog.csdn.net/weixin_41010198/article/details/80286216 import tensorflo ...

  2. 配送交付时间轻量级预估实践

    1. 背景 可能很多同学都不知道,从打开美团App点一份外卖开始,然后在半小时内就可以从骑手小哥手中拿到温热的饭菜,这中间涉及的环节有多么复杂.而美团配送技术团队的核心任务,就是将每天来自祖国各地的数 ...

  3. 历史最全DL相关书籍、课程、视频、论文、数据集、会议、框架和工具整理分享

    本文整理了与深度学习.人工智能相关丰富的内容,涉及人工智能相关的思维导图 (+100张AI思维导图),深度学习相关的免费在线书籍.课程.视频和讲座.论文.教程.研究人员.网站.数据集.会议.框架.工具 ...

  4. PyTorch超级资源列表(Github 2.4K星)包罗万象

    PyTorch超级资源列表,包罗万象 PyTorch超级资源列表(Github 2.4K星)包罗万象 -v7.x 1 Pytorch官方工程 2 自然语言处理和语音处理(NLP & Speec ...

  5. 美团配送交付时间轻量级预估实践

    来源:美团技术团队(meituantech)丨文:基泽 闫聪 数据猿官网 | www.datayuan.cn 今日头条丨一点资讯丨腾讯丨搜狐丨网易丨凤凰丨阿里UC大鱼丨新浪微博丨新浪看点丨百度百家丨博 ...

  6. deep learning list

    版权声明:本文为博主原创文章,未经博主允许不得转载. Free Online Books Deep Learning by Yoshua Bengio, Ian Goodfellow and Aaro ...

  7. recurrent_network

    来源:https://github.com/aymericdamien/TensorFlow-Examples """ Recurrent Neural Network. ...

  8. 论文阅读:social lstm:Human Trajectory Prediction in Crowded Spaces

    社会LSTM:拥挤空间中的人类轨迹预测 学习笔记参考:study note: https://www.zybuluo.com/ArrowLLL/note/981714 摘要:行人遵循不同的轨迹以避开障 ...

  9. 基于深度学习lstm_深度学习和基于LSTM的恶意软件分类

    基于深度学习lstm Malware development has seen diversity in terms of architecture and features. This advanc ...

最新文章

  1. Leaflet中使用leaflet-echarts插件实现Echarts的Migration迁徙图(带炫光特效)
  2. msfconsole 无法启动,解决办法
  3. java后台解析json并保存到数据库_[Java教程]ajax 发送json 后台接收 遍历保存进数据库...
  4. centos 编译安装 mysql_CentOS7编译安装MySQL5.7.24的教程详解
  5. 回溯法解决0-1背包问题
  6. HBase的常用Java API
  7. 20200209:匹配子序列的单词数(leetcode792)
  8. Centos7挂载iso镜像文件配置本地yum源
  9. 画法几何及计算机制图,画法几何及机械制图(第2版)范思冲-第十二章 计算机绘图基础.pptx...
  10. RGB和CMYK配色表
  11. Excel中实现隔行删除
  12. 【C 语言】文件操作 ( fopen 文件打开方式详解 )
  13. PS 页面描述性语言PostScript
  14. Love6 五一无忧无虑假期后的一些随笔和感想
  15. drop python_用Python做自己的AirDrop 1 - 环境搭建
  16. Selenium - 元素等待与智能等待
  17. 【蓝桥杯单片机国赛 第九届】
  18. F 分布的定义和概率密度函数
  19. Hadoop分布式计算框架MapReduce
  20. memcpy越界引起的segment fault

热门文章

  1. keras入门之手写字识别python代码
  2. Matlab emd工具箱、时频分析工具箱下载以及安装方法
  3. PyTorch进行神经风格转换/迁移(Neural-Transfer:图像风格迁移)
  4. leetcode刷题实录:3
  5. Linux内核模块的概念和基本的编程方法
  6. linux脚本编程(shell)浅介
  7. [云炬创业基础笔记]第五章创业机会评估测试5
  8. 基于马克思哲学原理论外在美与内在美2017-12-31
  9. 火爆网络的《神经网络与深度学习》,有人把它翻译成了中文版!
  10. Coursera吴恩达《神经网络与深度学习》课程笔记(3)-- 神经网络基础之Python与向量化