This notebook continues the codifies the capabilities discussed in this blog post. In a nutshell, it uses the pre-trained inception model as a starting point and then uses transfer learning to train it further on additional, customer-specific images. For explanation, simple flower images are used. Compared to training from scratch, the time and costs are drastically reduced.
This notebook does preprocessing, training and prediction by calling CloudML API instead of running them in the Datalab container. The purpose of local work is to do some initial prototyping and debugging on small scale data - often by taking a suitable (say 0.1 - 1%) sample of the full data. The same basic steps can then be repeated with much larger datasets in cloud.
First run the following steps only if you are running Datalab from your local desktop or laptop (not running Datalab from a GCE VM):
If you run Datalab from a GCE VM, then make sure the project of the GCE VM is enabled for Machine Learning API and Dataflow API.
In [4]:
import mltoolbox.image.classification as model
from google.datalab.ml import *
bucket = 'gs://' + datalab_project_id() + '-lab'
preprocess_dir = bucket + '/flowerpreprocessedcloud'
model_dir = bucket + '/flowermodelcloud'
staging_dir = bucket + '/staging'
In [ ]:
!gsutil mb $bucket
Preprocessing uses a Dataflow pipeline to convert the image format, resize images, and run the converted image through a pre-trained model to get the features or embeddings. You can also do this step using alternate technologies like Spark or plain Python code if you like. The %%ml preprocess command simplifies this task. Check out the parameters shown using --usage flag first and then run the command.
If you hit "PERMISSION_DENIED" when running the following cell, you need to enable Cloud DataFlow API (url is shown in error message).
The DataFlow job usually takes about 20 min to complete.
In [2]:
train_set = CsvDataSet('gs://cloud-datalab/sampledata/flower/train1000.csv', schema='image_url:STRING,label:STRING')
preprocess_job = model.preprocess_async(train_set, preprocess_dir, cloud={'num_workers': 10})
preprocess_job.wait() # Alternatively, you can query the job status by train_job.state. The wait() call blocks the notebook execution.
In [3]:
train_job = model.train_async(preprocess_dir, 30, 1000, model_dir, cloud=CloudTrainingConfig('us-central1', 'BASIC'))
train_job.wait() # Alternatively, you can query the job status by train_job.state. The wait() call blocks the notebook execution.
Out[3]:
Check your job status by running (replace the job id from the one shown above):
Job('image_classification_train_170307_002934').describe()
Tensorboard works too with GCS path. Note that the data will show up usually a minute after tensorboard starts with GCS path.
In [ ]:
tb_id = TensorBoard.start(model_dir)
In [21]:
Models().create('flower')
ModelVersions('flower').deploy('beta1', model_dir)
Online prediction is currently in alpha, it helps to ensure a warm start if the first call fails.
In [6]:
images = [
'gs://cloud-ml-data/img/flower_photos/daisy/15207766_fc2f1d692c_n.jpg',
'gs://cloud-ml-data/img/flower_photos/tulips/6876631336_54bf150990.jpg'
]
# set resize=True to avoid sending large data in prediction request.
model.predict('flower.beta1', images, resize=True, cloud=True)
Out[6]:
In [ ]:
import google.datalab.bigquery as bq
bq.Dataset('flower').create()
eval_set = CsvDataSet('gs://cloud-datalab/sampledata/flower/eval670.csv', schema='image_url:STRING,label:STRING')
batch_predict_job = model.batch_predict_async(eval_set, model_dir, output_bq_table='flower.eval_results_full',
cloud={'temp_location': staging_dir})
batch_predict_job.wait()
In [1]:
%%bq query --name wrong_prediction
SELECT * FROM flower.eval_results_full WHERE target != predicted
In [2]:
wrong_prediction.execute().result()
Out[2]:
In [5]:
ConfusionMatrix.from_bigquery('flower.eval_results_full').plot()
In [6]:
%%bq query --name accuracy
SELECT
target,
SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END) as correct,
COUNT(*) as total,
SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END)/COUNT(*) as accuracy
FROM
flower.eval_results_full
GROUP BY
target
In [7]:
accuracy.execute().result()
Out[7]:
In [8]:
%%bq query --name logloss
SELECT feature, AVG(-logloss) as logloss, count(*) as count FROM
(
SELECT feature, CASE WHEN correct=1 THEN LOG(prob) ELSE LOG(1-prob) END as logloss
FROM
(
SELECT
target as feature,
CASE WHEN target=predicted THEN 1 ELSE 0 END as correct,
target_prob as prob
FROM flower.eval_results_full))
GROUP BY feature
In [9]:
FeatureSliceView().plot(logloss)
In [ ]:
ModelVersions('flower').delete('beta1')
Models().delete('flower')
!gsutil -m rm -r {preprocess_dir}
!gsutil -m rm -r {model_dir}
In [ ]: