Flax Language Model Example

A language model learns a probability distribution over sentences from a given corpus by modelling each subsequent token (character, word, word-piece, etc.) as an autoregressive model over past observed tokens. This conditional distribution is commonly approximated using a "Transformer" decoder-stack.

Here we adapt the main training script for FLAX's lm1b language model example for running live in the colab environment.

Preparatory Steps

Upgrade Local JAX + FLAX Packages


In [0]:
# Install the newest JAX and FLAX versions.
!pip install --upgrade -q jax==0.1.61 jaxlib==0.1.42 flax==0.1.0rc2
# Grab flax example code
!git clone -b master https://github.com/google/flax.git flaxrepo

TPU Configuration

⚠️ Make sure the Colab Runtime is set to Accelerator: TPU.
Menu: Runtime --> Change runtime type
Popup: Hardware Accelerator --> TPU


In [0]:
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
import os
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

Imports


In [0]:
import functools
import itertools
import os
import time

import flax
from flax import jax_utils
from flax import nn
from flax import optim
from flax.metrics import tensorboard
from flax.training import checkpoints
from flax.training import common_utils

import jax
from jax import random
import jax.nn
import jax.numpy as jnp

import numpy as np

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
tf.enable_v2_behavior()

# We directly import the FLAX Language Model example code.
from flaxrepo.examples.lm1b import decode
from flaxrepo.examples.lm1b import input_pipeline
from flaxrepo.examples.lm1b import models

Hyperparameters and Configuration


In [0]:
# Make a local directory to store run data and checkpoints, etc.
!mkdir run_1

In [0]:
model_dir = '/content/run_1'  # Directory to store data in.
save_checkpoints = True       # Save local checkpoints?
restore_checkpoints = True    # Restore from last local checkpoint?     
checkpoint_freq = 5000        # How often to save checkpoints

num_train_steps = 500000      # Max number of training steps.
eval_frequency = 1000         # How often to run model evaluation.
num_eval_steps = 20           # Number of steps to take during evaluation.
random_seed = 0               # JAX PRNG random seed.
learning_rate = 0.05          # Base learning rate.
weight_decay = 1e-1           # AdamW-style relative weight decay factor.
batch_size = 256              # "Target" Batch size.
max_target_length = 256       # Maximum input length.
max_eval_target_length = 256  # Maximum eval-set input length.

lm_emb_dim = 512              # LM initial token embedding dimension.
lm_num_heads = 8              # Number of heads in decoder layers.
lm_num_layers = 6             # Number of decoder layers.
lm_qkv_dim = 512              # Decoder query/key/value depth.
lm_mlp_dim = 2048             # Feedforward (MLP) layer depth.

prompt_str = 'The British '   # Prompt for LM Inference.
sampling_temperature = 0.6    # Temperature to sample LM at.
sampling_top_k = 20           # If > 0, use TopK temperature sampling.
max_predict_token_length = 50 # Maximum number of subword tokens to predict.

Datasets

Wikitext-2 Dataset (FAST)

Instead of having to wait on locally building the LM1B dataset, we can instead ingest the smaller Wikitext-2 dataset extracted from a small subset of the english wikipedia to train on.


In [0]:
!wget --quiet https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
!unzip wikitext-2-raw-v1.zip

In [0]:
def preprocess_ds(path):
  """Extract content sentences from Wikitext-2 dataset."""
  dataset = tf.data.TextLineDataset(path)
  # Drop article headers.
  def content_filter(source):
    return tf.logical_not(tf.strings.regex_full_match(
        source, 
        '([[:space:]][=])+.+([[:space:]][=])+[[:space:]]*'))
  dataset = dataset.filter(content_filter)

  # Split paragraphs to lines.
  dataset = dataset.map(lambda x: tf.strings.split(x, ' . '))
  dataset = dataset.unbatch()

  # Remove blank lines.
  def min_filter(min_len):
    def filter_fn(source):
      return tf.greater(tf.strings.length(source), tf.constant(min_len))
    return filter_fn
  dataset = dataset.filter(min_filter(1))

  return dataset

# Get the raw train and eval datasets.
train_ds = preprocess_ds('/content/wikitext-2-raw/wiki.train.raw')
eval_ds = preprocess_ds('/content/wikitext-2-raw/wiki.valid.raw')

# Build subword tokenizer.
try:
  # If we already ran this cell, reload the cached subword vocab file.
  encoder = tfds.features.text.SubwordTextEncoder.load_from_file('wikitext2')
except tf.errors.NotFoundError:
  # Build subword tokenizer from data. Takes ~1 minute.
  encoder = tfds.features.text.SubwordTextEncoder.build_from_corpus(
      (x._numpy() for x in train_ds),
      target_vocab_size=2**13,
      max_corpus_chars=10**6)
  encoder.save_to_file('wikitext2')

