LSTM implementation explained

LSTM implementation explained

Preface

For a long time I’ve been looking for a good tutorial on implementing LSTM networks. They seemed to be complicated and I’ve never done anything with them before. Quick googling didn’t help, as all I’ve found were some slides.

Fortunately, I took part in Kaggle EEG Competition and thought that it might be fun to use LSTMs and finally learn how they work. I based my solution and this post’s code on char-rnn by Andrej Karpathy, which I highly recommend you to check out.

RNN misconception

There is one important thing that as I feel hasn’t been emphasized strongly enough (and is the main reason why I couldn’t get myself to do anything with RNNs). There isn’t much difference between an RNN and feedforward network implementation. It’s the easiest to implement an RNN just as a feedforward network with some parts of the input feeding into the middle of the stack, and a bunch of outputs coming out from there as well. There is no magic internal state kept in the network. It’s provided as a part of the input!

The overall structure of RNNs is very similar to that of feedforward networks.

LSTM refresher

This section will cover only the formal definition of LSTMs. There are lots of other nice blog posts describing in detail how can you imagine and think of these equations.

LSTMs have many variations, but we’ll stick to a simple one. One cell consists of three gates (input, forget, output), and a cell unit. Gates use a sigmoid activation, while input and cell state is often transformed with tanh. LSTM cell can be defined with a following set of equations:

Gates:

it=g(Wxixt+Whiht−1+bi)

ft=g(Wxfxt+Whfht−1+bf)

ot=g(Wxoxt+Whoht−1+bo)

Input transform:

c_int=tanh(Wxcxt+Whcht−1+bc_in)

State update:

ct=ft⋅ct−1+it⋅c_int

ht=ot⋅tanh(ct)

It can be pictured like this:

Because of the gating mechanism the cell can keep a piece of information for long periods of time during work and protect the gradient inside the cell from harmful changes during the training. Vanilla LSTMs don’t have a forget gate and add unchanged cell state during the update (it can be seen as a recurrent connection with a constant weight of 1), what is often referred to as a Constant Error Carousel (CEC). It’s called like that, because it solves a serious RNN training problem of vanishing and exploding gradients, which in turn makes it possible to learn long-term relationships.

Building your own LSTM layer

The code for this tutorial will be using Torch7. Don’t worry if you don’t know it. I’ll explain everything, so you’ll be able to implement the same algorithm in your favorite framework.

The network will be implemented as a nngraph.gModule, which basically means that we’ll define a computation graph consisting of standard nn modules. We will need the following layers:

  • nn.Identity() - passes on the input (used as a placeholder for input)
  • nn.Dropout(p) - standard dropout module (drops with probability 1 - p)
  • nn.Linear(in, out) - an affine transform from in dimensions to out dims
  • nn.Narrow(dim, start, len) - selects a subvector along dim dimension having len elements starting from start index
  • nn.Sigmoid() - applies sigmoid element-wise
  • nn.Tanh() - applies tanh element-wise
  • nn.CMulTable() - outputs the sum of tensors in forwarded table
  • nn.CAddTable() - outputs the product of tensors in forwarded table

Inputs

First, let’s define the input structure. The array-like objects in lua are called tables. This network will accept a table of tensors like the one below:

local inputs = {}
table.insert(inputs, nn.Identity()())   -- network input
table.insert(inputs, nn.Identity()())   -- c at time t-1
table.insert(inputs, nn.Identity()())   -- h at time t-1
local input = inputs[1]
local prev_c = inputs[2]
local prev_h = inputs[3]

Identity modules will just copy whatever we provide to the network into the graph.

Computing gate values

To make our implementation faster we will be applying the transformations of the whole LSTM layer simultaneously.

local i2h = nn.Linear(input_size, 4 * rnn_size)(input)  -- input to hidden
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h)   -- hidden to hidden
local preactivations = nn.CAddTable()({i2h, h2h})       -- i2h + h2h

If you’re unfamiliar with nngraph it probably seems strange that we’re constructing a module and already calling it once more with a graph node. What actually happens is that the second call converts the nn.Module to nngraph.gModule and the argument specifies it’s parent in the graph.

preactivations outputs a vector created by a linear transform of input and previous hidden state. These are raw values which will be used to compute the gate activations and the cell input. This vector is divided into 4 parts, each of size rnn_size. The first will be used for in gates, second for forget gates, third for out gates and the last one as a cell input (so the indices of respective gates and input of a cell number i are {i, rnn_size+i, 2⋅rnn_size+i, 3⋅rnn_size+i}).

 

Next, we have to apply a nonlinearity, but while all the gates use the sigmoid, we will use a tanh for the input preactivation. Because of this, we will place two nn.Narrow modules, which will select appropriate parts of the preactivation vector.

-- gates
local pre_sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(preactivations)
local all_gates = nn.Sigmoid()(pre_sigmoid_chunk)

-- input
local in_chunk = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(preactivations)
local in_transform = nn.Tanh()(in_chunk)

After the nonlinearities we have to place a couple more nn.Narrows and we have the gates done!

local in_gate = nn.Narrow(2, 1, rnn_size)(all_gates)
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(all_gates)
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(all_gates)

 

Cell and hidden state

Having computed the gate values we can now calculate the current cell state. All that’s required are just two nn.CMulTable modules (one for f⋅clt−1 and one for i⋅x), and a nn.CAddTable to sum them up to a current cell state.

