Text generation using tensor2tensor on Cloud ML Engine

This notebook illustrates using the tensor2tensor library to do from-scratch, distributed training of a poetry model. Then, the trained model is used to complete new poems.


Install tensor2tensor, and specify Google Cloud Platform project and bucket

Install the necessary packages. tensor2tensor will give us the Transformer model. Project Gutenberg gives us access to historical poems.

p.s. Note that this notebook uses Python2 because Project Gutenberg relies on BSD-DB which was deprecated in Python 3 and removed from the standard library. tensor2tensor itself can be used on Python 3. It's just Project Gutenberg that has this issue.


In [ ]:
%%bash
pip freeze | grep tensor

In [ ]:
%%bash
pip install tensor2tensor==1.13.1 tensorflow==1.13.1 tensorflow-serving-api==1.13 gutenberg 
pip install tensorflow_hub 

# install from sou
#git clone https://github.com/tensorflow/tensor2tensor.git
#cd tensor2tensor
#yes | pip install --user -e .

If the following cell does not reflect the version of tensorflow and tensor2tensor that you just installed, click "Reset Session" on the notebook so that the Python environment picks up the new packages.


In [ ]:
%%bash
pip freeze | grep tensor

In [ ]:
import os
PROJECT = 'cloud-training-demos' # REPLACE WITH YOUR PROJECT ID
BUCKET = 'cloud-training-demos-ml' # REPLACE WITH YOUR BUCKET NAME
REGION = 'us-central1' # REPLACE WITH YOUR BUCKET REGION e.g. us-central1

# this is what this notebook is demonstrating
PROBLEM= 'poetry_line_problem'

# for bash
os.environ['PROJECT'] = PROJECT
os.environ['BUCKET'] = BUCKET
os.environ['REGION'] = REGION
os.environ['PROBLEM'] = PROBLEM

#os.environ['PATH'] = os.environ['PATH'] + ':' + os.getcwd() + '/tensor2tensor/tensor2tensor/bin/'

In [ ]:
%%bash
gcloud config set project $PROJECT
gcloud config set compute/region $REGION

Download data

We will get some poetry anthologies from Project Gutenberg.


In [ ]:
%%bash
rm -rf data/poetry
mkdir -p data/poetry

In [ ]:
from gutenberg.acquire import load_etext
from gutenberg.cleanup import strip_headers
import re

books = [
  # bookid, skip N lines
  (26715, 1000, 'Victorian songs'),
  (30235, 580, 'Baldwin collection'),
  (35402, 710, 'Swinburne collection'),
  (574, 15, 'Blake'),
  (1304, 172, 'Bulchevys collection'),
  (19221, 223, 'Palgrave-Pearse collection'),
  (15553, 522, 'Knowles collection') 
]

with open('data/poetry/raw.txt', 'w') as ofp:
  lineno = 0
  for (id_nr, toskip, title) in books:
    startline = lineno
    text = strip_headers(load_etext(id_nr)).strip()
    lines = text.split('\n')[toskip:]
    # any line that is all upper case is a title or author name
    # also don't want any lines with years (numbers)
    for line in lines:
      if (len(line) > 0 
          and line.upper() != line 
          and not re.match('.*[0-9]+.*', line)
          and len(line) < 50
         ):
        cleaned = re.sub('[^a-z\'\-]+', ' ', line.strip().lower())
        ofp.write(cleaned)
        ofp.write('\n')
        lineno = lineno + 1
      else:
        ofp.write('\n')
    print('Wrote lines {} to {} from {}'.format(startline, lineno, title))

In [ ]:
!wc -l data/poetry/*.txt

Create training dataset

We are going to train a machine learning model to write poetry given a starting point. We'll give it one line, and it is going to tell us the next line. So, naturally, we will train it on real poetry. Our feature will be a line of a poem and the label will be next line of that poem.

Our training dataset will consist of two files. The first file will consist of the input lines of poetry and the other file will consist of the corresponding output lines, one output line per input line.


In [ ]:
with open('data/poetry/raw.txt', 'r') as rawfp,\
  open('data/poetry/input.txt', 'w') as infp,\
  open('data/poetry/output.txt', 'w') as outfp:
    
    prev_line = ''
    for curr_line in rawfp:
        curr_line = curr_line.strip()
        # poems break at empty lines, so this ensures we train only
        # on lines of the same poem
        if len(prev_line) > 0 and len(curr_line) > 0:       
            infp.write(prev_line + '\n')
            outfp.write(curr_line + '\n')
        prev_line = curr_line

In [ ]:
!head -5 data/poetry/*.txt

We do not need to generate the data beforehand -- instead, we can have Tensor2Tensor create the training dataset for us. So, in the code below, I will use only data/poetry/raw.txt -- obviously, this allows us to productionize our model better. Simply keep collecting raw data and generate the training/test data at the time of training.

Set up problem

The Problem in tensor2tensor is where you specify parameters like the size of your vocabulary and where to get the training data from.


In [ ]:
%%bash
rm -rf poetry
mkdir -p poetry/trainer

In [ ]:
%%writefile poetry/trainer/problem.py
import os
import tensorflow as tf
from tensor2tensor.utils import registry
from tensor2tensor.models import transformer
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators import text_problems
from tensor2tensor.data_generators import generator_utils

tf.summary.FileWriterCache.clear() # ensure filewriter cache is clear for TensorBoard events file

@registry.register_problem
class PoetryLineProblem(text_problems.Text2TextProblem):
  """Predict next line of poetry from the last line. From Gutenberg texts."""

  @property
  def approx_vocab_size(self):
    return 2**13  # ~8k

  @property
  def is_generate_per_split(self):
    # generate_data will NOT shard the data into TRAIN and EVAL for us.
    return False

  @property
  def dataset_splits(self):
    """Splits of data to produce and number of output shards for each."""
    # 10% evaluation data
    return [{
        "split": problem.DatasetSplit.TRAIN,
        "shards": 90,
    }, {
        "split": problem.DatasetSplit.EVAL,
        "shards": 10,
    }]

  def generate_samples(self, data_dir, tmp_dir, dataset_split):
    with open('data/poetry/raw.txt', 'r') as rawfp:
      prev_line = ''
      for curr_line in rawfp:
        curr_line = curr_line.strip()
        # poems break at empty lines, so this ensures we train only
        # on lines of the same poem
        if len(prev_line) > 0 and len(curr_line) > 0:       
            yield {
                "inputs": prev_line,
                "targets": curr_line
            }
        prev_line = curr_line          


# Smaller than the typical translate model, and with more regularization
@registry.register_hparams
def transformer_poetry():
  hparams = transformer.transformer_base()
  hparams.num_hidden_layers = 2
  hparams.hidden_size = 128
  hparams.filter_size = 512
  hparams.num_heads = 4
  hparams.attention_dropout = 0.6
  hparams.layer_prepostprocess_dropout = 0.6
  hparams.learning_rate = 0.05
  return hparams

@registry.register_hparams
def transformer_poetry_tpu():
  hparams = transformer_poetry()
  transformer.update_hparams_for_tpu(hparams)
  return hparams

# hyperparameter tuning ranges
@registry.register_ranged_hparams
def transformer_poetry_range(rhp):
  rhp.set_float("learning_rate", 0.05, 0.25, scale=rhp.LOG_SCALE)
  rhp.set_int("num_hidden_layers", 2, 4)
  rhp.set_discrete("hidden_size", [128, 256, 512])
  rhp.set_float("attention_dropout", 0.4, 0.7)

In [ ]:
%%writefile poetry/trainer/__init__.py
from . import problem

In [ ]:
%%writefile poetry/setup.py
from setuptools import find_packages
from setuptools import setup

REQUIRED_PACKAGES = [
  'tensor2tensor'
]

setup(
    name='poetry',
    version='0.1',
    author = 'Google',
    author_email = 'training-feedback@cloud.google.com',
    install_requires=REQUIRED_PACKAGES,
    packages=find_packages(),
    include_package_data=True,
    description='Poetry Line Problem',
    requires=[]
)

In [ ]:
!touch poetry/__init__.py

In [ ]:
!find poetry

Generate training data

Our problem (translation) requires the creation of text sequences from the training dataset. This is done using t2t-datagen and the Problem defined in the previous section.

(Ignore any runtime warnings about np.float64. they are harmless).


In [ ]:
%%bash
DATA_DIR=./t2t_data
TMP_DIR=$DATA_DIR/tmp
rm -rf $DATA_DIR $TMP_DIR
mkdir -p $DATA_DIR $TMP_DIR
# Generate data
t2t-datagen \
  --t2t_usr_dir=./poetry/trainer \
  --problem=$PROBLEM \
  --data_dir=$DATA_DIR \
  --tmp_dir=$TMP_DIR

Let's check to see the files that were output. If you see a broken pipe error, please ignore.


In [ ]:
!ls t2t_data | head

Provide Cloud ML Engine access to data

Copy the data to Google Cloud Storage, and then provide access to the data. gsutil throws an error when removing an empty bucket, so you may see an error the first time this code is run.


In [ ]:
%%bash
DATA_DIR=./t2t_data
gsutil -m rm -r gs://${BUCKET}/poetry/
gsutil -m cp ${DATA_DIR}/${PROBLEM}* ${DATA_DIR}/vocab* gs://${BUCKET}/poetry/data

In [ ]:
%%bash
PROJECT_ID=$PROJECT
AUTH_TOKEN=$(gcloud auth print-access-token)
SVC_ACCOUNT=$(curl -X GET -H "Content-Type: application/json" \
    -H "Authorization: Bearer $AUTH_TOKEN" \
    https://ml.googleapis.com/v1/projects/${PROJECT_ID}:getConfig \
    | python -c "import json; import sys; response = json.load(sys.stdin); \
    print(response['serviceAccount'])")

echo "Authorizing the Cloud ML Service account $SVC_ACCOUNT to access files in $BUCKET"
gsutil -m defacl ch -u $SVC_ACCOUNT:R gs://$BUCKET
gsutil -m acl ch -u $SVC_ACCOUNT:R -r gs://$BUCKET  # error message (if bucket is empty) can be ignored
gsutil -m acl ch -u $SVC_ACCOUNT:W gs://$BUCKET

Train model locally on subset of data

Let's run it locally on a subset of the data to make sure it works.


In [ ]:
%%bash
BASE=gs://${BUCKET}/poetry/data
OUTDIR=gs://${BUCKET}/poetry/subset
gsutil -m rm -r $OUTDIR
gsutil -m cp \
    ${BASE}/${PROBLEM}-train-0008* \
    ${BASE}/${PROBLEM}-dev-00000*  \
    ${BASE}/vocab* \
    $OUTDIR

Note: the following will work only if you are running Jupyter on a reasonably powerful machine. Don't be alarmed if your process is killed.


In [ ]:
%%bash
DATA_DIR=gs://${BUCKET}/poetry/subset
OUTDIR=./trained_model
rm -rf $OUTDIR
t2t-trainer \
  --data_dir=gs://${BUCKET}/poetry/subset \
  --t2t_usr_dir=./poetry/trainer \
  --problem=$PROBLEM \
  --model=transformer \
  --hparams_set=transformer_poetry \
  --output_dir=$OUTDIR --job-dir=$OUTDIR --train_steps=10

Option 1: Train model locally on full dataset (use if running on Notebook Instance with a GPU)

You can train on the full dataset if you are on a Google Cloud Notebook Instance with a P100 or better GPU


In [ ]:
%%bash
LOCALGPU="--train_steps=7500 --worker_gpu=1 --hparams_set=transformer_poetry"

DATA_DIR=gs://${BUCKET}/poetry/data
OUTDIR=gs://${BUCKET}/poetry/model
rm -rf $OUTDIR
t2t-trainer \
  --data_dir=gs://${BUCKET}/poetry/subset \
  --t2t_usr_dir=./poetry/trainer \
  --problem=$PROBLEM \
  --model=transformer \
  --hparams_set=transformer_poetry \
  --output_dir=$OUTDIR ${LOCALGPU}

Option 2: Train on Cloud ML Engine

tensor2tensor has a convenient --cloud_mlengine option to kick off the training on the managed service. It uses the Python API mentioned in the Cloud ML Engine docs, rather than requiring you to use gcloud to submit the job.

Note: your project needs P100 quota in the region.

The echo is because t2t-trainer asks you to confirm before submitting the job to the cloud. Ignore any error about "broken pipe". If you see a message similar to this:

    [... cloud_mlengine.py:392] Launched transformer_poetry_line_problem_t2t_20190323_000631. See console to track: https://console.cloud.google.com/mlengine/jobs/.
then, this step has been successful.


In [ ]:
%%bash
GPU="--train_steps=7500 --cloud_mlengine --worker_gpu=1 --hparams_set=transformer_poetry"

DATADIR=gs://${BUCKET}/poetry/data
OUTDIR=gs://${BUCKET}/poetry/model
JOBNAME=poetry_$(date -u +%y%m%d_%H%M%S)
echo $OUTDIR $REGION $JOBNAME
gsutil -m rm -rf $OUTDIR
echo "'Y'" | t2t-trainer \
  --data_dir=gs://${BUCKET}/poetry/subset \
  --t2t_usr_dir=./poetry/trainer \
  --problem=$PROBLEM \
  --model=transformer \
  --output_dir=$OUTDIR \
  ${GPU}

In [ ]:
%%bash
## CHANGE the job name (based on output above: You will see a line such as Launched transformer_poetry_line_problem_t2t_20190322_233159)
gcloud ml-engine jobs describe transformer_poetry_line_problem_t2t_20190323_003001

The job took about 25 minutes for me and ended with these evaluation metrics:

Saving dict for global step 8000: global_step = 8000, loss = 6.03338, metrics-poetry_line_problem/accuracy = 0.138544, metrics-poetry_line_problem/accuracy_per_sequence = 0.0, metrics-poetry_line_problem/accuracy_top5 = 0.232037, metrics-poetry_line_problem/approx_bleu_score = 0.00492648, metrics-poetry_line_problem/neg_log_perplexity = -6.68994, metrics-poetry_line_problem/rouge_2_fscore = 0.00256089, metrics-poetry_line_problem/rouge_L_fscore = 0.128194
Notice that accuracy_per_sequence is 0 -- Considering that we are asking the NN to be rather creative, that doesn't surprise me. Why am I looking at accuracy_per_sequence and not the other metrics? This is because it is more appropriate for problem we are solving; metrics like Bleu score are better for translation.

Option 3: Train on a directly-connected TPU

If you are running on a VM connected directly to a Cloud TPU, you can run t2t-trainer directly. Unfortunately, you won't see any output from Jupyter while the program is running.

Compare this command line to the one using GPU in the previous section.


In [ ]:
%%bash
# use one of these
TPU="--train_steps=7500 --use_tpu=True --cloud_tpu_name=laktpu --hparams_set=transformer_poetry_tpu"

DATADIR=gs://${BUCKET}/poetry/data
OUTDIR=gs://${BUCKET}/poetry/model_tpu
JOBNAME=poetry_$(date -u +%y%m%d_%H%M%S)
echo $OUTDIR $REGION $JOBNAME
gsutil -m rm -rf $OUTDIR
echo "'Y'" | t2t-trainer \
  --data_dir=gs://${BUCKET}/poetry/subset \
  --t2t_usr_dir=./poetry/trainer \
  --problem=$PROBLEM \
  --model=transformer \
  --output_dir=$OUTDIR \
  ${TPU}

In [ ]:
%%bash
gsutil ls gs://${BUCKET}/poetry/model_tpu

The job took about 10 minutes for me and ended with these evaluation metrics:

Saving dict for global step 8000: global_step = 8000, loss = 6.03338, metrics-poetry_line_problem/accuracy = 0.138544, metrics-poetry_line_problem/accuracy_per_sequence = 0.0, metrics-poetry_line_problem/accuracy_top5 = 0.232037, metrics-poetry_line_problem/approx_bleu_score = 0.00492648, metrics-poetry_line_problem/neg_log_perplexity = -6.68994, metrics-poetry_line_problem/rouge_2_fscore = 0.00256089, metrics-poetry_line_problem/rouge_L_fscore = 0.128194
Notice that accuracy_per_sequence is 0 -- Considering that we are asking the NN to be rather creative, that doesn't surprise me. Why am I looking at accuracy_per_sequence and not the other metrics? This is because it is more appropriate for problem we are solving; metrics like Bleu score are better for translation.

Option 4: Training longer

Let's train on 4 GPUs for 75,000 steps. Note the change in the last line of the job.


In [ ]:
%%bash

XXX This takes 3 hours on 4 GPUs. Remove this line if you are sure you want to do this.

DATADIR=gs://${BUCKET}/poetry/data
OUTDIR=gs://${BUCKET}/poetry/model_full2
JOBNAME=poetry_$(date -u +%y%m%d_%H%M%S)
echo $OUTDIR $REGION $JOBNAME
gsutil -m rm -rf $OUTDIR
echo "'Y'" | t2t-trainer \
  --data_dir=gs://${BUCKET}/poetry/subset \
  --t2t_usr_dir=./poetry/trainer \
  --problem=$PROBLEM \
  --model=transformer \
  --hparams_set=transformer_poetry \
  --output_dir=$OUTDIR \
  --train_steps=75000 --cloud_mlengine --worker_gpu=4

This job took 12 hours for me and ended with these metrics:

global_step = 76000, loss = 4.99763, metrics-poetry_line_problem/accuracy = 0.219792, metrics-poetry_line_problem/accuracy_per_sequence = 0.0192308, metrics-poetry_line_problem/accuracy_top5 = 0.37618, metrics-poetry_line_problem/approx_bleu_score = 0.017955, metrics-poetry_line_problem/neg_log_perplexity = -5.38725, metrics-poetry_line_problem/rouge_2_fscore = 0.0325563, metrics-poetry_line_problem/rouge_L_fscore = 0.210618
At least the accuracy per sequence is no longer zero. It is now 0.0192308 ... note that we are using a relatively small dataset (12K lines) and this is tiny in the world of natural language problems.

In order that you have your expectations set correctly: a high-performing translation model needs 400-million lines of input and takes 1 whole day on a TPU pod!

Check trained model


In [ ]:
%%bash
gsutil ls gs://${BUCKET}/poetry/model   #_modeltpu

Batch-predict

How will our poetry model do when faced with Rumi's spiritual couplets?


In [ ]:
%%writefile data/poetry/rumi.txt
Where did the handsome beloved go?
I wonder, where did that tall, shapely cypress tree go?
He spread his light among us like a candle.
Where did he go? So strange, where did he go without me?
All day long my heart trembles like a leaf.
All alone at midnight, where did that beloved go?
Go to the road, and ask any passing travelerThat soul-stirring companion, where did he go?
Go to the garden, and ask the gardenerThat tall, shapely rose stem, where did he go?
Go to the rooftop, and ask the watchmanThat unique sultan, where did he go?
Like a madman, I search in the meadows!
That deer in the meadows, where did he go?
My tearful eyes overflow like a riverThat pearl in the vast sea, where did he go?
All night long, I implore both moon and VenusThat lovely face, like a moon, where did he go?
If he is mine, why is he with others?
Since hes not here, to what there did he go?
If his heart and soul are joined with God,
And he left this realm of earth and water, where did he go?
Tell me clearly, Shams of Tabriz,
Of whom it is said, The sun never dieswhere did he go?

Let's write out the odd-numbered lines. We'll compare how close our model can get to the beauty of Rumi's second lines given his first.


In [ ]:
%%bash
awk 'NR % 2 == 1' data/poetry/rumi.txt | tr '[:upper:]' '[:lower:]' | sed "s/[^a-z\'-\ ]//g" > data/poetry/rumi_leads.txt
head -3 data/poetry/rumi_leads.txt

In [ ]:
%%bash
# same as the above training job ...
TOPDIR=gs://${BUCKET}
OUTDIR=${TOPDIR}/poetry/model #_tpu  # or ${TOPDIR}/poetry/model_full
DATADIR=${TOPDIR}/poetry/data
MODEL=transformer
HPARAMS=transformer_poetry #_tpu

# the file with the input lines
DECODE_FILE=data/poetry/rumi_leads.txt

BEAM_SIZE=4
ALPHA=0.6

t2t-decoder \
  --data_dir=$DATADIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$OUTDIR \
  --t2t_usr_dir=./poetry/trainer \
  --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
  --decode_from_file=$DECODE_FILE

Note if you get an error about "AttributeError: 'HParams' object has no attribute 'problems'" please Reset Session, run the cell that defines the PROBLEM and run the above cell again.


In [ ]:
%%bash  
DECODE_FILE=data/poetry/rumi_leads.txt
cat ${DECODE_FILE}.*.decodes

Some of these are still phrases and not complete sentences. This indicates that we might need to train longer or better somehow. We need to diagnose the model ...

Diagnosing training run

Let's diagnose the training run to see what we'd improve the next time around. (Note that this package may not be present on Jupyter -- pip install pydatalab if necessary)

Monitor training with TensorBoard

To activate TensorBoard within the JupyterLab UI navigate to "File" - "New Launcher". Then double-click the 'Tensorboard' icon on the bottom row.

TensorBoard 1 will appear in the new tab. Navigate through the three tabs to see the active TensorBoard. The 'Graphs' and 'Projector' tabs offer very interesting information including the ability to replay the tests.

You may close the TensorBoard tab when you are finished exploring.

Looking at the loss curve, it is clear that we are overfitting (note that the orange training curve is well below the blue eval curve). Both loss curves and the accuracy-per-sequence curve, which is our key evaluation measure, plateaus after 40k. (The red curve is a faster way of computing the evaluation metric, and can be ignored). So, how do we improve the model? Well, we need to reduce overfitting and make sure the eval metrics keep going down as long as the loss is also going down.

What we really need to do is to get more data, but if that's not an option, we could try to reduce the NN and increase the dropout regularization. We could also do hyperparameter tuning on the dropout and network sizes.

Hyperparameter tuning

tensor2tensor also supports hyperparameter tuning on Cloud ML Engine. Note the addition of the autotune flags.

The transformer_poetry_range was registered in problem.py above.


In [ ]:
%%bash

XXX This takes about 15 hours and consumes about 420 ML units.  Uncomment if you wish to proceed anyway

DATADIR=gs://${BUCKET}/poetry/data
OUTDIR=gs://${BUCKET}/poetry/model_hparam
JOBNAME=poetry_$(date -u +%y%m%d_%H%M%S)
echo $OUTDIR $REGION $JOBNAME
gsutil -m rm -rf $OUTDIR
echo "'Y'" | t2t-trainer \
  --data_dir=gs://${BUCKET}/poetry/subset \
  --t2t_usr_dir=./poetry/trainer \
  --problem=$PROBLEM \
  --model=transformer \
  --hparams_set=transformer_poetry \
  --output_dir=$OUTDIR \
  --hparams_range=transformer_poetry_range \
  --autotune_objective='metrics-poetry_line_problem/accuracy_per_sequence' \
  --autotune_maximize \
  --autotune_max_trials=4 \
  --autotune_parallel_trials=4 \
  --train_steps=7500 --cloud_mlengine --worker_gpu=4

When I ran the above job, it took about 15 hours and finished with these as the best parameters:

{
      "trialId": "37",
      "hyperparameters": {
        "hp_num_hidden_layers": "4",
        "hp_learning_rate": "0.026711152525921437",
        "hp_hidden_size": "512",
        "hp_attention_dropout": "0.60589466163419292"
      },
      "finalMetric": {
        "trainingStep": "8000",
        "objectiveValue": 0.0276162791997
      }
In other words, the accuracy per sequence achieved was 0.027 (as compared to 0.019 before hyperparameter tuning, so a 40% improvement!) using 4 hidden layers, a learning rate of 0.0267, a hidden size of 512 and droput probability of 0.606. This is inspite of training for only 7500 steps instead of 75,000 steps ... we could train for 75k steps with these parameters, but I'll leave that as an exercise for you.

Instead, let's try predicting with this optimized model. Note the addition of the hp* flags in order to override the values hardcoded in the source code. (there is no need to specify learning rate and dropout because they are not used during inference). I am using 37 because I got the best result at trialId=37


In [ ]:
%%bash
# same as the above training job ...
BEST_TRIAL=28  # CHANGE as needed.
TOPDIR=gs://${BUCKET}
OUTDIR=${TOPDIR}/poetry/model_hparam/$BEST_TRIAL
DATADIR=${TOPDIR}/poetry/data
MODEL=transformer
HPARAMS=transformer_poetry

# the file with the input lines
DECODE_FILE=data/poetry/rumi_leads.txt

BEAM_SIZE=4
ALPHA=0.6

t2t-decoder \
  --data_dir=$DATADIR \
  --problem=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$OUTDIR \
  --t2t_usr_dir=./poetry/trainer \
  --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
  --decode_from_file=$DECODE_FILE \
  --hparams="num_hidden_layers=4,hidden_size=512"

In [ ]:
%%bash  
DECODE_FILE=data/poetry/rumi_leads.txt
cat ${DECODE_FILE}.*.decodes

Take the first three line. I'm showing the first line of the couplet provided to the model, how the AI model that we trained complets it and how Rumi completes it:

INPUT: where did the handsome beloved go
AI: where art thou worse to me than dead
RUMI: I wonder, where did that tall, shapely cypress tree go?

INPUT: he spread his light among us like a candle
AI: like the hurricane eclipse
RUMI: Where did he go? So strange, where did he go without me?

INPUT: all day long my heart trembles like a leaf
AI: and through their hollow aisles it plays
RUMI: All alone at midnight, where did that beloved go?

Oh wow. The couplets as completed are quite decent considering that:

  • We trained the model on American poetry, so feeding it Rumi is a bit out of left field.
  • Rumi, of course, has a context and thread running through his lines while the AI (since it was fed only that one line) doesn't.

"Spreading light like a hurricane eclipse" is a metaphor I won't soon forget. And it was created by a machine learning model!

Serving poetry

How would you serve these predictions? There are two ways:

  1. Use Cloud ML Engine -- this is serverless and you don't have to manage any infrastructure.
  2. Use Kubeflow on Google Kubernetes Engine -- this uses clusters but will also work on-prem on your own Kubernetes cluster. </ol>
  3. In either case, you need to export the model first and have TensorFlow serving serve the model. The model, however, expects to see encoded (i.e. preprocessed) data. So, we'll do that in the Python Flask application (in AppEngine Flex) that serves the user interface.

    
    
    In [ ]:
    %%bash
    TOPDIR=gs://${BUCKET}
    OUTDIR=${TOPDIR}/poetry/model_full2
    DATADIR=${TOPDIR}/poetry/data
    MODEL=transformer
    HPARAMS=transformer_poetry
    BEAM_SIZE=4
    ALPHA=0.6
    
    t2t-exporter \
      --model=$MODEL \
      --hparams_set=$HPARAMS \
      --problem=$PROBLEM \
      --t2t_usr_dir=./poetry/trainer \
      --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
      --data_dir=$DATADIR \
      --output_dir=$OUTDIR
    
    
    
    In [ ]:
    %%bash
    MODEL_LOCATION=$(gsutil ls gs://${BUCKET}/poetry/model_full2/export | tail -1)
    echo $MODEL_LOCATION
    saved_model_cli show --dir $MODEL_LOCATION --tag_set serve --signature_def serving_default
    

    Cloud ML Engine

    
    
    In [ ]:
    %%writefile mlengine.json
    description: Poetry service on ML Engine
    autoScaling:
        minNodes: 1  # We don't want this model to autoscale down to zero
    
    
    
    In [ ]:
    %%bash
    MODEL_NAME="poetry"
    MODEL_VERSION="v1"
    MODEL_LOCATION=$(gsutil ls gs://${BUCKET}/poetry/model_full2/export | tail -1)
    echo "Deleting and deploying $MODEL_NAME $MODEL_VERSION from $MODEL_LOCATION ... this will take a few minutes"
    gcloud ml-engine versions delete ${MODEL_VERSION} --model ${MODEL_NAME}
    #gcloud ml-engine models delete ${MODEL_NAME}
    #gcloud ml-engine models create ${MODEL_NAME} --regions $REGION
    gcloud ml-engine versions create ${MODEL_VERSION} \
           --model ${MODEL_NAME} --origin ${MODEL_LOCATION} --runtime-version=1.13 --config=mlengine.json
    

    Kubeflow

    Follow these instructions:

    • On the GCP console, launch a Google Kubernetes Engine (GKE) cluster named 'poetry' with 2 nodes, each of which is a n1-standard-2 (2 vCPUs, 7.5 GB memory) VM
    • On the GCP console, click on the Connect button for your cluster, and choose the CloudShell option
    • In CloudShell, run:
        git clone https://github.com/GoogleCloudPlatform/training-data-analyst`
        cd training-data-analyst/courses/machine_learning/deepdive/09_sequence
    • Look at ./setup_kubeflow.sh and modify as appropriate.

    AppEngine

    What's deployed in Cloud ML Engine or Kubeflow is only the TensorFlow model. We still need a preprocessing service. That is done using AppEngine. Edit application/app.yaml appropriately.

    
    
    In [ ]:
    !cat application/app.yaml
    
    
    
    In [ ]:
    %%bash
    cd application
    #gcloud app create  # if this is your first app
    #gcloud app deploy --quiet --stop-previous-version app.yaml
    

    Now visit https://mlpoetry-dot-cloud-training-demos.appspot.com and try out the prediction app!

    Copyright 2019 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License