Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");

In [0]:
#@title Full license text
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

LSTMs in Haiku

Haiku is a simple neural network library for JAX.

This notebook walks through a simple LSTM in JAX with Haiku.

For first-time Haiku users, we recommend that you first check out out our Quickstart and MNIST example first.


In [0]:
!pip install git+

In [0]:
import functools
import math
from typing import Tuple, TypeVar
import warnings

import haiku as hk
import jax
import jax.numpy as jnp
from jax.experimental import optix
import numpy as np
import pandas as pd
import plotnine as gg

T = TypeVar('T')
Pair = Tuple[T, T]


Generating Data

In this notebook, we generate many sine waves (of the same period), and try to predict the next value in the wave based on its previous values.

For simplicity, we generate static-sized datasets and wrap them with an iterator-based API.

In [0]:
def sine_seq(
    phase: float,
    seq_len: int,
    samples_per_cycle: int,
) -> Pair[np.ndarray]:
  """Returns x, y in [T, B] tensor."""
  t = np.arange(seq_len + 1) * (2 * math.pi / samples_per_cycle)
  t = t.reshape([-1, 1]) + phase
  sine_t = np.sin(t)
  return sine_t[:-1, :], sine_t[1:, :]

def generate_data(
    seq_len: int,
    train_size: int,
    valid_size: int,
) -> Pair[Pair[np.ndarray]]:
  phases = np.random.uniform(0., 2 * math.pi, [train_size + valid_size])
  all_x, all_y = sine_seq(phases, seq_len, 3 * seq_len / 4)

  all_x = np.expand_dims(all_x, -1)
  all_y = np.expand_dims(all_y, -1)
  train_x = all_x[:, :train_size]
  train_y = all_y[:, :train_size]

  valid_x = all_x[:, train_size:]
  valid_y = all_y[:, train_size:]

  return (train_x, train_y), (valid_x, valid_y)

class Dataset:
  """An iterator over a numpy array, revealing batch_size elements at a time."""

  def __init__(self, xy: Pair[np.ndarray], batch_size: int):
    self._x, self._y = xy
    self._batch_size = batch_size
    self._length = self._x.shape[1]
    self._idx = 0
    if self._length % batch_size != 0:
      msg = 'dataset size {} must be divisible by batch_size {}.'
      raise ValueError(msg.format(self._length, batch_size))

  def __next__(self) -> Pair[np.ndarray]:
    start = self._idx
    end = start + self._batch_size
    x, y = self._x[:, start:end], self._y[:, start:end]
    if end >= self._length:
      end = end % self._length
      assert end == 0  # Guaranteed by ctor assertion.
    self._idx = end
    return x, y

In [0]:
TRAIN_SIZE = 2 ** 14
SEQ_LEN = 64

train, valid = generate_data(SEQ_LEN, TRAIN_SIZE, VALID_SIZE)

# Plot an observation/target pair.
df = pd.DataFrame({'x': train[0][:, 0, 0], 'y': train[1][:, 0, 0]}).reset_index()
df = pd.melt(df, id_vars=['index'], value_vars=['x', 'y'])
plot = gg.ggplot(df) + gg.aes(x='index', y='value', color='variable') + gg.geom_line()

train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)
del train, valid  # Don't leak temporaries.

Training an LSTM

To train the LSTM, we define a Haiku function which unrolls the LSTM over the input sequence, generating predictions for all output values. The LSTM always starts with its initial state at the start of the sequence.

The Haiku function is then transformed into a pure function through hk.transform, and is trained with Adam on an L2 prediction loss.

In [0]:
def unroll_net(seqs: jnp.ndarray):
  """Unrolls an LSTM over seqs, mapping each output to a scalar."""
  # seqs is [T, B, F].
  core = hk.LSTM(32)
  batch_size = seqs.shape[1]
  outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
  # We could include this Linear as part of the recurrent core!
  # However, it's more efficient on modern accelerators to run the linear once
  # over the entire sequence than once per sequence element.
  return hk.BatchApply(hk.Linear(1))(outs), state

model = hk.transform(unroll_net)

def train_model(train_ds: Dataset, valid_ds: Dataset) -> hk.Params:
  """Initializes and trains a model on train_ds, returning the final params."""
  rng = jax.random.PRNGKey(428)
  opt = optix.adam(1e-3)

  def loss(params, x, y):
    pred, _ = model.apply(params, None, x)
    return jnp.mean(jnp.square(pred - y))

  def update(step, params, opt_state, x, y):
    l, grads = jax.value_and_grad(loss)(params, x, y)
    grads, opt_state = opt.update(grads, opt_state)
    params = optix.apply_updates(params, grads)
    return l, params, opt_state

  # Initialize state.
  sample_x, _ = next(train_ds)
  params = model.init(rng, sample_x)
  opt_state = opt.init(params)

  for step in range(2001):
    if step % 100 == 0:
      x, y = next(valid_ds)
      print("Step {}: valid loss {}".format(step, loss(params, x, y)))

    x, y = next(train_ds)
    train_loss, params, opt_state = update(step, params, opt_state, x, y)
    if step % 100 == 0:
      print("Step {}: train loss {}".format(step, train_loss))

  return params

In [0]:
trained_params = train_model(train_ds, valid_ds)


The point of training models is so that they can make predictions! How can we generate predictions with the trained model?

If we're allowed to feed in the ground truth, we can just run the original model's apply function.

In [0]:
def plot_samples(truth: np.ndarray, prediction: np.ndarray) -> gg.ggplot:
  assert truth.shape == prediction.shape
  df = pd.DataFrame({'truth': truth.squeeze(), 'predicted': prediction.squeeze()}).reset_index()
  df = pd.melt(df, id_vars=['index'], value_vars=['truth', 'predicted'])
  plot = (
      + gg.aes(x='index', y='value', color='variable')
      + gg.geom_line()
  return plot

In [0]:
# Grab a sample from the validation set.
sample_x, _ = next(valid_ds)
sample_x = sample_x[:, :1]  # Shrink to batch-size 1.

# Generate a prediction, feeding in ground truth at each point as input.
predicted, _ = model.apply(trained_params, None, sample_x)

plot = plot_samples(sample_x[1:], predicted[:-1])
del sample_x, predicted

# Typically: the beginning of the predictions are a bit wonky, but the curve
# quickly smoothes out.

If we can't feed in the ground truth (because we don't have it), we can also run the model autoregressively.

In [0]:
def autoregressive_predict(
    trained_params: hk.Params,
    context: jnp.ndarray,
    seq_len: int,
  """Given a context, autoregressively generate the rest of a sine wave."""
  ar_outs = []
  context = jax.device_put(context)
  for _ in range(seq_len - context.shape[0]):
    full_context = jnp.concatenate([context] + ar_outs)
    outs, _ = jax.jit(model.apply)(trained_params, None, full_context)
    # Append the newest prediction to ar_outs.
  # Return the final full prediction.
  return outs

sample_x, _ = next(valid_ds)
context_length = SEQ_LEN // 8
# Cut the batch-size 1 context from the start of the sequence.
context = sample_x[:context_length, :1]

# We can reuse params we got from training for inference - as long as the
# declaration order is the same.
predicted = autoregressive_predict(trained_params, context, SEQ_LEN)

plot = plot_samples(sample_x[1:, :1], predicted)
plot += gg.geom_vline(xintercept=len(context), linetype='dashed')
del predicted

Sharing parameters with a different function.

Unfortunately, this is a bit slow - we're doing O(N^2) computation for a sequence of length N.

It'd be better if we could do the autoregressive sampling all at once - but we need to write a new Haiku function for that.

We're in luck - if the Haiku module names match, the same parameters can be used for multiple Haiku functions.

This can be achieved through a combination of two techniques:

  1. If we manually give a unique name to a module, we can ensure that the parameters are directed to the right places.
  2. If modules are instantiated in the same order, they'll have the same names in different functions.

Here, we rely on method #2 to create a fast autoregressive prediction.

In [0]:
def fast_autoregressive_predict_fn(context, seq_len):
  """Given a context, autoregressively generate the rest of a sine wave."""
  core = hk.LSTM(32)
  dense = hk.Linear(1)
  state = core.initial_state(context.shape[1])
  # Unroll over the context using `hk.dynamic_unroll`.
  # As before, we `hk.BatchApply` the Linear for efficiency.
  context_outs, state = hk.dynamic_unroll(core, context, state)
  context_outs = hk.BatchApply(dense)(context_outs)

  # Now, unroll one step at a time using the running recurrent state.
  ar_outs = []
  x = context_outs[-1]
  for _ in range(seq_len - context.shape[0]):
    x, state = core(x, state)
    x = dense(x)
  return jnp.concatenate([context_outs, jnp.stack(ar_outs)])

fast_ar_predict = hk.transform(fast_autoregressive_predict_fn)
fast_ar_predict = jax.jit(fast_ar_predict.apply, static_argnums=3)
# Reuse the same context from the previous cell.
predicted = fast_ar_predict(trained_params, None, context, SEQ_LEN)
# The plots should be equivalent!
plot = plot_samples(sample_x[1:, :1], predicted[:-1])
plot += gg.geom_vline(xintercept=len(context), linetype='dashed')

In [0]:
%timeit autoregressive_predict(trained_params, context, SEQ_LEN)
%timeit fast_ar_predict(trained_params, None, context, SEQ_LEN)