# Encode strings with subword tokenizer.
def tf_encode(x):
  result = tf.py_function(lambda s: tf.constant(encoder.encode(s.numpy())), 
                          [x,], 
                          tf.int32)
  result.set_shape([None])
  return result
train_ds=train_ds.map(tf_encode)
eval_ds=eval_ds.map(tf_encode)

# Created zero-padded length-bucketed batches.
train_ds = input_pipeline.lm1b_preprocess(train_ds,
                training=True,
                n_devices=jax.local_device_count(),
                max_target_length=256,
                max_eval_target_length=256,
                batch_size=256,
                drop_remainder=True)

eval_ds = input_pipeline.lm1b_preprocess(eval_ds,
                training=False,
                n_devices=jax.local_device_count(),
                max_target_length=256,
                max_eval_target_length=256,
                batch_size=256,
                drop_remainder=True)

⚠️ LM1B Dataset (SLOW)

The LM1B dataset is fairly large and requires significant upfront preprocessing. Doing it on a colab VM is possible but can be frustrating as it will take several hours to finish during which time the VM could reset.

We strongly recommend downloading and preparing the dataset on a cloud instance and storing the prepared dataset on a GCS Bucket. Another alternative is preparing the dataset on a local machine and uploading it to a Google Drive folder which can be mounted on colab.

from google.colab import drive
drive.mount('/content/drive')
!cp -r /content/drive/My\ Drive/tensorflow_datasets/lm1b ./tensorflow_datasets/lm1b

More IO documentation at the Colab IO notebook.


In [0]:
# On colab it takes an hour to download and several more hours to preprocess, 
# and you may need to babysit the colab to keep it alive. If you do this be 
# sure to copy it to a google drive folder or elsewhere as storage on a 
# Colab VM is ephemeral!

# builder = tfds.builder('lm1b/subwords32k')
# builder.download_and_prepare(download_dir='/content/tensorflow_datasets')

In [0]:
# (below commented out to avoid triggering on "run all")

# # Point to existing local data copied over:
# data_dir = '/content/tensorflow_datasets'
# # or a GCS Bucket:
# # data_dir = "gs://YOUR_BUCKET_NAME/tensorflow_datasets"
# train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(
#       n_devices=jax.local_device_count(),
#       data_dir=data_dir,
#       batch_size=batch_size,
#       dynamic_batching=True,
#       max_target_length=max_target_length,
#       max_eval_target_length=max_eval_target_length)
# vocab_size = info_ds['text'].encoder.vocab_size
# encoder = info_ds['text'].encoder

Model

Defined in the examples/lm1b/models.py file.


In [0]:
# Init PRNG Stream.
rng = random.PRNGKey(random_seed)
rng, init_rng = random.split(rng)
# We init the first set of dropout PRNG keys, but update it afterwards inside
# the main pmap'd training update for performance.
dropout_rngs = random.split(rng, jax.local_device_count())

Model, Optimizer, Learning Rate


In [0]:
@functools.partial(jax.jit, static_argnums=(1, 2))
def create_model(key, input_shape, model_kwargs):
  """
  We create a model definition from the top-level Language Model and 
  passed in hyperparameters.
  """
  module = models.TransformerLM.partial(**model_kwargs)
  # We initialize an autoregressive Cache collection for fast, autoregressive
  # decoding through the language model's decoder layers.
  with nn.attention.Cache().mutate() as cache_def:
    # create_by_shape initializes the model parameters.
    _, model = module.create_by_shape(key,
                                         [(input_shape, jnp.float32)],
                                         cache=cache_def)
  return model, cache_def

# Init model and optimizer.
vocab_size = encoder.vocab_size
input_shape = (batch_size, max_target_length)
transformer_lm_kwargs = {
    'vocab_size': vocab_size,
    'emb_dim': lm_emb_dim,
    'num_heads': lm_num_heads,
    'num_layers': lm_num_layers,
    'qkv_dim': lm_qkv_dim,
    'mlp_dim': lm_mlp_dim,
    'max_len': max(max_target_length, max_eval_target_length)
}
model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs)


def create_optimizer(model, learning_rate):
  """
  Here we define the AdamW optimizer we'll use.
  """
  optimizer_def = optim.Adam(
      learning_rate,
      beta1=0.9,
      beta2=0.98,
      eps=1e-9,
      weight_decay=weight_decay)
  optimizer = optimizer_def.create(model)
  optimizer = optimizer.replicate()
  return optimizer

# Build an optimizer from the model.
optimizer = create_optimizer(model, learning_rate)
# Don't keep a copy of the initial model object.
# if needed, we instead access the model directly via optimizer.target
del model


def create_learning_rate_scheduler(base_learning_rate=0.5, warmup_steps=8000):
  """Define our learning rate schedule."""
  def step_fn(step):
    return jnp.asarray(
        base_learning_rate * 
        jnp.minimum(1.0, step / warmup_steps) /
        jnp.sqrt(jnp.maximum(step, warmup_steps)), dtype=jnp.float32)
  return step_fn

learning_rate_fn = create_learning_rate_scheduler(
    base_learning_rate=learning_rate)

Loss Function and Auxiliary Metrics


In [0]:
def compute_weighted_cross_entropy(logits, targets, weights=None):
  """Compute weighted cross entropy and entropy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch x length]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))
  onehot_targets = common_utils.onehot(targets, logits.shape[-1])
  loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
  normalizing_factor = onehot_targets.sum()
  if weights is not None:
    loss = loss * weights
    normalizing_factor = weights.sum()

  return loss.sum(), normalizing_factor


def compute_weighted_accuracy(logits, targets, weights=None):
  """Compute weighted accuracy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical targets [batch, length] int array.
   weights: None or array of shape [batch x length]

  Returns:
    Tuple of scalar accuracy and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))
  loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
  normalizing_factor = jnp.prod(logits.shape[:-1])
  if weights is not None:
    loss = loss * weights
    normalizing_factor = weights.sum()

  return loss.sum(), normalizing_factor


def compute_metrics(logits, labels, weights):
  """Compute summary metrics."""
  loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights)
  acc, _ = compute_weighted_accuracy(logits, labels, weights)
  metrics = {
      'loss': loss,
      'accuracy': acc,
      'denominator': weight_sum,
  }
  metrics = jax.lax.psum(metrics, axis_name='batch')
  return metrics

Main training, evaluation, and inference functions.


In [0]:
def train_step(optimizer, inputs, learning_rate_fn, dropout_rng=None):
  """Perform a single training step."""
  weights = jnp.where(inputs > 0, 1, 0)

  # We handle PRNG splitting inside the top pmap, rather
  # than handling it outside in the training loop - doing the
  # latter can add some stalls to the devices.
  dropout_rng, new_dropout_rng = random.split(dropout_rng)

  def loss_fn(model):
    """Loss function used for training."""
    with nn.stochastic(dropout_rng):
      logits = model(inputs, train=True)
    loss, weight_sum = compute_weighted_cross_entropy(logits, inputs, weights)
    mean_loss = loss / weight_sum
    return mean_loss, logits

  step = optimizer.state.step
  lr = learning_rate_fn(step)
  new_optimizer, _, logits = optimizer.optimize(loss_fn, learning_rate=lr)
  metrics = compute_metrics(logits, inputs, weights)
  metrics['learning_rate'] = lr

  return new_optimizer, metrics, new_dropout_rng

# parallelize the training step with JAX's pmap.
p_train_step = jax.pmap(
    functools.partial(train_step, learning_rate_fn=learning_rate_fn),
    axis_name='batch')


def eval_step(model, inputs):
  weights = jnp.where(inputs > 0, 1, 0)
  logits = model(inputs, train=False)
  return compute_metrics(logits, inputs, weights)

# parallelize the evaluation step with JAX's pmap.
p_eval_step = jax.pmap(eval_step, axis_name='batch')


def predict_step(inputs, model, cache, prng_key):
  """Fast sampling of language model from prompt."""
  prefix_len = inputs.shape[1]
  pad_len = max_predict_token_length - prefix_len
  padded_inputs = jnp.pad(inputs, jnp.array([[0, 0], [0, pad_len]]))

  def tokens_ids_to_logits(ids, cache):
    """Token slice to logits from decoder model."""
    with cache.mutate() as new_cache:
      logits = model(ids, shift=False, train=False, cache=new_cache)
    # Remove singleton sequence-length dimension from model.
    # [batch, 1, vocab] --> [batch, vocab]
    logits = logits.squeeze(axis=1)
    return logits, new_cache

  sampled_seqs = decode.temperature_sample(
      padded_inputs,
      cache,
      tokens_ids_to_logits,
      prng_key,
      temperature=sampling_temperature,
      topk=sampling_top_k,
      eos_token=2**16)  # No EOS tokens used in default lm1b dataset encoding.

  return sampled_seqs

# parallelize the fast autoregressive sampler with JAX's pmap.
p_pred_step = jax.pmap(predict_step, axis_name='batch')

Tensorboard Logging


In [0]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard

In [0]:
# Launch an inline tensorboard panel.
%tensorboard --logdir /content/run_1

Main Training Loop


In [0]:
# Summary writers for tensorboard.
train_summary_writer = tensorboard.SummaryWriter(
    os.path.join(model_dir, 'train'))
eval_summary_writer = tensorboard.SummaryWriter(
    os.path.join(model_dir, 'eval'))

# Initialize training dataset iterator.
train_iter = iter(train_ds)
start_step = 0

if restore_checkpoints:
  # Restore unreplicated optimizer + model state from last checkpoint.
  optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)
  # Grab last step from the first of the optimizer replicas.
  start_step = int(optimizer.state.step[0])

metrics_all = []    # We aggregate and average training metrics here.
tick = time.time()  # Initialize step timer.

print('Compiling XLA programs for different input shapes,'
      ' this can take 5-10 minutes.')
for step, batch in zip(range(start_step, num_train_steps), train_iter):

  # Core training step.
  batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))
  optimizer, metrics, dropout_rngs = p_train_step(
      optimizer, batch, dropout_rng=dropout_rngs)
  metrics_all.append(metrics)

  # Save a Checkpoint
  if step % checkpoint_freq == 0 and step > 0:
    if save_checkpoints:
      checkpoints.save_checkpoint(model_dir, optimizer, step)

  # Periodic metric handling.
  if step % eval_frequency == 0 and step > 0:
    metrics_all = common_utils.get_metrics(metrics_all)
    lr = metrics_all.pop('learning_rate').mean()
    metrics_sums = jax.tree_map(jnp.sum, metrics_all)
    denominator = metrics_sums.pop('denominator')
    summary = jax.tree_map(lambda x: x / denominator, metrics_sums)
    summary['learning_rate'] = lr
    summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
    
    # Update step timer.
    tock = time.time()
    steps_per_sec = eval_frequency / (tock - tick)
    tick = tock
    train_summary_writer.scalar('steps per second', steps_per_sec, step)

    print('train in step: %d, loss: %.4f' %(step, summary['loss']))
    for key, val in summary.items():
      train_summary_writer.scalar(key, val, step)
    train_summary_writer.flush()
    # reset metric accumulation for next evaluation cycle.
    metrics_all = []


    # Eval Metrics -----------------------------------------------------------
    eval_metrics = []
    for _, eval_batch in zip(range(num_eval_steps), iter(eval_ds)):
      eval_batch = common_utils.shard(
          jax.tree_map(lambda x: x._numpy(), eval_batch))
      metrics = p_eval_step(optimizer.target, eval_batch)
      eval_metrics.append(metrics)

    eval_metrics = common_utils.get_metrics(eval_metrics)
    eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
    eval_denominator = eval_metrics_sums.pop('denominator')
    eval_summary = jax.tree_map(
        lambda x: x / eval_denominator,
        eval_metrics_sums)
    eval_summary['perplexity'] = jnp.clip(
        jnp.exp(eval_summary['loss']), a_max=1.0e4)

    print('eval in step: %d, loss: %.4f'%(step, eval_summary['loss']))
    for key, val in eval_summary.items():
      eval_summary_writer.scalar(key, val, step)
    eval_summary_writer.flush()


    # Fast inference of prompt extension using trained LM. -------------------
    # Update rng stream for prediction.
    rng, subrng = jax.random.split(rng)
    pred_rngs = random.split(subrng, jax.local_device_count())

    # Encode provided text prompt to initialize sampling.
    prompt = jnp.array(encoder.encode(prompt_str))
    prompt = jax_utils.replicate(prompt)
    prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))

    # Initialize the autoregressive cache, run prediction loop, collect data.
    cache = jax_utils.replicate(
        cache_def.initialize_cache((1, max_predict_token_length)))
    predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)
    predicted = np.array(predicted).reshape(
      (predicted.shape[0] * predicted.shape[1],) + predicted.shape[2:])

    # Write examples for tensorboard.
    print(encoder.decode(predicted[0]))
    exemplars = ''
    for n in range(predicted.shape[0]):
      exemplars += encoder.decode(predicted[n]) + '\n\n'
    eval_summary_writer.text('samples', exemplars, step)
    eval_summary_writer.flush()

Fast inference on the language model


In [0]:
# Optional - we can skip training and restore from a saved checkpoint.
# optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)

In [0]:
def predict(rng, prompt_str):
  # Update rng stream for prediction.
  rng, subrng = jax.random.split(rng)
  pred_rngs = random.split(subrng, jax.local_device_count())
  # Encode provided text prompt to initialize sampling.
  prompt = jnp.array(encoder.encode(prompt_str))
  prompt = jax_utils.replicate(prompt)
  prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
  # Initialize the autoregressive cache, run prediction loop, collect data.
  cache = jax_utils.replicate(
      cache_def.initialize_cache((1, max_predict_token_length)))
  predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)
  predicted = np.array(predicted).reshape(
      (predicted.shape[0] * predicted.shape[1],) + predicted.shape[2:])
  # Print generated sentences.
  exemplars = ''
  for n in range(predicted.shape[0]):
    exemplars += encoder.decode(predicted[n]) + '\n\n'
  print(exemplars)
  # Return rng stream.
  return rng

In [0]:
rng = predict(rng, "The kakapo is ")

In [0]: