A language model learns a probability distribution over sentences from a given corpus by modelling each subsequent token (character, word, word-piece, etc.) as an autoregressive model over past observed tokens. This conditional distribution is commonly approximated using a "Transformer" decoder-stack.
Here we adapt the main training script for FLAX's lm1b language model example for running live in the colab environment.
In [0]:
# Install the newest JAX and FLAX versions.
!pip install --upgrade -q jax==0.1.61 jaxlib==0.1.42 flax==0.1.0rc2
# Grab flax example code
!git clone -b master https://github.com/google/flax.git flaxrepo
⚠️ Make sure the Colab Runtime is set to Accelerator: TPU.
Menu: Runtime --> Change runtime type
Popup: Hardware Accelerator --> TPU
In [0]:
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
import os
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
In [0]:
import functools
import itertools
import os
import time
import flax
from flax import jax_utils
from flax import nn
from flax import optim
from flax.metrics import tensorboard
from flax.training import checkpoints
from flax.training import common_utils
import jax
from jax import random
import jax.nn
import jax.numpy as jnp
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
tf.enable_v2_behavior()
# We directly import the FLAX Language Model example code.
from flaxrepo.examples.lm1b import decode
from flaxrepo.examples.lm1b import input_pipeline
from flaxrepo.examples.lm1b import models
In [0]:
# Make a local directory to store run data and checkpoints, etc.
!mkdir run_1
In [0]:
model_dir = '/content/run_1' # Directory to store data in.
save_checkpoints = True # Save local checkpoints?
restore_checkpoints = True # Restore from last local checkpoint?
checkpoint_freq = 5000 # How often to save checkpoints
num_train_steps = 500000 # Max number of training steps.
eval_frequency = 1000 # How often to run model evaluation.
num_eval_steps = 20 # Number of steps to take during evaluation.
random_seed = 0 # JAX PRNG random seed.
learning_rate = 0.05 # Base learning rate.
weight_decay = 1e-1 # AdamW-style relative weight decay factor.
batch_size = 256 # "Target" Batch size.
max_target_length = 256 # Maximum input length.
max_eval_target_length = 256 # Maximum eval-set input length.
lm_emb_dim = 512 # LM initial token embedding dimension.
lm_num_heads = 8 # Number of heads in decoder layers.
lm_num_layers = 6 # Number of decoder layers.
lm_qkv_dim = 512 # Decoder query/key/value depth.
lm_mlp_dim = 2048 # Feedforward (MLP) layer depth.
prompt_str = 'The British ' # Prompt for LM Inference.
sampling_temperature = 0.6 # Temperature to sample LM at.
sampling_top_k = 20 # If > 0, use TopK temperature sampling.
max_predict_token_length = 50 # Maximum number of subword tokens to predict.
Instead of having to wait on locally building the LM1B dataset, we can instead ingest the smaller Wikitext-2 dataset extracted from a small subset of the english wikipedia to train on.
In [0]:
!wget --quiet https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip
!unzip wikitext-2-raw-v1.zip
In [0]:
def preprocess_ds(path):
"""Extract content sentences from Wikitext-2 dataset."""
dataset = tf.data.TextLineDataset(path)
# Drop article headers.
def content_filter(source):
return tf.logical_not(tf.strings.regex_full_match(
source,
'([[:space:]][=])+.+([[:space:]][=])+[[:space:]]*'))
dataset = dataset.filter(content_filter)
# Split paragraphs to lines.
dataset = dataset.map(lambda x: tf.strings.split(x, ' . '))
dataset = dataset.unbatch()
# Remove blank lines.
def min_filter(min_len):
def filter_fn(source):
return tf.greater(tf.strings.length(source), tf.constant(min_len))
return filter_fn
dataset = dataset.filter(min_filter(1))
return dataset
# Get the raw train and eval datasets.
train_ds = preprocess_ds('/content/wikitext-2-raw/wiki.train.raw')
eval_ds = preprocess_ds('/content/wikitext-2-raw/wiki.valid.raw')
# Build subword tokenizer.
try:
# If we already ran this cell, reload the cached subword vocab file.
encoder = tfds.features.text.SubwordTextEncoder.load_from_file('wikitext2')
except tf.errors.NotFoundError:
# Build subword tokenizer from data. Takes ~1 minute.
encoder = tfds.features.text.SubwordTextEncoder.build_from_corpus(
(x._numpy() for x in train_ds),
target_vocab_size=2**13,
max_corpus_chars=10**6)
encoder.save_to_file('wikitext2')
# Encode strings with subword tokenizer.
def tf_encode(x):
result = tf.py_function(lambda s: tf.constant(encoder.encode(s.numpy())),
[x,],
tf.int32)
result.set_shape([None])
return result
train_ds=train_ds.map(tf_encode)
eval_ds=eval_ds.map(tf_encode)
# Created zero-padded length-bucketed batches.
train_ds = input_pipeline.lm1b_preprocess(train_ds,
training=True,
n_devices=jax.local_device_count(),
max_target_length=256,
max_eval_target_length=256,
batch_size=256,
drop_remainder=True)
eval_ds = input_pipeline.lm1b_preprocess(eval_ds,
training=False,
n_devices=jax.local_device_count(),
max_target_length=256,
max_eval_target_length=256,
batch_size=256,
drop_remainder=True)
The LM1B dataset is fairly large and requires significant upfront preprocessing. Doing it on a colab VM is possible but can be frustrating as it will take several hours to finish during which time the VM could reset.
We strongly recommend downloading and preparing the dataset on a cloud instance and storing the prepared dataset on a GCS Bucket. Another alternative is preparing the dataset on a local machine and uploading it to a Google Drive folder which can be mounted on colab.
from google.colab import drive
drive.mount('/content/drive')
!cp -r /content/drive/My\ Drive/tensorflow_datasets/lm1b ./tensorflow_datasets/lm1b
More IO documentation at the Colab IO notebook.
In [0]:
# On colab it takes an hour to download and several more hours to preprocess,
# and you may need to babysit the colab to keep it alive. If you do this be
# sure to copy it to a google drive folder or elsewhere as storage on a
# Colab VM is ephemeral!
# builder = tfds.builder('lm1b/subwords32k')
# builder.download_and_prepare(download_dir='/content/tensorflow_datasets')
In [0]:
# (below commented out to avoid triggering on "run all")
# # Point to existing local data copied over:
# data_dir = '/content/tensorflow_datasets'
# # or a GCS Bucket:
# # data_dir = "gs://YOUR_BUCKET_NAME/tensorflow_datasets"
# train_ds, eval_ds, info_ds = input_pipeline.get_lm1b_datasets(
# n_devices=jax.local_device_count(),
# data_dir=data_dir,
# batch_size=batch_size,
# dynamic_batching=True,
# max_target_length=max_target_length,
# max_eval_target_length=max_eval_target_length)
# vocab_size = info_ds['text'].encoder.vocab_size
# encoder = info_ds['text'].encoder
Defined in the examples/lm1b/models.py file.
In [0]:
# Init PRNG Stream.
rng = random.PRNGKey(random_seed)
rng, init_rng = random.split(rng)
# We init the first set of dropout PRNG keys, but update it afterwards inside
# the main pmap'd training update for performance.
dropout_rngs = random.split(rng, jax.local_device_count())
In [0]:
@functools.partial(jax.jit, static_argnums=(1, 2))
def create_model(key, input_shape, model_kwargs):
"""
We create a model definition from the top-level Language Model and
passed in hyperparameters.
"""
module = models.TransformerLM.partial(**model_kwargs)
# We initialize an autoregressive Cache collection for fast, autoregressive
# decoding through the language model's decoder layers.
with nn.attention.Cache().mutate() as cache_def:
# create_by_shape initializes the model parameters.
_, model = module.create_by_shape(key,
[(input_shape, jnp.float32)],
cache=cache_def)
return model, cache_def
# Init model and optimizer.
vocab_size = encoder.vocab_size
input_shape = (batch_size, max_target_length)
transformer_lm_kwargs = {
'vocab_size': vocab_size,
'emb_dim': lm_emb_dim,
'num_heads': lm_num_heads,
'num_layers': lm_num_layers,
'qkv_dim': lm_qkv_dim,
'mlp_dim': lm_mlp_dim,
'max_len': max(max_target_length, max_eval_target_length)
}
model, cache_def = create_model(init_rng, input_shape, transformer_lm_kwargs)
def create_optimizer(model, learning_rate):
"""
Here we define the AdamW optimizer we'll use.
"""
optimizer_def = optim.Adam(
learning_rate,
beta1=0.9,
beta2=0.98,
eps=1e-9,
weight_decay=weight_decay)
optimizer = optimizer_def.create(model)
optimizer = optimizer.replicate()
return optimizer
# Build an optimizer from the model.
optimizer = create_optimizer(model, learning_rate)
# Don't keep a copy of the initial model object.
# if needed, we instead access the model directly via optimizer.target
del model
def create_learning_rate_scheduler(base_learning_rate=0.5, warmup_steps=8000):
"""Define our learning rate schedule."""
def step_fn(step):
return jnp.asarray(
base_learning_rate *
jnp.minimum(1.0, step / warmup_steps) /
jnp.sqrt(jnp.maximum(step, warmup_steps)), dtype=jnp.float32)
return step_fn
learning_rate_fn = create_learning_rate_scheduler(
base_learning_rate=learning_rate)
In [0]:
def compute_weighted_cross_entropy(logits, targets, weights=None):
"""Compute weighted cross entropy and entropy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch x length]
Returns:
Tuple of scalar loss and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
(str(logits.shape), str(targets.shape)))
onehot_targets = common_utils.onehot(targets, logits.shape[-1])
loss = -jnp.sum(onehot_targets * nn.log_softmax(logits), axis=-1)
normalizing_factor = onehot_targets.sum()
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_weighted_accuracy(logits, targets, weights=None):
"""Compute weighted accuracy for log probs and targets.
Args:
logits: [batch, length, num_classes] float array.
targets: categorical targets [batch, length] int array.
weights: None or array of shape [batch x length]
Returns:
Tuple of scalar accuracy and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
(str(logits.shape), str(targets.shape)))
loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
normalizing_factor = jnp.prod(logits.shape[:-1])
if weights is not None:
loss = loss * weights
normalizing_factor = weights.sum()
return loss.sum(), normalizing_factor
def compute_metrics(logits, labels, weights):
"""Compute summary metrics."""
loss, weight_sum = compute_weighted_cross_entropy(logits, labels, weights)
acc, _ = compute_weighted_accuracy(logits, labels, weights)
metrics = {
'loss': loss,
'accuracy': acc,
'denominator': weight_sum,
}
metrics = jax.lax.psum(metrics, axis_name='batch')
return metrics
In [0]:
def train_step(optimizer, inputs, learning_rate_fn, dropout_rng=None):
"""Perform a single training step."""
weights = jnp.where(inputs > 0, 1, 0)
# We handle PRNG splitting inside the top pmap, rather
# than handling it outside in the training loop - doing the
# latter can add some stalls to the devices.
dropout_rng, new_dropout_rng = random.split(dropout_rng)
def loss_fn(model):
"""Loss function used for training."""
with nn.stochastic(dropout_rng):
logits = model(inputs, train=True)
loss, weight_sum = compute_weighted_cross_entropy(logits, inputs, weights)
mean_loss = loss / weight_sum
return mean_loss, logits
step = optimizer.state.step
lr = learning_rate_fn(step)
new_optimizer, _, logits = optimizer.optimize(loss_fn, learning_rate=lr)
metrics = compute_metrics(logits, inputs, weights)
metrics['learning_rate'] = lr
return new_optimizer, metrics, new_dropout_rng
# parallelize the training step with JAX's pmap.
p_train_step = jax.pmap(
functools.partial(train_step, learning_rate_fn=learning_rate_fn),
axis_name='batch')
def eval_step(model, inputs):
weights = jnp.where(inputs > 0, 1, 0)
logits = model(inputs, train=False)
return compute_metrics(logits, inputs, weights)
# parallelize the evaluation step with JAX's pmap.
p_eval_step = jax.pmap(eval_step, axis_name='batch')
def predict_step(inputs, model, cache, prng_key):
"""Fast sampling of language model from prompt."""
prefix_len = inputs.shape[1]
pad_len = max_predict_token_length - prefix_len
padded_inputs = jnp.pad(inputs, jnp.array([[0, 0], [0, pad_len]]))
def tokens_ids_to_logits(ids, cache):
"""Token slice to logits from decoder model."""
with cache.mutate() as new_cache:
logits = model(ids, shift=False, train=False, cache=new_cache)
# Remove singleton sequence-length dimension from model.
# [batch, 1, vocab] --> [batch, vocab]
logits = logits.squeeze(axis=1)
return logits, new_cache
sampled_seqs = decode.temperature_sample(
padded_inputs,
cache,
tokens_ids_to_logits,
prng_key,
temperature=sampling_temperature,
topk=sampling_top_k,
eos_token=2**16) # No EOS tokens used in default lm1b dataset encoding.
return sampled_seqs
# parallelize the fast autoregressive sampler with JAX's pmap.
p_pred_step = jax.pmap(predict_step, axis_name='batch')
In [0]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard
In [0]:
# Launch an inline tensorboard panel.
%tensorboard --logdir /content/run_1
In [0]:
# Summary writers for tensorboard.
train_summary_writer = tensorboard.SummaryWriter(
os.path.join(model_dir, 'train'))
eval_summary_writer = tensorboard.SummaryWriter(
os.path.join(model_dir, 'eval'))
# Initialize training dataset iterator.
train_iter = iter(train_ds)
start_step = 0
if restore_checkpoints:
# Restore unreplicated optimizer + model state from last checkpoint.
optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)
# Grab last step from the first of the optimizer replicas.
start_step = int(optimizer.state.step[0])
metrics_all = [] # We aggregate and average training metrics here.
tick = time.time() # Initialize step timer.
print('Compiling XLA programs for different input shapes,'
' this can take 5-10 minutes.')
for step, batch in zip(range(start_step, num_train_steps), train_iter):
# Core training step.
batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch))
optimizer, metrics, dropout_rngs = p_train_step(
optimizer, batch, dropout_rng=dropout_rngs)
metrics_all.append(metrics)
# Save a Checkpoint
if step % checkpoint_freq == 0 and step > 0:
if save_checkpoints:
checkpoints.save_checkpoint(model_dir, optimizer, step)
# Periodic metric handling.
if step % eval_frequency == 0 and step > 0:
metrics_all = common_utils.get_metrics(metrics_all)
lr = metrics_all.pop('learning_rate').mean()
metrics_sums = jax.tree_map(jnp.sum, metrics_all)
denominator = metrics_sums.pop('denominator')
summary = jax.tree_map(lambda x: x / denominator, metrics_sums)
summary['learning_rate'] = lr
summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4)
# Update step timer.
tock = time.time()
steps_per_sec = eval_frequency / (tock - tick)
tick = tock
train_summary_writer.scalar('steps per second', steps_per_sec, step)
print('train in step: %d, loss: %.4f' %(step, summary['loss']))
for key, val in summary.items():
train_summary_writer.scalar(key, val, step)
train_summary_writer.flush()
# reset metric accumulation for next evaluation cycle.
metrics_all = []
# Eval Metrics -----------------------------------------------------------
eval_metrics = []
for _, eval_batch in zip(range(num_eval_steps), iter(eval_ds)):
eval_batch = common_utils.shard(
jax.tree_map(lambda x: x._numpy(), eval_batch))
metrics = p_eval_step(optimizer.target, eval_batch)
eval_metrics.append(metrics)
eval_metrics = common_utils.get_metrics(eval_metrics)
eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics)
eval_denominator = eval_metrics_sums.pop('denominator')
eval_summary = jax.tree_map(
lambda x: x / eval_denominator,
eval_metrics_sums)
eval_summary['perplexity'] = jnp.clip(
jnp.exp(eval_summary['loss']), a_max=1.0e4)
print('eval in step: %d, loss: %.4f'%(step, eval_summary['loss']))
for key, val in eval_summary.items():
eval_summary_writer.scalar(key, val, step)
eval_summary_writer.flush()
# Fast inference of prompt extension using trained LM. -------------------
# Update rng stream for prediction.
rng, subrng = jax.random.split(rng)
pred_rngs = random.split(subrng, jax.local_device_count())
# Encode provided text prompt to initialize sampling.
prompt = jnp.array(encoder.encode(prompt_str))
prompt = jax_utils.replicate(prompt)
prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
# Initialize the autoregressive cache, run prediction loop, collect data.
cache = jax_utils.replicate(
cache_def.initialize_cache((1, max_predict_token_length)))
predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)
predicted = np.array(predicted).reshape(
(predicted.shape[0] * predicted.shape[1],) + predicted.shape[2:])
# Write examples for tensorboard.
print(encoder.decode(predicted[0]))
exemplars = ''
for n in range(predicted.shape[0]):
exemplars += encoder.decode(predicted[n]) + '\n\n'
eval_summary_writer.text('samples', exemplars, step)
eval_summary_writer.flush()
In [0]:
# Optional - we can skip training and restore from a saved checkpoint.
# optimizer = checkpoints.restore_checkpoint(model_dir, optimizer)
In [0]:
def predict(rng, prompt_str):
# Update rng stream for prediction.
rng, subrng = jax.random.split(rng)
pred_rngs = random.split(subrng, jax.local_device_count())
# Encode provided text prompt to initialize sampling.
prompt = jnp.array(encoder.encode(prompt_str))
prompt = jax_utils.replicate(prompt)
prompt = jnp.reshape(prompt, (prompt.shape[0], 1, prompt.shape[1]))
# Initialize the autoregressive cache, run prediction loop, collect data.
cache = jax_utils.replicate(
cache_def.initialize_cache((1, max_predict_token_length)))
predicted = p_pred_step(prompt, optimizer.target, cache, pred_rngs)
predicted = np.array(predicted).reshape(
(predicted.shape[0] * predicted.shape[1],) + predicted.shape[2:])
# Print generated sentences.
exemplars = ''
for n in range(predicted.shape[0]):
exemplars += encoder.decode(predicted[n]) + '\n\n'
print(exemplars)
# Return rng stream.
return rng
In [0]:
rng = predict(rng, "The kakapo is ")
In [0]: