In [0]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

 https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Reformer: Machine Translation

This notebook was designed to run on TPU.

To use TPUs in Colab, click "Runtime" on the main menu bar and select Change runtime type. Set "TPU" as the hardware accelerator.


In [0]:
# Install JAX.
!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .
!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .
!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl
!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

In [0]:
!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3

from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import pickle
import jax
import trax
from trax.models.beam_search import Search
from trax.supervised import inputs

from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder

import numpy as np
import jax.numpy as jnp

from scipy.special import softmax

In [0]:
# Install sacreBLEU
!pip install sacrebleu
import sacrebleu

Load WMT14 data


In [0]:
# Download the newstest2014 English-to-German translation pairs
!sacrebleu -t wmt14/full -l en-de --echo src > wmt14-en-de.src
!sacrebleu -t wmt14/full -l en-de --echo ref > wmt14-en-de.ref

In [0]:
# Load the source text and reference translations into Python
refs = []
for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1):
  if line.endswith('\n'):
    line = line[:-1]
  refs.append(line)
srcs = []
for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1):
  if line.endswith('\n'):
    line = line[:-1]
  srcs.append(line)

In [0]:
# Set up our sub-word tokenizer
tokenizer = SubwordTextEncoder(
    'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')

In [0]:
# Encode source sentences using the tokenizer
input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64)
for i, x in enumerate(srcs):
  x = tokenizer.encode(x)
  assert len(x) <= 127
  input_ids[i, :len(x)] = x
  input_ids[i, len(x)] = 1

Load the pre-trained model


In [0]:
# We'll be using a pre-trained reversible transformer-base model.
# First, load the config (which sets all needed hyperparameters).
!gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin
gin.parse_config_file('./config.gin')

In [0]:
# Now we load the pre-trained model weights.
with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f:
  model_weights = pickle.load(f)['weights']

Beam search decoding


In [0]:
# Set up beam search.
beam_decoder = Search(
    trax.models.Reformer, model_weights,
    beam_size=4,
    alpha=0.6,  # For length normalization, set to 0.6 following Vaswani et al.
    eos_id=1,  # The stop token has id 1 in the vocabulary we use.
    max_decode_len=146,
    )

In [13]:
pred_ids = []
preds = []
BATCH_SIZE = 1024
for start in range(0, input_ids.shape[0], BATCH_SIZE):
  print(start, '/', input_ids.shape[0], flush=True)
  batch = input_ids[start:start+BATCH_SIZE]
  seqs, scores = beam_decoder.decode(batch, batch_size=BATCH_SIZE)
  # Select highest scoring output.
  batch_pred_ids = seqs[:, -1]
  pred_ids.append(batch_pred_ids)
  preds.extend([
      tokenizer.decode(pred.tolist(), strip_extraneous=True)
      for pred in batch_pred_ids
  ])


0 / 3003
1024 / 3003
2048 / 3003

In [14]:
bleu = sacrebleu.corpus_bleu(preds, [refs], lowercase=True, tokenize='intl')
print(bleu)


BLEU = 27.86 59.5/33.5/21.3/14.2 (BP = 1.000 ratio = 1.020 hyp_len = 65943 ref_len = 64676)

In [0]: