In [0]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

 https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Reformer: Text Generation

This notebook was designed to run on TPU.

To use TPUs in Colab, click "Runtime" on the main menu bar and select Change runtime type. Set "TPU" as the hardware accelerator.


In [0]:
# Grab newest JAX version.
!pip install --upgrade -q jax==0.1.57 jaxlib==0.1.37

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [0]:
!pip install --upgrade -q sentencepiece
!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.0

from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import jax
import trax
from trax.supervised import inputs

import numpy as onp
import jax.numpy as np

from scipy.special import softmax

from sentencepiece import SentencePieceProcessor

Setting up data and model

In this notebook, we'll be pushing the limits of just how many tokens we can fit on a single TPU device. The TPUs available in Colab have 8GB of memory per core, and 8 cores. We will set up a Reformer model that can fit a copy of "Crime and Punishment" on each of the 8 TPU cores (over 500,000 tokens per 8GB of memory).


In [0]:
# Import a copy of "Crime and Punishment", by Fyodor Dostoevsky
with GFile('gs://trax-ml/reformer/crime-and-punishment-2554.txt') as f:
  text = f.read()

# The file read above includes metadata and licensing information.
# For training our language model, we will only use the actual novel text.
start = text.find('CRIME AND PUNISHMENT')  # skip header
start = text.find('CRIME AND PUNISHMENT', start + 1)  # skip header
start = text.find('CRIME AND PUNISHMENT', start + 1)  # skip translator preface
end = text.rfind('End of Project')  # skip extra text at the end
text = text[start:end].strip()

In [0]:
# Load a BPE vocabulaary with 320 types. This mostly consists of single letters
# and pairs of letters, but it has some common words and word pieces, too.
!gsutil cp gs://trax-ml/reformer/cp.320.* .

TOKENIZER = SentencePieceProcessor()
TOKENIZER.load('cp.320.model')

In [0]:
# Tokenize
IDS = TOKENIZER.EncodeAsIds(text)
IDS = onp.asarray(IDS, dtype=onp.int32)
PAD_AMOUNT = 512 * 1024 - len(IDS)
print("Number of tokens:", IDS.shape[0])


Number of tokens: 513812

As we see above, "Crime and Punishment" has just over half a million tokens with the BPE vocabulary we have selected.

Normally we would have a dataset with many examples, but for this demonstration we fit a language model on the single novel only. We don't want the model to just memorize the dataset by encoding the words in its position embeddings, so at each training iteration we will randomly select how much padding to put before the text vs. after it.

We have 8 TPU cores, so we will separately randomize the amount of padding for each core.


In [0]:
# Set up the data pipeline.
def my_inputs(n_devices):
  while True:
    inputs = []
    mask = []
    pad_amounts = onp.random.choice(PAD_AMOUNT, n_devices)
    for i in range(n_devices):
      inputs.append(onp.pad(IDS, (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                            mode='constant'))
      mask.append(onp.pad(onp.ones_like(IDS, dtype=onp.float32),
                          (pad_amounts[i], PAD_AMOUNT - pad_amounts[i]),
                          mode='constant'))
    inputs = onp.stack(inputs)
    mask = onp.stack(mask)
    yield (inputs, inputs, mask)

print("(device count, tokens per device) = ",
      next(my_inputs(trax.math.device_count()))[0].shape)


(device count, tokens per device) =  (8, 524288)

In [0]:
# Configure hyperparameters.
gin.parse_config("""
import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.inputs
import trax.supervised.trainer_lib

# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Our model will have 6 layers, alternating between the LSH attention proposed
# in the Reformer paper and local attention within a certain context window.
n_layers = 6
attn_type = [
  @TimeBinCausalAttention,
  @LSHCausalAttention,  
  @TimeBinCausalAttention,
  @LSHCausalAttention,
  @TimeBinCausalAttention,
  @LSHCausalAttention,
  ]
share_qk = False  # LSHCausalAttention ignores this flag and always shares q & k
n_heads = 2
attn_kv = 64
dropout = 0.05
n_tokens = 524288

# Parameters for MultifactorSchedule:
# ==============================================================================
MultifactorSchedule.constant = 0.01
MultifactorSchedule.factors = 'constant * linear_warmup * cosine_decay'
MultifactorSchedule.warmup_steps = 100
MultifactorSchedule.steps_per_cycle = 900

# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.86
Adam.b2 = 0.92
Adam.eps = 1e-9

# Parameters for TimeBinCausalAttention:
# ==============================================================================
TimeBinCausalAttention.bin_length = 64
TimeBinCausalAttention.dropout = 0.05
TimeBinCausalAttention.n_bins = None
TimeBinCausalAttention.share_qk = %share_qk

# Parameters for LSHCausalAttention:
# ==============================================================================
LSHCausalAttention.allow_duplicate_attention = False
LSHCausalAttention.attend_across_buckets = True
LSHCausalAttention.rehash_each_round = True
LSHCausalAttention.data_rotation = False
LSHCausalAttention.n_bins = 4096
LSHCausalAttention.n_buckets = 8192
LSHCausalAttention.factorize_hash = [64, 128]
LSHCausalAttention.n_hashes = 1
LSHCausalAttention.one_rng = False
LSHCausalAttention.hard_k = 0
LSHCausalAttention.dropout = 0.0
LSHCausalAttention.drop_for_hash_rate = 0.0
LSHCausalAttention.max_len_for_inference = 2048
LSHCausalAttention.bucket_capacity_for_inference = 64

# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 256
ReformerLM.d_ff = 512
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 320
ReformerLM.share_qk = %share_qk
ReformerLM.axial_pos_shape = (512, 1024)
ReformerLM.d_axial_pos_embs= (64, 192)
""")

In [0]:
# Set up a Trainer.
output_dir = os.path.expanduser('~/train_dir/')
!rm -f ~/train_dir/model.pkl  # Remove old model
trainer = trax.supervised.Trainer(
    model=trax.models.ReformerLM,
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=trax.supervised.inputs.Inputs(my_inputs),
    output_dir=output_dir,
    has_weights=True)

In [0]:
# Run one training step, to make sure the model fits in memory.
# The first time trainer.train_epoch is called, it will JIT the entire network
# architecture, which takes around 2 minutes. The JIT-compiled model is saved
# so subsequent runs will be much faster than the first.
trainer.train_epoch(n_steps=1, n_eval_steps=1)


Step      1: Ran 1 train steps in 124.84 secs
Step      1: Evaluation
Step      1: train                   accuracy |  0.00621507
Step      1: train                       loss |  6.35514784
Step      1: train         neg_log_perplexity |  6.35514784
Step      1: train weights_per_batch_per_core |  513812.00000000
Step      1: eval                    accuracy |  0.00616811
Step      1: eval                        loss |  6.35424042
Step      1: eval          neg_log_perplexity |  6.35424042
Step      1: eval  weights_per_batch_per_core |  513812.00000000
Step      1: Finished evaluation

In [0]:
# Train for 600 steps total
# The first ~20 steps are slow to run, but after that it reaches steady-state
# speed. This will take at least 30 minutes to run to completion, but can safely
# be interrupted by selecting "Runtime > Interrupt Execution" from the menu.
# The language model won't be exceptionally good when trained for just a few
# steps and with minimal regularization. However, we can still sample from it to
# see what it learns.
trainer.train_epoch(n_steps=9, n_eval_steps=1)
for _ in range(59):
  trainer.train_epoch(n_steps=10, n_eval_steps=1)

Sample from the model


In [0]:
# As we report in the Reformer paper, increasing the number of hashing rounds
# helps with quality. We can even increase the number of hashing rounds at
# evaluation time only.
gin.parse_config("""LSHCausalAttention.n_hashes = 4""")
model_infer = trax.models.ReformerLM(mode='predict')

In [0]:
# Prepare a jitted copy of the model.
jit_model_infer = trax.layers.base._accelerate(
    model_infer._forward_internal, trax.math.device_count())
# Set up the initial state for sampling.
infer_state = model_infer.new_weights_and_state(
    trax.supervised.trainer_lib.ShapeDtype((1,1), dtype=np.int32))[1]
infer_state = trainer._for_n_devices(infer_state)

In [0]:
def sample(length=2048, prompt=None):
  """Sample from the ReformerLM model"""
  model_weights = trainer._opt_state[0][0]

  # Token id 0 is the equivalent of a "start" token
  cur_inputs = np.zeros((trax.math.device_count(), 1, 1), dtype=np.int32)

  cur_state = infer_state
  rngs = trax.math.random.split(trax.math.random.get_prng(0), trax.math.device_count())
  all_samples = []

  if prompt is not None:
    prompt = np.asarray(
        [TOKENIZER.EncodeAsIds(prompt)] * trax.math.device_count())

  for iteration in range(length):
    logits, cur_state = jit_model_infer(
        cur_inputs,
        model_weights,
        cur_state,
        rngs)
    
    if prompt is not None and iteration < prompt.shape[1]:
      cur_samples = onp.array(prompt[:, iteration], dtype=int)
    else:
      logits = onp.array(logits)[:,0,0,:]
      probs = onp.exp(logits)
      cur_samples = [onp.random.choice(probs.shape[-1], p=probs[i,:])
                     for i in range(probs.shape[0])]
      cur_samples = onp.array(cur_samples, dtype=int)
    all_samples.append(cur_samples)

    cur_inputs = np.array(cur_samples[:,None,None])
  all_samples = onp.stack(all_samples, -1)
  
  return all_samples

In [0]:
# Sample from the Reformer language model, given a prefix.
samples = sample(length=128, prompt="There was a time when")
for ids in samples:
  print(TOKENIZER.DecodeIds(ids.tolist()))


There was a time when the door, when anxious--he did most of all kicking his weary. It was a scarcely realisease talking ears fellow stood next rough in extraords and then, a large stood old woman were died in the old woman, the accusing, and her little chest in handlinters,
There was a time when came into desire any felt an injure of some of being a shopelessing, that, would certaints of fear where there is less in all true. In place would not copace of person that governoment, she is acquaintance. And yet talking office. What do you writ of some gament and
There was a time when he lister, remained a little in his property in a day a man who in the room appearance of the rive of the subject was part of one acquaintied, even huge, and various comfined at things, that instantial ovelock-spons, eager girl had not looked feet
There was a time when the balcasevsky Petrovitch who drown, scandlchedness of scanness, and forcertain rags, with coming an extremely colours and innatummed easier, and the absorbed in completely absorbed in completely forced with him at once. The red had been about
There was a time when he remembered that there could believed. The suddenly over him. Inced to one clear eagger was a look of dish and no merely pictims of quisten. The dost visit, trivial about it. As the wooden people were companion of ceiling and a kitchen correctly forgetting her
There was a time when he walked a stronger, however it was indeed from the noose inouple of the community, and it was pounced to be ashamed to be not!... Here... the address, Luished--no, the monst her! N-per circumstance of Golting it is myself for the room. But I am firmly w
There was a time when he arrival in a raggong. At time he used to wall gave a girl like a lady who had come yard. Petrovitch pale and blank table in that behaved, turned in the eyebs, unno as iron window-drivering leaden in coming, that with a certicis. Not
There was a time when he was ashes of flat. These were flound face was grate. As for a young man came into a room of notice of the room, a tortur jelooking, and with one acurg in the yard, with a minutely refinite an effect, and she laid which was walking with evidently agitation. E

In [0]: