Hyperparameter tuning of BigQuery ML models

This notebook shows you a couple of ways to do hyperparameter tuning of BigQuery ML models

1a. Grid Search from Python

Let's try to find the optimal K for a k-means clustering model by trying out all possible values of K within a specified range.


In [1]:
# On Notebook instances in Google Cloud, these are already installed
#!python -m pip install google-cloud-bigquery
#%load_ext google.cloud.bigquery

In [5]:
from multiprocessing.dummy import Pool as ThreadPool
from google.cloud import bigquery
import numpy as np
PROJECT='cloud-training-demos'  # CHANGE THIS

In [26]:
class Range:
    def __init__(self, minvalue, maxvalue, incr=1):
        self._minvalue = minvalue
        self._maxvalue = maxvalue
        self._incr = incr
    def values(self):
        return range(self._minvalue, self._maxvalue, self._incr) 

class Params:
    def __init__(self, num_clusters):
        self._num_clusters = num_clusters
        self._model_name = 'ch09eu.london_station_clusters_{}'.format(num_clusters)
        self._train_query = """
          CREATE OR REPLACE MODEL {}
          OPTIONS(model_type='kmeans', 
                  num_clusters={}, 
                  standardize_features = true) AS
          SELECT * except(station_name)
          from ch09eu.stationstats
        """.format(self._model_name, self._num_clusters)
        self._eval_query = """
          SELECT davies_bouldin_index AS error
          FROM ML.EVALUATE(MODEL {});
        """.format(self._model_name)
        self._error = None
        
    def run(self):
        bq = bigquery.Client(project=PROJECT)
        job = bq.query(self._train_query, location='EU')
        job.result() # wait for job to finish
        evaldf = bq.query(self._eval_query, location='EU').to_dataframe()
        self._error = evaldf['error'][0]
        return self
    
    def __str__(self):
        fmt = '{!s:<40} {:>10f} {:>5d}'
        return fmt.format(self._model_name, self._error, self._num_clusters)
    
def train_and_evaluate(num_clusters: Range, max_concurrent=3):
    # grid search means to try all possible values in range
    params = []
    for k in num_clusters.values():
        params.append(Params(k))
    
    # run all the jobs
    print('Grid search of {} possible parameters'.format(len(params)))
    pool = ThreadPool(max_concurrent)
    results = pool.map(lambda p: p.run(), params)
    
    # sort in ascending order
    return sorted(results, key=lambda p: p._error)
    
params = train_and_evaluate(Range(3, 20))
print(*params, sep='\n')


Grid search of 17 possible parameters
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
/usr/local/lib/python3.5/dist-packages/google/auth/_default.py:66: UserWarning: Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/
  warnings.warn(_CLOUD_SDK_CREDENTIALS_WARNING)
ch09eu.london_station_clusters_19          1.400756    19
ch09eu.london_station_clusters_15          1.415517    15
ch09eu.london_station_clusters_18          1.429912    18
ch09eu.london_station_clusters_14          1.438088    14
ch09eu.london_station_clusters_13          1.438440    13
ch09eu.london_station_clusters_17          1.454715    17
ch09eu.london_station_clusters_16          1.456185    16
ch09eu.london_station_clusters_11          1.502263    11
ch09eu.london_station_clusters_12          1.511940    12
ch09eu.london_station_clusters_10          1.529150    10
ch09eu.london_station_clusters_7           1.551265     7
ch09eu.london_station_clusters_9           1.571020     9
ch09eu.london_station_clusters_6           1.571398     6
ch09eu.london_station_clusters_4           1.596398     4
ch09eu.london_station_clusters_8           1.621974     8
ch09eu.london_station_clusters_5           1.660766     5
ch09eu.london_station_clusters_3           1.681441     3

Since the error has kept decreasing, it appears that we need to experiment with even more clusters. 19 may itself be too few ...

1b. Grid Search through Scripting and dynamic SQL

Scripting can simplify the above code.


In [ ]:
%%bigquery
DECLARE NUM_CLUSTERS INT64 DEFAULT 3;
DECLARE MIN_ERROR FLOAT64 DEFAULT 1000.0;
DECLARE BEST_NUM_CLUSTERS INT64 DEFAULT -1;
DECLARE MODEL_NAME STRING;
DECLARE error FLOAT64 DEFAULT 0;
 
WHILE NUM_CLUSTERS < 8 DO
 
  SET MODEL_NAME = CONCAT('ch09eu.london_station_clusters_', 
                            CAST(NUM_CLUSTERS AS STRING));
 
  EXECUTE IMMEDIATE format("""
  CREATE OR REPLACE MODEL %s
    OPTIONS(model_type='kmeans', 
            num_clusters=%d, 
            standardize_features = true) AS
    SELECT * except(station_name)
    from ch09eu.stationstats;
  """, MODEL_NAME, NUM_CLUSTERS);
    
  EXECUTE IMMEDIATE format("""
    SELECT davies_bouldin_index FROM ML.EVALUATE(MODEL %s);
  """, MODEL_NAME) INTO error;
  
    IF error < MIN_ERROR THEN
       SET MIN_ERROR = error;
       SET BEST_NUM_CLUSTERS = NUM_CLUSTERS;
    END IF;
  SET NUM_CLUSTERS = NUM_CLUSTERS + 1;
 
END WHILE

2. Bayesian optimization

As the number of possible parameters grows, a grid search becomes increasingly wasteful. It is better to use a more efficient search algorithm and that's where Cloud AI Platform's hyperparameter tuning can be helpful.

I'll demonstrate this on tuning the feature engineering and number of nodes of a DNN model.


In [72]:
%%writefile hyperparam.yaml
trainingInput:
  scaleTier: CUSTOM
  masterType: standard   # See: https://cloud.google.com/ml-engine/docs/tensorflow/machine-types
  hyperparameters:
    goal: MINIMIZE
    maxTrials: 10
    maxParallelTrials: 2
    hyperparameterMetricTag: mean_absolute_error
    params:
    - parameterName: afternoon_start
      type: INTEGER
      minValue: 9
      maxValue: 12
      scaleType: UNIT_LINEAR_SCALE
    - parameterName: afternoon_end
      type: INTEGER
      minValue: 15
      maxValue: 19
      scaleType: UNIT_LINEAR_SCALE
    - parameterName: num_nodes_0
      type: INTEGER
      minValue: 10
      maxValue: 100
      scaleType: UNIT_LOG_SCALE
    - parameterName: num_nodes_1
      type: INTEGER
      minValue: 3
      maxValue: 10
      scaleType: UNIT_LINEAR_SCALE


Overwriting hyperparam.yaml

In [77]:
%%writefile setup.py
from setuptools import setup

setup(name='trainer',
      version='1.0',
      description='Tune BigQuery ML',
      url='http://github.com/GoogleCloudPlatform/bigquery-oreilly-book',
      author='V Lakshmanan',
      author_email='nobody@google.com',
      license='Apache2',
      packages=['trainer'],
      package_data={'': ['privatekey.json']},
      install_requires=[
          'google-cloud-bigquery==1.15.0',
          'cloudml-hypertune',  # Required for hyperparameter tuning.
      ],
      zip_safe=False)


Overwriting setup.py

In [74]:
%%bash
mkdir -p trainer
touch trainer/__init__.py

In [71]:
%%bash

### You can get your project number from the GCP home page

KEYFILE=trainer/privatekey.json
PROJECTNUMBER=663413318684    ## CHANGE THIS
if [ ! -f $KEYFILE ]; then
  gcloud iam service-accounts keys create \
      --iam-account ${PROJECTNUMBER}-compute@developer.gserviceaccount.com \
      $KEYFILE
fi

In [85]:
%%writefile trainer/train_and_eval.py
import argparse
import hypertune
import json
import logging
import pkgutil
from google.oauth2.service_account import Credentials as sac

from google.cloud import bigquery

def get_credentials():
    privatekey = pkgutil.get_data('trainer', 'privatekey.json')
    service_account_info = json.loads(privatekey.decode('utf-8'))
    return sac.from_service_account_info(service_account_info)
    
def train_and_evaluate(args):        
    model_name = "ch09eu.bicycle_model_dnn_{}_{}_{}_{}".format(
        args.afternoon_start, args.afternoon_end, args.num_nodes_0, args.num_nodes_1
    )
    train_query = """
        CREATE OR REPLACE MODEL {}
        TRANSFORM(* EXCEPT(start_date)
                  , IF(EXTRACT(dayofweek FROM start_date) BETWEEN 2 and 6, 'weekday', 'weekend') as dayofweek
                  , ML.BUCKETIZE(EXTRACT(HOUR FROM start_date), [5, {}, {}]) AS hourofday
        )
        OPTIONS(input_label_cols=['duration'], 
                model_type='dnn_regressor',
                hidden_units=[{}, {}])
        AS

        SELECT 
          duration
          , start_station_name
          , start_date
        FROM `bigquery-public-data`.london_bicycles.cycle_hire
    """.format(model_name, 
               args.afternoon_start, 
               args.afternoon_end,
               args.num_nodes_0,
               args.num_nodes_1)
    logging.info(train_query)
    bq = bigquery.Client(project=args.project, 
                         location=args.location, 
                         credentials=get_credentials())
    job = bq.query(train_query)
    job.result() # wait for job to finish
    
    eval_query = """
        SELECT mean_absolute_error 
        FROM ML.EVALUATE(MODEL {})
    """.format(model_name)
    logging.info(eval_info)
    evaldf = bq.query(eval_query).to_dataframe()
    return evaldf['mean_absolute_error'][0]    
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--afternoon_start', type=int, default=10)
    parser.add_argument('--afternoon_end', type=int, default=17)
    parser.add_argument('--num_nodes_0', type=int, default=10)
    parser.add_argument('--num_nodes_1', type=int, default=5)
    parser.add_argument('--location', type=str, default='US')
    parser.add_argument('--project', type=str, required=True)
    parser.add_argument('--job-dir', default='ignored') # output directory to save artifacts. we have none

    # get args and invoke model
    args = parser.parse_args()
    error = train_and_evaluate(args)
    logging.info('{} Resulting mean_absolute_error: {}'.format(args.__dict__, error)) 

    # write out the metric so that the executable can be
    # invoked again with next set of metrics
    hpt = hypertune.HyperTune()
    hpt.report_hyperparameter_tuning_metric(
       hyperparameter_metric_tag='mean_absolute_error',
       metric_value=error,
       global_step=1)


Overwriting trainer/train_and_eval.py

In [87]:
%%bash
PROJECT=cloud-training-demos #CHANGE THIS
BUCKET=cloud-training-demos-ml  ## CHANGE THIS
JOBNAME=bqml_hparam_$(date -u +%y%m%d_%H%M%S)
REGION=europe-west1
gcloud ai-platform jobs submit training $JOBNAME \
  --runtime-version=1.13 \
  --python-version=3.5 \
  --region=$REGION \
  --module-name=trainer.train_and_eval \
  --package-path=$(pwd)/trainer \
  --job-dir=gs://$BUCKET/hparam/ \
  --config=hyperparam.yaml \
  -- \
  --project=$PROJECT --location=EU


jobId: bqml_hparam_190701_161928
state: QUEUED
Job [bqml_hparam_190701_161928] submitted successfully.
Your job is still active. You may view the status of your job with the command

  $ gcloud ai-platform jobs describe bqml_hparam_190701_161928

or continue streaming the logs with the command

  $ gcloud ai-platform jobs stream-logs bqml_hparam_190701_161928

Copyright 2019 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.