Long short-term memory (LSTM) RNNs


In [1]:
from __future__ import print_function
import mxnet as mx
from mxnet import nd, autograd
import numpy as np
mx.random.seed(1)
# ctx = mx.gpu(0)
ctx = mx.cpu(0)

In [2]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from datetime import datetime
# import mpld3
sns.set_style('whitegrid')
#sns.set_context('notebook')
sns.set_context('poster')
# Make inline plots vector graphics instead of raster graphics
from IPython.display import set_matplotlib_formats
#set_matplotlib_formats('pdf', 'svg')
set_matplotlib_formats('pdf', 'png')


/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:98: DeprecationWarning: DisplayFormatter._formatters_default is deprecated: use @default decorator instead.
  def _formatters_default(self):
/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:677: DeprecationWarning: PlainTextFormatter._deferred_printers_default is deprecated: use @default decorator instead.
  def _deferred_printers_default(self):
/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:669: DeprecationWarning: PlainTextFormatter._singleton_printers_default is deprecated: use @default decorator instead.
  def _singleton_printers_default(self):
/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:672: DeprecationWarning: PlainTextFormatter._type_printers_default is deprecated: use @default decorator instead.
  def _type_printers_default(self):
/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:672: DeprecationWarning: PlainTextFormatter._type_printers_default is deprecated: use @default decorator instead.
  def _type_printers_default(self):
/usr/local/lib/python2.7/site-packages/IPython/core/formatters.py:677: DeprecationWarning: PlainTextFormatter._deferred_printers_default is deprecated: use @default decorator instead.
  def _deferred_printers_default(self):

In [3]:
SEQ_LENGTH = 100 + 1  # needs to be at least the seq_length for training + 1 because of the time shift between inputs and labels
NUM_SAMPLES_TRAINING = 5000 + 1
NUM_SAMPLES_TESTING = 100 + 1

Dataset: "Some time-series"


In [4]:
def gimme_one_random_number():
    return nd.random_uniform(low=0, high=1, shape=(1,1)).asnumpy()[0][0]

def create_one_time_series(seq_length=10):
  freq = (gimme_one_random_number()*0.5) + 0.1  # 0.1 to 0.6
  ampl = gimme_one_random_number() + 0.5  # 0.5 to 1.5
  x = np.sin(np.arange(0, seq_length) * freq) * ampl
  return x

In [5]:
def create_batch_time_series(seq_length=10, num_samples=4):
    column_labels = ['t'+str(i) for i in range(0, seq_length)]
    df = pd.DataFrame(create_one_time_series(seq_length=seq_length)).transpose()
    df.columns = column_labels
    df.index = ['s'+str(0)]
    for i in range(1, num_samples):
        more_df = pd.DataFrame(create_one_time_series(seq_length=seq_length)).transpose()
        more_df.columns = column_labels
        more_df.index = ['s'+str(i)]
        df = pd.concat([df, more_df], axis=0)
    return df # returns a dataframe of shape (num_samples, seq_length)

In [6]:
# Create some time-series
# uncomment below to force predictible random numbers
# mx.random.seed(1)
data_train = create_batch_time_series(seq_length=SEQ_LENGTH, num_samples=NUM_SAMPLES_TRAINING)  
data_test = create_batch_time_series(seq_length=SEQ_LENGTH, num_samples=NUM_SAMPLES_TESTING)

# Write data to csv
data_train.to_csv("../data/timeseries/train.csv", header=False, index=False)
data_test.to_csv("../data/timeseries/test.csv", header=False, index=False)

Check the data real quick


In [7]:
# num_sampling_points = min(SEQ_LENGTH, 50)
# (data_train.sample(4).transpose().iloc[range(0, SEQ_LENGTH, SEQ_LENGTH//num_sampling_points)]).plot()

Preparing the data for training


In [8]:
# print(data_train.loc[:,data_train.columns[:-1]]) # inputs
# print(data_train.loc[:,data_train.columns[1:]])  # outputs (i.e. inputs shift by +1)

In [9]:
batch_size = 32
batch_size_test = 1
seq_length = 16

num_batches_train = data_train.shape[0] // batch_size
num_batches_test = data_test.shape[0] // batch_size_test

num_features = 1  #  we do 1D time series for now, this is like vocab_size = 1 for characters

# inputs are from t0 to t_seq_length - 1. because the last point is kept for the output ("label") of the penultimate point 
data_train_inputs = data_train.loc[:,data_train.columns[:-1]]
data_train_labels = data_train.loc[:,data_train.columns[1:]]
data_test_inputs = data_test.loc[:,data_test.columns[:-1]]
data_test_labels = data_test.loc[:,data_test.columns[1:]]

train_data_inputs = nd.array(data_train_inputs.values).reshape((num_batches_train, batch_size, seq_length, num_features))
train_data_labels = nd.array(data_train_labels.values).reshape((num_batches_train, batch_size, seq_length, num_features))
test_data_inputs = nd.array(data_test_inputs.values).reshape((num_batches_test, batch_size_test, seq_length, num_features))
test_data_labels = nd.array(data_test_labels.values).reshape((num_batches_test, batch_size_test, seq_length, num_features))

train_data_inputs = nd.swapaxes(train_data_inputs, 1, 2)
train_data_labels = nd.swapaxes(train_data_labels, 1, 2)
test_data_inputs = nd.swapaxes(test_data_inputs, 1, 2)
test_data_labels = nd.swapaxes(test_data_labels, 1, 2)


print('num_samples_training={0} | num_batches_train={1} | batch_size={2} | seq_length={3}'.format(NUM_SAMPLES_TRAINING, num_batches_train, batch_size, seq_length))
print('train_data_inputs shape: ', train_data_inputs.shape)
print('train_data_labels shape: ', train_data_labels.shape)
# print(data_train_inputs.values)
# print(train_data_inputs[0]) # see what one batch looks like


num_samples_training=5001 | num_batches_train=156 | batch_size=32 | seq_length=16
train_data_inputs shape:  (156L, 16L, 32L, 1L)
train_data_labels shape:  (156L, 16L, 32L, 1L)

Long short-term memory (LSTM) RNNs

An LSTM block has mechanisms to enable "memorizing" information for an extended number of time steps. We use the LSTM block with the following transformations that map inputs to outputs across blocks at consecutive layers and consecutive time steps: $\newcommand{\xb}{\mathbf{x}} \newcommand{\RR}{\mathbb{R}}$

$$g_t = \text{tanh}(X_t W_{xg} + h_{t-1} W_{hg} + b_g),$$$$i_t = \sigma(X_t W_{xi} + h_{t-1} W_{hi} + b_i),$$$$f_t = \sigma(X_t W_{xf} + h_{t-1} W_{hf} + b_f),$$$$o_t = \sigma(X_t W_{xo} + h_{t-1} W_{ho} + b_o),$$$$c_t = f_t \odot c_{t-1} + i_t \odot g_t,$$$$h_t = o_t \odot \text{tanh}(c_t),$$

where $\odot$ is an element-wise multiplication operator, and for all $\xb = [x_1, x_2, \ldots, x_k]^\top \in \RR^k$ the two activation functions:

$$\sigma(\xb) = \left[\frac{1}{1+\exp(-x_1)}, \ldots, \frac{1}{1+\exp(-x_k)}]\right]^\top,$$$$\text{tanh}(\xb) = \left[\frac{1-\exp(-2x_1)}{1+\exp(-2x_1)}, \ldots, \frac{1-\exp(-2x_k)}{1+\exp(-2x_k)}\right]^\top.$$

In the transformations above, the memory cell $c_t$ stores the "long-term" memory in the vector form. In other words, the information accumulatively captured and encoded until time step $t$ is stored in $c_t$ and is only passed along the same layer over different time steps.

Given the inputs $c_t$ and $h_t$, the input gate $i_t$ and forget gate $f_t$ will help the memory cell to decide how to overwrite or keep the memory information. The output gate $o_t$ further lets the LSTM block decide how to retrieve the memory information to generate the current state $h_t$ that is passed to both the next layer of the current time step and the next time step of the current layer. Such decisions are made using the hidden-layer parameters $W$ and $b$ with different subscripts: these parameters will be inferred during the training phase by gluon.

Allocate parameters


In [10]:
num_inputs = num_features
num_hidden = 8
num_outputs = num_features

########################
#  Weights connecting the inputs to the hidden layer
########################
Wxg = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01
Wxi = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01
Wxf = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01
Wxo = nd.random_normal(shape=(num_inputs,num_hidden), ctx=ctx) * .01

########################
#  Recurrent weights connecting the hidden layer across time steps
########################
Whg = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01
Whi = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01
Whf = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01
Who = nd.random_normal(shape=(num_hidden,num_hidden), ctx=ctx)* .01

########################
#  Bias vector for hidden layer
########################
bg = nd.random_normal(shape=num_hidden, ctx=ctx) * .01
bi = nd.random_normal(shape=num_hidden, ctx=ctx) * .01
bf = nd.random_normal(shape=num_hidden, ctx=ctx) * .01
bo = nd.random_normal(shape=num_hidden, ctx=ctx) * .01

########################
# Weights to the output nodes
########################
Why = nd.random_normal(shape=(num_hidden,num_outputs), ctx=ctx) * .01
by = nd.random_normal(shape=num_outputs, ctx=ctx) * .01

Attach the gradients


In [11]:
params = [Wxg, Wxi, Wxf, Wxo, Whg, Whi, Whf, Who, bg, bi, bf, bo, Why, by]

for param in params:
    param.attach_grad()

Softmax Activation


In [12]:
def softmax(y_linear, temperature=1.0):
    lin = (y_linear-nd.max(y_linear)) / temperature
    exp = nd.exp(lin)
    partition = nd.sum(exp, axis=0, exclude=True).reshape((-1,1))
    return exp / partition

Define the model


In [13]:
def lstm_rnn(inputs, h, c, temperature=1.0):
    outputs = []
    # inputs is one BATCH of sequences so its shape is number_of_seq, seq_length, features_dim 
    # (latter is 1 for a time series, vocab_size for a character, n for a n different times series)
    for X in inputs:
#         print('shape of inputs, X: ', inputs.shape, X.shape)
        # X is batch of one time stamp. E.g. if each batch has 37 sequences, then the first value of X will be a set of the 37 first values of each of the 37 sequences 
        # that means each iteration on X corresponds to one time stamp, but it is done in batches of different sequences
        g = nd.tanh(nd.dot(X, Wxg) + nd.dot(h, Whg) + bg)
        i = nd.sigmoid(nd.dot(X, Wxi) + nd.dot(h, Whi) + bi)
        f = nd.sigmoid(nd.dot(X, Wxf) + nd.dot(h, Whf) + bf)
        o = nd.sigmoid(nd.dot(X, Wxo) + nd.dot(h, Who) + bo)
        #######################
        #
        #######################
        c = f * c + i * g
        h = o * nd.tanh(c)
        #######################
        #
        #######################
        yhat_linear = nd.dot(h, Why) + by
        # yhat is a batch of several values of the same time stamp
        # this is basically the prediction of the sequence, which overlaps most of the input sequence, plus one point (character or value)
#         yhat = softmax(yhat_linear, temperature=temperature)
#         yhat = nd.sigmoid(yhat_linear)
#         yhat = nd.tanh(yhat_linear)
        yhat = yhat_linear # we cant use a 1.0-bounded activation function since amplitudes can be greater than 1.0
        outputs.append(yhat) # outputs has same shape as inputs, i.e. a list of batches of data points.
#     print('some shapes... yhat outputs', yhat.shape, len(outputs) )
    return (outputs, h, c)

Cross-entropy loss function


In [14]:
def cross_entropy(yhat, y):
    return - nd.mean(nd.sum(y * nd.log(yhat), axis=0, exclude=True))

# root_mean_squared_error = mx.metric.RMSE()

def rmse(yhat, y):
#     root_mean_squared_error.update(labels = y, preds = yhat)
#     return root_mean_squared_error.get()
#     print("LOOOL ", nd.power(y - yhat, 2))
#     return - nd.mean(nd.sum(y * nd.log(yhat), axis=0, exclude=True))
    return nd.mean(nd.sqrt(nd.sum(nd.power(y - yhat, 2), axis=0, exclude=True)))

Averaging the loss over the sequence


In [15]:
def average_ce_loss(outputs, labels):
    assert(len(outputs) == len(labels))
    total_loss = 0.
    for (output, label) in zip(outputs,labels):
        total_loss = total_loss + cross_entropy(output, label)
    return total_loss / len(outputs)

def average_rmse_loss(outputs, labels):
    assert(len(outputs) == len(labels))
    total_loss = 0.
    for (output, label) in zip(outputs,labels):
        total_loss = total_loss + rmse(output, label)
    return total_loss / len(outputs)

Optimizer


In [16]:
def SGD(params, lr):
    for param in params:
        param[:] = param - lr * param.grad

Test and visualize predictions


In [17]:
def test_prediction(one_input_seq, one_label_seq, temperature=1.0):
    #####################################
    # Set the initial state of the hidden representation ($h_0$) to the zero vector
    #####################################  # some better initialization needed??
    h = nd.zeros(shape=(1, num_hidden), ctx=ctx)
    c = nd.zeros(shape=(1, num_hidden), ctx=ctx)
    
    outputs, h, c = lstm_rnn(one_input_seq, h, c, temperature=temperature)
    loss = rmse(outputs[-1][0], one_label_seq)
    return outputs[-1][0].asnumpy()[-1], one_label_seq.asnumpy()[-1], loss.asnumpy()[-1], outputs, one_label_seq

def check_prediction(index):
    o, label, loss, outputs, labels = test_prediction(test_data_inputs[index], test_data_labels[index], temperature=1.0)
    prediction = round(o, 3)
    true_label = round(label, 3)
    outputs = [float(i.asnumpy().flatten()) for i in outputs]
    true_labels = list(test_data_labels[index].asnumpy().flatten())
    # print(outputs, '\n----\n', true_labels)
    df = pd.DataFrame([outputs, true_labels]).transpose()
    df.columns = ['predicted', 'true']
    # print(df)
    rel_error = round(100. * (prediction / true_label - 1.0), 2)
#     print('\nprediction = {0} | actual_value = {1} | rel_error = {2}'.format(prediction, true_label, rel_error))
    return df

In [19]:
epochs = 45
moving_loss = 0.
learning_rate = .03

# needed to update plots on the fly
%matplotlib notebook
fig, axxx = plt.subplots(1,1)

# for i in range(4):
#     l1 = test_data_inputs[i].asnumpy().flatten()
#     l2 = test_data_labels[i].asnumpy().flatten()
#     df = pd.DataFrame([l1, l2]).transpose()
#     df.columns = ['predicted', 'true']
#     axxx.clear()
# #     plt.pause(0.3)
#     df.plot(ax=axxx)
#     fig.canvas.draw()
# #     plt.draw()
# #     plt.pause(0.5)
# #     df.plot(fig.axes)
#     time.sleep(0.3)  
# %matplotlib inline
        


# state = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)
for e in range(epochs):
    ############################
    # Attenuate the learning rate by a factor of 2 every 100 epochs.
    ############################
    if ((e+1) % 100 == 0):
        learning_rate = learning_rate / 2.0
    h = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)
    c = nd.zeros(shape=(batch_size, num_hidden), ctx=ctx)
    for i in range(num_batches_train):
        data_one_hot = train_data_inputs[i]
        label_one_hot = train_data_labels[i]
        with autograd.record():
#             print('SHAPE OF data_one_hot ', data_one_hot.shape)
            outputs, h, c = lstm_rnn(data_one_hot, h, c)
            loss = average_rmse_loss(outputs, label_one_hot)
            loss.backward()
        SGD(params, learning_rate)

        ##########################
        #  Keep a moving average of the losses
        ##########################
        if (i == 0) and (e == 0):
            moving_loss = nd.mean(loss).asscalar()
        else:
            moving_loss = .99 * moving_loss + .01 * nd.mean(loss).asscalar()
      
#     print("Epoch %s. Loss: %s" % (e, moving_loss)) 
    data_prediction_df = check_prediction(index=e)
    axxx.clear()
    data_prediction_df.plot(ax=axxx)
    fig.canvas.draw()
    prediction = round(data_prediction_df.tail(1)['predicted'].values.flatten()[-1], 3)
    true_label = round(data_prediction_df.tail(1)['true'].values.flatten()[-1], 3)
    rel_error = round(100. * (prediction / true_label - 1.0), 2)
    print("Epoch = {0} | Loss = {1} | Prediction = {2} True = {3} Error = {4}".format(e, moving_loss, prediction, true_label, rel_error ))
    
%matplotlib inline


Epoch = 0 | Loss = 0.615221960254 | Prediction = 0.047 True = -1.381 Error = -103.4
Epoch = 1 | Loss = 0.592363745208 | Prediction = 0.182 True = 0.226 Error = -19.47
Epoch = 2 | Loss = 0.496086866163 | Prediction = 0.547 True = 1.344 Error = -59.3
Epoch = 3 | Loss = 0.38580032295 | Prediction = -0.836 True = -0.446 Error = 87.44
Epoch = 4 | Loss = 0.325791319924 | Prediction = -0.811 True = -1.271 Error = -36.19
Epoch = 5 | Loss = 0.285158004289 | Prediction = 0.975 True = 0.655 Error = 48.85
Epoch = 6 | Loss = 0.252333167044 | Prediction = -1.145 True = -1.355 Error = -15.5
Epoch = 7 | Loss = 0.22183133572 | Prediction = -1.235 True = -1.254 Error = -1.52
Epoch = 8 | Loss = 0.186673355969 | Prediction = -1.221 True = -1.08 Error = 13.06
Epoch = 9 | Loss = 0.142382053405 | Prediction = -1.077 True = -0.841 Error = 28.06
Epoch = 10 | Loss = 0.116652795879 | Prediction = -0.794 True = -0.554 Error = 43.32
Epoch = 11 | Loss = 0.107185860356 | Prediction = -0.435 True = -0.233 Error = 86.7
Epoch = 12 | Loss = 0.101929478208 | Prediction = -0.142 True = -0.269 Error = -47.21
Epoch = 13 | Loss = 0.0981680197267 | Prediction = -0.728 True = -0.768 Error = -5.21
Epoch = 14 | Loss = 0.0953045538405 | Prediction = -1.117 True = -1.16 Error = -3.71
Epoch = 15 | Loss = 0.0932019116822 | Prediction = -1.341 True = -1.39 Error = -3.53
Epoch = 16 | Loss = 0.0916855492985 | Prediction = -1.44 True = -1.424 Error = 1.12
Epoch = 17 | Loss = 0.0905339271424 | Prediction = -1.386 True = -1.259 Error = 10.09
Epoch = 18 | Loss = 0.0895710152233 | Prediction = 1.195 True = 1.267 Error = -5.68
Epoch = 19 | Loss = 0.0887062786193 | Prediction = 0.158 True = 0.157 Error = 0.64
Epoch = 20 | Loss = 0.0879057692828 | Prediction = -1.192 True = -1.142 Error = 4.38
Epoch = 21 | Loss = 0.0871477199253 | Prediction = -0.998 True = -1.066 Error = -6.38
Epoch = 22 | Loss = 0.0864323772884 | Prediction = 0.33 True = 0.294 Error = 12.24
Epoch = 23 | Loss = 0.0857719560485 | Prediction = 1.302 True = 1.3 Error = 0.15
Epoch = 24 | Loss = 0.0851621272648 | Prediction = 0.709 True = 0.741 Error = -4.32
Epoch = 25 | Loss = 0.0846084357904 | Prediction = 1.462 True = 1.287 Error = 13.6
Epoch = 26 | Loss = 0.0841026926626 | Prediction = -1.026 True = -1.04 Error = -1.35
Epoch = 27 | Loss = 0.0836393457114 | Prediction = -0.737 True = -0.445 Error = 65.62
Epoch = 28 | Loss = 0.0832201183717 | Prediction = 1.461 True = 1.401 Error = 4.28
Epoch = 29 | Loss = 0.082823677887 | Prediction = -0.668 True = -0.687 Error = -2.77
Epoch = 30 | Loss = 0.0824482649787 | Prediction = -1.161 True = -0.845 Error = 37.4
Epoch = 31 | Loss = 0.0820940062173 | Prediction = -0.533 True = -0.541 Error = -1.48
Epoch = 32 | Loss = 0.0817520185556 | Prediction = -0.399 True = -0.407 Error = -1.97
Epoch = 33 | Loss = 0.0814229234636 | Prediction = -0.191 True = -0.193 Error = -1.04
Epoch = 34 | Loss = 0.0810967448618 | Prediction = 0.059 True = 0.058 Error = 1.72
Epoch = 35 | Loss = 0.0807809258815 | Prediction = 0.306 True = 0.298 Error = 2.68
Epoch = 36 | Loss = 0.0804760479562 | Prediction = 0.487 True = 0.479 Error = 1.67
Epoch = 37 | Loss = 0.0801744776573 | Prediction = 0.611 True = 0.549 Error = 11.29
Epoch = 38 | Loss = 0.0798811066981 | Prediction = 1.291 True = 1.263 Error = 2.22
Epoch = 39 | Loss = 0.0795865975717 | Prediction = 1.02 True = 1.096 Error = -6.93
Epoch = 40 | Loss = 0.0792954185166 | Prediction = 0.17 True = 0.164 Error = 3.66
Epoch = 41 | Loss = 0.0790092078105 | Prediction = -0.978 True = -0.882 Error = 10.88
Epoch = 42 | Loss = 0.0787270771903 | Prediction = -1.283 True = -1.313 Error = -2.28
Epoch = 43 | Loss = 0.0784530613636 | Prediction = 1.164 True = 1.168 Error = -0.34
Epoch = 44 | Loss = 0.0781817384367 | Prediction = 1.088 True = 1.127 Error = -3.46

MY ANOMALY DETECTOR IS GONNA BE SO GREAT THAT IT WONT STOP DETECTING YOU

Conclusions

For whinges or inquiries, open an issue on GitHub.