Train locally

Import training data

For illustration purposes we will use the MNIST dataset. The following code downloads the dataset and puts it in ./mnist_data.

The first 60000 images and targets are the original training set, while the last 10000 are the testing set. The training set is ordered by the labels so we shuffle them since we will use a very small portion of the data to shorten training time.


In [ ]:
from sklearn.datasets import fetch_mldata
from sklearn.utils import shuffle

mnist = fetch_mldata('MNIST original', data_home='./mnist_data')
X, y = shuffle(mnist.data[:60000], mnist.target[:60000])

X_small = X[:100]
y_small = y[:100]

# Note: using only 10% of the training data
X_large = X[:6000]
y_large = y[:6000]

Instantiate the estimator and the SearchCV objects

For illustration purposes we will use the RandomForestClassifier with RandomizedSearchCV:

http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html

http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html


In [ ]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV

rfc = RandomForestClassifier(n_jobs=-1)
param_grid = {
    'max_features': [1.0, 0.9, 0.8, 0.6],
    'n_estimators': [10, 50, 100, 150, 200],
    'max_depth': [5, 10, 15, 20, None],
    'min_samples_split': [0.01, 0.05, 0.1]
}
search = GridSearchCV(estimator=rfc, param_grid=param_grid, n_jobs=-1, verbose=3)

Fit the GridSearchCV object locally

After fitting we can examine the best score (accuracy) and the best parameters that achieve that score.


In [ ]:
%time search.fit(X_small, y_small)

print(search.best_score_, search.best_params_)

Everything up to this point is what you would do when training locally. With larger amount of data it would take much longer.

Train on Google Container Engine

Set up for training on Google Container Engine

Before we can start training on the Container Engine we need to:

  • Build the Docker image which will be handling the workloads.
  • Create a cluster.

For these we will first set up some configuration variables.

Your Google Cloud Platform project id.


In [ ]:
project_id = 'YOUR-PROJECT-ID'

A Google Cloud Storage bucket belonging to your project created through either:

This bucket will be used for storing temporary data during Docker image building, for storing training data, and for storing trained models.

This can be an existing bucket, but we recommend you create a new one.


In [ ]:
bucket_name = 'YOUR-BUCKET-NAME'

Pick a cluster id for the cluster on Google Container Engine we will create. Preferably not an existing cluster to avoid affecting its workload.


In [ ]:
cluster_id = 'YOUR-CLUSTER-ID'

Choose a name for the image that will be running on the container.


In [ ]:
image_name = 'YOUR-IMAGE-NAME'

Choose a zone to host the cluster.

List of zones: https://cloud.google.com/compute/docs/regions-zones/


In [ ]:
zone = 'us-central1-b'

Change this only if you have customized the source.


In [ ]:
source_dir = 'source'

Build a Docker image

This step builds a Docker image using the content in the source/ folder. The image will be tagged with the provided image_name so the workers can pull it. The main script source/worker.py would retrieve a pickled GridSearchCV object from Cloud Storage and fit it to data on GCS.

Note: This step only needs to be run once the first time you follow these steps, and each time you modify the codes in source/. If you have not modified source/ then you can just re-use the same image.

Note: This could take a couple minutes. To monitor the build process: https://console.cloud.google.com/gcr/builds


In [ ]:
from helpers.cloudbuild_helper import build

build(project_id, source_dir, bucket_name, image_name)

Create a cluster

This step creates a cluster on the Container Engine.

You can alternatively create the cluster with the gcloud command line tool or through the console, but you must add the additional scope of write access to Google Clous Storage: 'https://www.googleapis.com/auth/devstorage.read_write'

Note: This could take several minutes. To monitor the cluster creation process: https://console.cloud.google.com/kubernetes/list


In [ ]:
from helpers.gke_helper import create_cluster

create_cluster(project_id, zone, cluster_id, n_nodes=1, machine_type='n1-standard-64')

For GCE instance pricing: https://cloud.google.com/compute/pricing

Instantiate the GKEParallel object

The GKEParallel class is a helper wrapper around a GridSearchCV object that manages deploying fitting jobs to the Container Engine cluster created above.

We pass in the GridSearchCV object, which will be pickled and stored on Cloud Storage with uri of the form:

gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/search.pkl


In [ ]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV

rfc = RandomForestClassifier(n_jobs=-1)
param_grid = {
    'max_features': [1.0, 0.9, 0.8, 0.6],
    'n_estimators': [10, 50, 100, 150, 200],
    'max_depth': [5, 10, 15, 20, None],
    'min_samples_split': [0.01, 0.05, 0.1]
}
search = GridSearchCV(estimator=rfc, param_grid=param_grid, n_jobs=-1, verbose=3)

In [ ]:
from gke_parallel import GKEParallel

gke_search = GKEParallel(search, project_id, zone, cluster_id, bucket_name, image_name)

Refresh access token to the cluster

To make it easy to gain access to the cluster through the Kubernetes client library, included in this sample is a script that retrieves credentials for the cluster with gcloud and refreshes access token with kubectl.


In [ ]:
! bash get_cluster_credentials.sh $cluster_id $zone

Deploy the fitting task

GKEParallel instances implement a similar (but different!) interface as GridSearchCV.

Calling fit(X, y) first uploads the training data to Cloud Storage as:

gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/X.pkl
gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/y.pkl

This allows reusing the same uploaded datasets for future training tasks.

For instance, if you already have pickled data on Cloud Storage:

gs://DATA-BUCKET/X.pkl
gs://DATA-BUCKET/y.pkl

then you can deploy the fitting task with:

gke_search.fit(X='gs://DATA-BUCKET/X.pkl', y='gs://DATA-BUCKET/y.pkl')

Calling fit(X, y) also pickles the wrapped search and gke_search, stores them on Cloud Storage as:

gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/search.pkl
gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/gke_search.pkl

In [ ]:
gke_search.fit(X_large, y_large)

Inspect the GKEParallel object

In the background, the GKEParallel instance splits the param_grid into smaller param_grids

Each smaller param_grid is pickled and stored on GCS within each worker's workspace:

gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/WORKER-ID/param_grid.pkl

The param_grids can be accessed as follows, showing how they are assigned to each worker.

The keys of this dictionary are the worker_ids.


In [ ]:
gke_search.param_grids

You could optionally specify a task_name when creating a GKEParallel instance.

If you did not specify a task_name, when you call fit(X, y) the task_name will be set to:

YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME


In [ ]:
gke_search.task_name

Similarly, each job is given a job_name. The dictionary of job_names can be accessed as follows. Each worker pod handles one job processing one of the smaller param_grids.

To monitor the jobs: https://console.cloud.google.com/kubernetes/workload


In [ ]:
gke_search.job_names

Cancel the task

To cancel the task, run cancel(). This will delete all the deployed worker pods and jobs, but will NOT delete the cluster, nor delete any data already persisted to Cloud Storage.


In [ ]:
#gke_search.cancel()

Monitor the progress

GKEParallel instances implement a similar (but different!) interface as Future instances. Calling done() checks whether each worker has completed the job and persisted its outcome on GCS with uri:

gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/WORKER-ID/fitted_search.pkl

To monitor the jobs: https://console.cloud.google.com/kubernetes/workload

To access the persisted data directly: https://console.cloud.google.com/storage/browser/YOUR-BUCKET-NAME/


In [ ]:
gke_search.done(), gke_search.dones

When all the jobs are finished, the pods will stop running (but the cluster will remain), and we can retrieve the fitted model.

Calling result() will populate the gke_search.results attribute which is returned. This attribute records all the fitted GridSearchCV from the jobs. The fitted model is downloaded only if the download argument is set to True.

Calling result() also updates the pickled gke_search object on Cloud Storage:

gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/gke_search.pkl


In [ ]:
result = gke_search.result(download=False)

You can also get the logs from the pods:


In [ ]:
from helpers.kubernetes_helper import get_pod_logs

for pod_name, log in get_pod_logs().items():
    print('=' * 20)
    print('\t{}\n'.format(pod_name))
    print(log)

Once the jobs are finished, the cluster can be deleted. All the fitted models are stored on GCS.

The cluster can also be deleted from the console: https://console.cloud.google.com/kubernetes/list


In [ ]:
from helpers.gke_helper import delete_cluster

#delete_cluster(project_id, zone, cluster_id)

The next cell continues to poll the jobs until they are all finished, downloads the results, and deletes the cluster.


In [ ]:
import time
from helpers.gke_helper import delete_cluster

while not gke_search.done():
    n_done = len([d for d in gke_search.dones.values() if d])
    print('{}/{} finished'.format(n_done, len(gke_search.job_names)))
    time.sleep(60)

delete_cluster(project_id, zone, cluster_id)
result = gke_search.result(download=True)

Restore the GKEParallel object

To restore the fitted gke_search object (for example from a different notebook), you can use the helper function included in this sample.


In [ ]:
from helpers.gcs_helper import download_uri_and_unpickle
gcs_uri = 'gs://YOUR-BUCKET-NAME/YOUR-CLUSTER-ID.YOUR-IMAGE-NAME.UNIX-TIME/gke_search.pkl'
gke_search_restored = download_uri_and_unpickle(gcs_uri)

Inspect the result

GKEParallel also implements part of the interface of GridSearchCV to allow easy access to best_score+, best_param_, and beat_estimator_.


In [ ]:
gke_search.best_score_, gke_search.best_params_, gke_search.best_estimator_

You can also call predict(), which deligates the call to the best_estimator_.

Below we calculate the accuracy on the 10000 test images.


In [ ]:
predicted = gke_search.predict(mnist.data[60000:])

print(len([p for i, p in enumerate(predicted) if p == mnist.target[60000:][i]]))

Clean up

To clean up, delete the cluster so your project will no longer be charged for VM instance usage. The simplest way to delete the cluster is through the console: https://console.cloud.google.com/kubernetes/list

This will not delete any data persisted on Cloud Storage.