-- previous cell state contribution
local c_forget = nn.CMulTable()({forget_gate, prev_c})
-- input contribution
local c_input = nn.CMulTable()({in_gate, in_transform})
-- next cell state
local next_c = nn.CAddTable()({
  c_forget,
  c_input
})

It’s finally time to implement hidden state calculation. It’s the simplest part, because it just involves applying tanh to current cell state (nn.Tanh) and multiplying it with an output gate (nn.CMulTable).

local c_transform = nn.Tanh()(next_c)
local next_h = nn.CMulTable()({out_gate, c_transform})

 

Defining the module

Now, if you want to export the whole graph as a standalone module you can wrap it like that:

-- module outputs
outputs = {}
table.insert(outputs, next_c)
table.insert(outputs, next_h)

-- packs the graph into a convenient module with standard API (:forward(), :backward())
return nn.gModule(inputs, outputs)

Examples

LSTM layer implementation is available here. You can use it like that:

th> LSTM = require ‘LSTM.lua‘
                                                                      [0.0224s]
th> layer = LSTM.create(3, 2)
                                                                      [0.0019s]
th> layer:forward({torch.randn(1,3), torch.randn(1,2), torch.randn(1,2)})
{
  1 : DoubleTensor - size: 1x2
  2 : DoubleTensor - size: 1x2
}
                                                                      [0.0005s]

To make a multi-layer LSTM network you can forward subsequent layers in a for loop, taking next_hfrom previous layer as next layer’s input. You can check this example.

Training

If you’re interested please leave a comment and I’ll try to expand this post!

That’s it!

That’s it. It’s quite easy to implement any RNN when you understand how to deal with the hidden state. After connecting several layers just put a regular MLP on top and connect it to last layer’s hidden state and you’re done!

Here are some nice papers on RNNs if you’re interested:

时间: 2024-09-27 21:53:10

LSTM implementation explained的相关文章

RNN and LSTM saliency Predection Scene Label

http://handong1587.github.io/deep_learning/2015/10/09/rnn-and-lstm.html  //RNN and LSTM http://handong1587.github.io/deep_learning/2015/10/09/saliency-prediction.html //saliency Predection http://handong1587.github.io/deep_learning/2015/10/09/scene-l

LSTM与Highway-LSTM算法实现的研究(1)

LSTM与Highway-LSTM算法实现的研究(1) [email protected] http://www.cnblogs.com/swje/ 作者:Zhouw  2015-12-22   声明: 1)该LSTM的学习系列是整理自网上很大牛和机器学习专家所无私奉献的资料的.具体引用的资料请看参考文献.具体的版本声明也参考原文献. 2)本文仅供学术交流,非商用.所以每一部分具体的参考资料并没有详细对应.如果某部分不小心侵犯了大家的利益,还望海涵,并联系博主删除. 3)本人才疏学浅,整理总结的

MXNet中LSTM例子注记

Preface 序列问题也是一个interesting的issue.找了一会LSTM的材料,发现并没有一个系统的文字,早期Sepp Hochreiter的paper和弟子Felix Gers的thesis看起来并没有那么轻松.最开始入手的是15年的一个review,当时看起来也不太顺畅,但看了前两个(一部分)再回头来看这篇的formulation部分,会清晰些. 本来打算自己写个程序理一下,发现这里有个参考,程序很短,Python写的总共没有200line,但要从里面理出结构来有些费劲.想起MX

组会准备

LSTM Networks for Sentiment Analysis Summary This tutorial aims to provide an example of how a Recurrent Neural Network (RNN) using the Long Short Term Memory (LSTM) architecture can be implemented using Theano. In this tutorial, this model is used t

Awesome Recurrent Neural Networks

Awesome Recurrent Neural Networks A curated list of resources dedicated to recurrent neural networks (closely related to deep learning). Maintainers - Jiwon Kim, Myungsub Choi We have pages for other topics: awesome-deep-vision, awesome-random-forest

tensorflow bilstm官方示例

1 ''' 2 A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library. 3 This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/) 4 Long Short Term Memory paper: http://deepl

配送交付时间轻量级预估实践

1. 背景 可能很多同学都不知道,从打开美团App点一份外卖开始,然后在半小时内就可以从骑手小哥手中拿到温热的饭菜,这中间涉及的环节有多么复杂.而美团配送技术团队的核心任务,就是将每天来自祖国各地的数千万份订单,迅速调度几十万骑手小哥按照最优路线,并以最快的速度送到大家手中. 在这种场景下,骑手的交付时间,即骑手到达用户附近下车后多久能送到用户手中,就是一个非常重要的环节.下图是一个订单在整个配送链路的时间构成,时间轴最右部分描述了交付环节在整个配送环节中的位置.交付时间衡量的是骑手送餐时的交付

An Implementation of Double-Array Trie

Contents What is Trie? What Does It Take to Implement a Trie? Tripple-Array Trie Double-Array Trie Suffix Compression Key Insertion Key Deletion Double-Array Pool Allocation An Implementation Download Other Implementations References What is Trie? Tr

NetBeans Lookups Explained!

https://dzone.com/articles/netbeans-lookups-explained ———————————————————————————————————————————————————————— Lookups are one of the most important parts of the NetBeans Platform. They're used almost everywhere and most of the time when you ask some