In [0]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This notebook was designed to run on TPU.
To use TPUs in Colab, click "Runtime" on the main menu bar and select Change runtime type. Set "TPU" as the hardware accelerator.
In [0]:
# Grab newest JAX version.
!pip install --upgrade -q jax==0.1.57 jaxlib==0.1.37
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
In [0]:
!pip install --upgrade -q sentencepiece
!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.0
from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.supervised import inputs
import numpy as onp
import jax.numpy as np
from scipy.special import softmax
from sentencepiece import SentencePieceProcessor
In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of "Crime and Punishment" on each of the 8 TPU cores (over 500,000 tokens per 8GB of memory).
In [0]:
# Import a copy of "Crime and Punishment", by Fyodor Dostoevsky
with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:
text = f.read()
# The file read above includes metadata and licensing information.
# For training our language model, we will only use the actual novel text.
start = text.find('CRIME AND PUNISHMENT') # skip header
start = text.find('CRIME AND PUNISHMENT', start + 1) # skip header
start = text.find('CRIME AND PUNISHMENT', start + 1) # skip translator preface
end = text.rfind('End of Project') # skip extra text at the end
text = text[start:end].strip()
In [0]:
# Load a BPE vocabulaary with 320 types. This mostly consists of single letters
# and pairs of letters, but it has some common words and word pieces, too.
!gsutil cp gs://trax-ml/reformer/cp.320.* .
TOKENIZER = SentencePieceProcessor()
TOKENIZER.load('cp.320.model')
In [0]:
# Tokenize
IDS = TOKENIZER.EncodeAsIds(text)
IDS = onp.asarray(IDS, dtype=onp.int32)
PAD_AMOUNT = 512 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])
As we see above, "Crime and Punishment" has just over half a million tokens with the BPE vocabulary we have selected.
Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.
We have 8 TPU cores, so we will separately randomize the amount of padding for each core.
In [0]:
# Set up the data pipeline.
def my_inputs(n_devices):
while True:
inputs = []
mask = []
pad_amounts = onp.random.choice(PAD_AMOUNT, n_devices)
for i in range(n_devices):
inputs.append(onp.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
mode='constant'))
mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32),
(pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
mode='constant'))
inputs = onp.stack(inputs)
mask = onp.stack(mask)
yield (inputs, inputs, mask)
print("(device count, tokens per device) = ",
next(my_inputs(trax.math.device_count()))[0].shape)
In [0]:
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.inputs
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Our model will have 6 layers, alternating between the LSH attention proposed
# in the Reformer paper and local attention within a certain context window.
n_layers = 6
attn_type = [
@TimeBinCausalAttention,
@LSHCausalAttention,
@TimeBinCausalAttention,
@LSHCausalAttention,
@TimeBinCausalAttention,
@LSHCausalAttention,
]
share_qk = False # LSHCausalAttention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 524288
# Parameters for MultifactorSchedule:
# ==============================================================================
MultifactorSchedule.constant = 0.01
MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay'
MultifactorSchedule.warmup_steps = 100
MultifactorSchedule.steps_per_cycle = 900
# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9
# Parameters for TimeBinCausalAttention:
# ==============================================================================
TimeBinCausalAttention.bin_length = 64
TimeBinCausalAttention.dropout = 0.05
TimeBinCausalAttention.n_bins = None
TimeBinCausalAttention.share_qk = %share_qk
# Parameters for LSHCausalAttention:
# ==============================================================================
LSHCausalAttention.allow_duplicate_attention = False
LSHCausalAttention.attend_across_buckets = True
LSHCausalAttention.rehash_each_round = True
LSHCausalAttention.data_rotation = False
LSHCausalAttention.n_bins = 4096
LSHCausalAttention.n_buckets = 8192
LSHCausalAttention.factorize_hash = [64, 128]
LSHCausalAttention.n_hashes = 1
LSHCausalAttention.one_rng = False
LSHCausalAttention.hard_k = 0
LSHCausalAttention.dropout = 0.0
LSHCausalAttention.drop_for_hash_rate = 0.0
LSHCausalAttention.max_len_for_inference = 2048
LSHCausalAttention.bucket_capacity_for_inference = 64
# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 512
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 320
ReformerLM.share_qk = %share_qk
ReformerLM.axial_pos_shape = (512, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")
In [0]:
# Set up a Trainer.
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl # Remove old model
trainer = trax.supervised.Trainer(
model=trax.models.ReformerLM,
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adam,
lr_schedule=trax.lr.MultifactorSchedule,
inputs=trax.supervised.inputs.Inputs(my_inputs),
output_dir=output_dir,
has_weights=True)
In [0]:
# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)
In [0]:
# Train for 600 steps total
# The first ~20 steps are slow to run, but after that it reaches steady-state
# speed. This will take at least 30 minutes to run to completion, but can safely
# be interrupted by selecting "Runtime > Interrupt Execution" from the menu.
# The language model won't be exceptionally good when trained for just a few
# steps and with minimal regularization. However, we can still sample from it to
# see what it learns.
trainer.train_epoch(n_steps=9, n_eval_steps=1)
for _ in range(59):
trainer.train_epoch(n_steps=10, n_eval_steps=1)
In [0]:
# As we report in the Reformer paper, increasing the number of hashing rounds
# helps with quality. We can even increase the number of hashing rounds at
# evaluation time only.
gin.parse_config("""LSHCausalAttention.n_hashes = 4""")
model_infer = trax.models.ReformerLM(mode='predict')
In [0]:
# Prepare a jitted copy of the model.
jit_model_infer = trax.layers.base._accelerate(
model_infer._forward_internal, trax.math.device_count())
# Set up the initial state for sampling.
infer_state = model_infer.new_weights_and_state(
trax.supervised.trainer_lib.ShapeDtype((1,1), dtype=np.int32))[1]
infer_state = trainer._for_n_devices(infer_state)
In [0]:
def sample(length=2048, prompt=None):
"""Sample from the ReformerLM model"""
model_weights = trainer._opt_state[0][0]
# Token id 0 is the equivalent of a "start" token
cur_inputs = np.zeros((trax.math.device_count(), 1, 1), dtype=np.int32)
cur_state = infer_state
rngs = trax.math.random.split(trax.math.random.get_prng(0), trax.math.device_count())
all_samples = []
if prompt is not None:
prompt = np.asarray(
[TOKENIZER.EncodeAsIds(prompt)] * trax.math.device_count())
for iteration in range(length):
logits, cur_state = jit_model_infer(
cur_inputs,
model_weights,
cur_state,
rngs)
if prompt is not None and iteration < prompt.shape[1]:
cur_samples = onp.array(prompt[:, iteration], dtype=int)
else:
logits = onp.array(logits)[:,0,0,:]
probs = onp.exp(logits)
cur_samples = [onp.random.choice(probs.shape[-1], p=probs[i,:])
for i in range(probs.shape[0])]
cur_samples = onp.array(cur_samples, dtype=int)
all_samples.append(cur_samples)
cur_inputs = np.array(cur_samples[:,None,None])
all_samples = onp.stack(all_samples, -1)
return all_samples
In [0]:
# Sample from the Reformer language model, given a prefix.
samples = sample(length=128, prompt="There was a time when")
for ids in samples:
print(TOKENIZER.DecodeIds(ids.tolist()))
In [0]: