In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Customizing AdaNet With TensorFlow Hub Modules

From the customizing AdaNet tutorial, you know how to define your own neural architecture search space for AdaNet algorithm to explore. One can simplify this process further by using TensorFlow Hub modules as the basic building blocks for AdaNet. These modules have already been pre-trained on large corpuses of data which enables you to leverage the power of transfer learning.

In this tutorial, we will create a custom search space for sentiment analysis dataset using TensorFlow Hub text embedding modules.


In [0]:
#@test {"skip": true}
# If you are running this in Colab, first install the adanet package:
!pip install adanet

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools
import os
import re
import shutil
import numpy as np 
import pandas as pd 

import tensorflow.compat.v1 as tf
import tensorflow_hub as hub

import adanet
from adanet.examples import simple_dnn

# The random seed to use.
RANDOM_SEED = 42

LOG_DIR = '/tmp/models'

Getting started

Data

We will try to solve the Large Movie Review Dataset v1.0 task (Mass et al., 2011). The dataset consists of IMDB movie reviews labeled by positivity from 1 to 10. The task is to label the reviews as negative or positive.


In [0]:
def load_directory_data(directory):
  data = {}
  data["sentence"] = []
  data["sentiment"] = []
  for file_path in os.listdir(directory):
    with tf.gfile.GFile(os.path.join(directory, file_path), "r") as f:
      data["sentence"].append(f.read())
      data["sentiment"].append(re.match("\d+_(\d+)\.txt", file_path).group(1))
  return pd.DataFrame.from_dict(data)

def load_dataset(directory):
  pos_df = load_directory_data(os.path.join(directory, "pos"))
  neg_df = load_directory_data(os.path.join(directory, "neg"))
  pos_df["polarity"] = 1
  neg_df["polarity"] = 0
  return pd.concat([pos_df, neg_df]).sample(frac=1).reset_index(drop=True)

def download_and_load_datasets(force_download=False):
  dataset = tf.keras.utils.get_file(
    fname="aclImdb.tar.gz",
    origin="http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz",
    extract=True
  )
  train_df = load_dataset(os.path.join(os.path.dirname(dataset),
                                      "aclImdb", "train"))
  test_df = load_dataset(os.path.join(os.path.dirname(dataset),
                                      "aclImdb", "test"))
  return train_df, test_df

tf.logging.set_verbosity(tf.logging.INFO)

train_df, test_df = download_and_load_datasets()
train_df.head()

Supply the data in TensorFlow

Our first task is to supply the data in TensorFlow. We define three kinds of input_fn that will be used in training later using pandas_input_fn.


In [0]:
FEATURES_KEY = "sentence"

train_input_fn = tf.estimator.inputs.pandas_input_fn(
  train_df, train_df["polarity"], num_epochs=None, shuffle=True)

predict_train_input_fn = tf.estimator.inputs.pandas_input_fn(
  train_df, train_df["polarity"], shuffle=False)

predict_test_input_fn = tf.estimator.inputs.pandas_input_fn(
  test_df, test_df["polarity"], shuffle=False)

Launch TensorBoard

Let's run TensorBoard to visualize model training over time. We'll use ngrok to tunnel traffic to localhost.

The instructions for setting up Tensorboard were obtained from https://www.dlology.com/blog/quick-guide-to-run-tensorboard-in-google-colab/

Run the next cells and follow the link to see the TensorBoard in a new tab.


In [0]:
#@test {"skip": true}

get_ipython().system_raw(
    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'
    .format(LOG_DIR)
)

# Install ngrok binary.
! wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip
! unzip ngrok-stable-linux-amd64.zip

# Delete old logs dir.
shutil.rmtree(LOG_DIR, ignore_errors=True)

print("Follow this link to open TensorBoard in a new tab.")
get_ipython().system_raw('./ngrok http 6006 &')
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

Establish baselines

The next task should be to get somes baselines to see how our model performs on this dataset.

Let's define some information to share with all our tf.estimator.Estimators:


In [0]:
NUM_CLASSES = 2

loss_reduction = tf.losses.Reduction.SUM_OVER_BATCH_SIZE

head = tf.contrib.estimator.binary_classification_head(
  loss_reduction=loss_reduction)

hub_columns=hub.text_embedding_column(
    key=FEATURES_KEY, 
    module_spec="https://tfhub.dev/google/nnlm-en-dim128/1")

def make_config(experiment_name):
  # Estimator configuration.
  return tf.estimator.RunConfig(
    save_checkpoints_steps=1000,
    save_summary_steps=1000,
    tf_random_seed=RANDOM_SEED,
    model_dir=os.path.join(LOG_DIR, experiment_name))

Let's start simple, and train a linear model:


In [0]:
#@test {"skip": true}
#@title Parameters
LEARNING_RATE = 0.001 #@param {type:"number"}
TRAIN_STEPS = 5000 #@param {type:"integer"}

estimator = tf.estimator.LinearClassifier(
  feature_columns=[hub_columns],
  n_classes=NUM_CLASSES,
  optimizer=tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE),
  loss_reduction=loss_reduction,
  config=make_config("linear"))

results, _ = tf.estimator.train_and_evaluate(
  estimator,
  train_spec=tf.estimator.TrainSpec(
    input_fn=train_input_fn,
    max_steps=TRAIN_STEPS),
  eval_spec=tf.estimator.EvalSpec(
    input_fn=predict_test_input_fn,
    steps=None))

print("Accuracy: ", results["accuracy"])
print("Loss: ", results["average_loss"])

The linear model with default parameters achieves about 78% accuracy.

Let's see if we can do better with the simple_dnn AdaNet:


In [0]:
#@test {"skip": true}
#@title Parameters
LEARNING_RATE = 0.003  #@param {type:"number"}
TRAIN_STEPS = 5000  #@param {type:"integer"}
ADANET_ITERATIONS = 2  #@param {type:"integer"}

estimator = adanet.Estimator(
    head=head,
    
    # Define the generator, which defines our search space of subnetworks
    # to train as candidates to add to the final AdaNet model.
    subnetwork_generator=simple_dnn.Generator(
        feature_columns=[hub_columns],
        optimizer=tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE),
        seed=RANDOM_SEED),
    
    # The number of train steps per iteration.
    max_iteration_steps=TRAIN_STEPS // ADANET_ITERATIONS,
    
    # The evaluator will evaluate the model on the full training set to
    # compute the overall AdaNet loss (train loss + complexity
    # regularization) to select the best candidate to include in the
    # final AdaNet model.
    evaluator=adanet.Evaluator(
        input_fn=predict_train_input_fn,
        steps=1000),
    
    # Configuration for Estimators.
    config=make_config("simple_dnn"))

results, _ = tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(
        input_fn=train_input_fn,
        max_steps=TRAIN_STEPS),
    eval_spec=tf.estimator.EvalSpec(
        input_fn=predict_test_input_fn,
        steps=None))
print("Accuracy:", results["accuracy"])
print("Loss:", results["average_loss"])

The simple_dnn AdaNet model with default parameters achieves about 80% accuracy.

This improvement can be attributed to simple_dnn searching over fully-connected neural networks which have more expressive power than the linear model due to their non-linear activations.

The above simple_dnn generator only generates subnetworks that take embedding results from one module. We can add diversity to the search space by building subnetworks that take different embeddings, hence might improve the performance. To do that, we need to define a custom adanet.subnetwork.Builder and adanet.subnetwork.Generator.

Define a AdaNet model with TensorFlow Hub text embedding modules

Creating a new search space for AdaNet to explore is straightforward. There are two abstract classes you need to extend:

  1. adanet.subnetwork.Builder
  2. adanet.subnetwork.Generator

Similar to the tf.estimator.Estimator model_fn, adanet.subnetwork.Builder allows you to define your own TensorFlow graph for creating a neural network, and specify the training operations.

Below we define one that applies text embedding using TensorFlow Hub text modules first, and then a fully-connected layer to the sentiment polarity.


In [0]:
class SimpleNetworkBuilder(adanet.subnetwork.Builder):
  """Builds a simple subnetwork with text embedding module."""

  def __init__(self, learning_rate, max_iteration_steps, seed,
               module_name, module):
    """Initializes a `SimpleNetworkBuilder`.

    Args:
      learning_rate: The float learning rate to use.
      max_iteration_steps: The number of steps per iteration.
      seed: The random seed.

    Returns:
      An instance of `SimpleNetworkBuilder`.
    """
    self._learning_rate = learning_rate
    self._max_iteration_steps = max_iteration_steps
    self._seed = seed
    self._module_name = module_name
    self._module = module

  def build_subnetwork(self,
                       features,
                       logits_dimension,
                       training,
                       iteration_step,
                       summary,
                       previous_ensemble=None):
    """See `adanet.subnetwork.Builder`."""
    sentence = features["sentence"]
    # Load module and apply text embedding, setting trainable=True.
    m = hub.Module(self._module, trainable=True)
    x = m(sentence)
    kernel_initializer = tf.keras.initializers.he_normal(seed=self._seed)

    # The `Head` passed to adanet.Estimator will apply the softmax activation.
    logits = tf.layers.dense(
        x, units=1, activation=None, kernel_initializer=kernel_initializer)

    # Use a constant complexity measure, since all subnetworks have the same
    # architecture and hyperparameters.
    complexity = tf.constant(1)

    return adanet.Subnetwork(
        last_layer=x,
        logits=logits,
        complexity=complexity,
        persisted_tensors={})

  def build_subnetwork_train_op(self, 
                                subnetwork, 
                                loss, 
                                var_list, 
                                labels, 
                                iteration_step,
                                summary, 
                                previous_ensemble=None):
    """See `adanet.subnetwork.Builder`."""

    learning_rate = tf.train.cosine_decay(
        learning_rate=self._learning_rate,
        global_step=iteration_step,
        decay_steps=self._max_iteration_steps)
    optimizer = tf.train.MomentumOptimizer(learning_rate, .9)
    # NOTE: The `adanet.Estimator` increments the global step.
    return optimizer.minimize(loss=loss, var_list=var_list)

  def build_mixture_weights_train_op(self, loss, var_list, logits, labels,
                                     iteration_step, summary):
    """See `adanet.subnetwork.Builder`."""
    return tf.no_op("mixture_weights_train_op")

  @property
  def name(self):
    """See `adanet.subnetwork.Builder`."""
    return self._module_name

Next, we extend a adanet.subnetwork.Generator, which defines the search space of candidate SimpleNetworkBuilder to consider including the final network. It can create one or more at each iteration with different parameters, and the AdaNet algorithm will select the candidate that best improves the overall neural network's adanet_loss on the training set.

The one below loops through the text embedding modules listed in MODULES and gives it a different random seed at each iteration. These modules are selected from TensorFlow Hub text modules:


In [0]:
MODULES = [
    "https://tfhub.dev/google/nnlm-en-dim50/1",
    "https://tfhub.dev/google/nnlm-en-dim128/1",
    "https://tfhub.dev/google/universal-sentence-encoder/1"
]

In [0]:
class SimpleNetworkGenerator(adanet.subnetwork.Generator):
  """Generates a `SimpleNetwork` at each iteration.
  """

  def __init__(self, learning_rate, max_iteration_steps, seed=None):
    """Initializes a `Generator` that builds `SimpleNetwork`.

    Args:
      learning_rate: The float learning rate to use.
      max_iteration_steps: The number of steps per iteration.
      seed: The random seed.

    Returns:
      An instance of `Generator`.
    """
    self._seed = seed
    self._dnn_builder_fn = functools.partial(
        SimpleNetworkBuilder,
        learning_rate=learning_rate,
        max_iteration_steps=max_iteration_steps)

  def generate_candidates(self, previous_ensemble, iteration_number,
                          previous_ensemble_reports, all_reports):
    """See `adanet.subnetwork.Generator`."""
    module_index = iteration_number % len(MODULES)
    module_name = MODULES[module_index].split("/")[-2]
    
    print("generating candidate: %s" % module_name)
    
    seed = self._seed
    # Change the seed according to the iteration so that each subnetwork
    # learns something different.
    if seed is not None:
      seed += iteration_number
    return [self._dnn_builder_fn(seed=seed, 
                                 module_name=module_name, 
                                 module=MODULES[module_index])]

With these defined, we pass them into a new adanet.Estimator:


In [0]:
#@title Parameters
LEARNING_RATE = 0.05  #@param {type:"number"}
TRAIN_STEPS = 7500  #@param {type:"integer"}
ADANET_ITERATIONS = 3  #@param {type:"integer"}

max_iteration_steps = TRAIN_STEPS // ADANET_ITERATIONS
estimator = adanet.Estimator(
    head=head,
    subnetwork_generator=SimpleNetworkGenerator(
        learning_rate=LEARNING_RATE,
        max_iteration_steps=max_iteration_steps,
        seed=RANDOM_SEED),
    max_iteration_steps=max_iteration_steps,
    evaluator=adanet.Evaluator(input_fn=train_input_fn, 
                               steps=10),
    report_materializer=None,
    adanet_loss_decay=.99,
    config=make_config("tfhub"))

results, _ = tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(input_fn=train_input_fn,
                                      max_steps=TRAIN_STEPS),
    eval_spec=tf.estimator.EvalSpec(input_fn=predict_test_input_fn, 
                                    steps=None))
print("Accuracy:", results["accuracy"])
print("Loss:", results["average_loss"])



def ensemble_architecture(result):
  """Extracts the ensemble architecture from evaluation results."""

  architecture = result["architecture/adanet/ensembles"]
  # The architecture is a serialized Summary proto for TensorBoard.
  summary_proto = tf.summary.Summary.FromString(architecture)

Our SimpleNetworkGenerator code achieves about 87% accuracy , which is almost 7% higher than with using just one network directly.

You can see how the performance improves step by step:

Linear Baseline Adanet + simple_dnn Adanet + TensorFlow Hub
78% 80% 87%

Generating predictions on our trained model

Now that we've got a trained model, we can use it to generate predictions on new input. To keep things simple, here we'll generate predictions on our estimator using the first 10 examples from the test set.


In [0]:
predict_input_fn = tf.estimator.inputs.pandas_input_fn(
  test_df.iloc[:10], test_df["polarity"].iloc[:10], shuffle=False)

predictions = estimator.predict(input_fn=predict_input_fn)

for i, val in enumerate(predictions):
    predicted_class = val['class_ids'][0]
    prediction_confidence = val['probabilities'][predicted_class] * 100
    
    print('Actual text: ' + test_df["sentence"][i])
    print('Predicted class: %s, confidence: %s%%' 
          % (predicted_class, round(prediction_confidence, 3)))

Conclusion and next steps

In this tutorial, you learned how to customize adanet to encode your understanding of a particular dataset, and explore novel search spaces with AdaNet with TensorFlow Hub modules.

As an exercise, you can swap out the ACL IMDB dataset with other text dataset in this notebook and see how SimpleNetworkGenerator performs.