In [0]:
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This notebook was designed to run on TPU.
To use TPUs in Colab, click "Runtime" on the main menu bar and select Change runtime type. Set "TPU" as the hardware accelerator.
In [0]:
# Install JAX.
!gsutil cp gs://trax-ml/reformer/jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl .
!gsutil cp gs://trax-ml/reformer/jax-0.1.59-cp36-none-manylinux2010_x86_64.whl .
!pip install --upgrade -q ./jaxlib-0.1.39-cp36-none-manylinux2010_x86_64.whl
!pip install --upgrade -q ./jax-0.1.59-cp36-none-manylinux2010_x86_64.whl
# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
In [0]:
!pip install --upgrade -q gin git+https://github.com/google/trax.git@v1.2.3
from tensorflow.compat.v1.io.gfile import GFile
import gin
import os
import pickle
import jax
import trax
from trax.models.beam_search import Search
from trax.supervised import inputs
from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
import numpy as np
import jax.numpy as jnp
from scipy.special import softmax
In [0]:
# Install sacreBLEU
!pip install sacrebleu
import sacrebleu
In [0]:
# Download the newstest2014 English-to-German translation pairs
!sacrebleu -t wmt14/full -l en-de --echo src > wmt14-en-de.src
!sacrebleu -t wmt14/full -l en-de --echo ref > wmt14-en-de.ref
In [0]:
# Load the source text and reference translations into Python
refs = []
for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.ref'), 1):
if line.endswith('\n'):
line = line[:-1]
refs.append(line)
srcs = []
for lineno, line in enumerate(sacrebleu.smart_open('wmt14-en-de.src'), 1):
if line.endswith('\n'):
line = line[:-1]
srcs.append(line)
In [0]:
# Set up our sub-word tokenizer
tokenizer = SubwordTextEncoder(
'gs://trax-ml/reformer/mt/vocab.translate_ende_wmt32k.32768.subwords')
In [0]:
# Encode source sentences using the tokenizer
input_ids = np.zeros((len(srcs), 128), dtype=jnp.int64)
for i, x in enumerate(srcs):
x = tokenizer.encode(x)
assert len(x) <= 127
input_ids[i, :len(x)] = x
input_ids[i, len(x)] = 1
In [0]:
# We'll be using a pre-trained reversible transformer-base model.
# First, load the config (which sets all needed hyperparameters).
!gsutil cp gs://trax-ml/reformer/mt/config.gin ./config.gin
gin.parse_config_file('./config.gin')
In [0]:
# Now we load the pre-trained model weights.
with GFile('gs://trax-ml/reformer/mt/model.pkl', 'rb') as f:
model_weights = pickle.load(f)['weights']
In [0]:
# Set up beam search.
beam_decoder = Search(
trax.models.Reformer, model_weights,
beam_size=4,
alpha=0.6, # For length normalization, set to 0.6 following Vaswani et al.
eos_id=1, # The stop token has id 1 in the vocabulary we use.
max_decode_len=146,
)
In [13]:
pred_ids = []
preds = []
BATCH_SIZE = 1024
for start in range(0, input_ids.shape[0], BATCH_SIZE):
print(start, '/', input_ids.shape[0], flush=True)
batch = input_ids[start:start+BATCH_SIZE]
seqs, scores = beam_decoder.decode(batch, batch_size=BATCH_SIZE)
# Select highest scoring output.
batch_pred_ids = seqs[:, -1]
pred_ids.append(batch_pred_ids)
preds.extend([
tokenizer.decode(pred.tolist(), strip_extraneous=True)
for pred in batch_pred_ids
])
In [14]:
bleu = sacrebleu.corpus_bleu(preds, [refs], lowercase=True, tokenize='intl')
print(bleu)
In [0]: