Efficient training for image classification

Transfer learning using Inception Package - Local Run Experience

Traditionally, image classification required a very large corpus of training data - often millions of images which may not be available and a long time to train on those images which is expensive and time consuming. That has changed with transfer learning which can be readily used with Cloud ML Engine and without deep knowledge of image classification algorithms using the ML toolbox in Datalab.

This notebook codifies the capabilities discussed in this blog post. In a nutshell, it uses the pre-trained inception model as a starting point and then uses transfer learning to train it further on additional, customer-specific images. For explanation, simple flower images are used. Compared to training from scratch, the training data requirements, time and costs are drastically reduced.

This notebook does all operations in the Datalab container without calling CloudML API. Hence, this is called "local" operations - though Datalab itself is most often running on a GCE VM. See the corresponding cloud notebook for cloud experience which only adds the --cloud parameter and some config to the local experience commands. The purpose of local work is to do some initial prototyping and debugging on small scale data - often by taking a suitable (say 0.1 - 1%) sample of the full data. The same basic steps can then be repeated with much larger datasets in cloud.

Setup

All data is available under gs://cloud-datalab/sampledata/flower. eval100 is a subset of eval300, which is a subset of eval670. Same for train data.


In [5]:
!mkdir -p /content/flowerdata

In [6]:
!gsutil -m cp gs://cloud-datalab/sampledata/flower/* /content/flowerdata


Copying gs://cloud-datalab/sampledata/flower/all.csv...
Copying gs://cloud-datalab/sampledata/flower/eval100.csv...
Copying gs://cloud-datalab/sampledata/flower/eval300.csv...
Copying gs://cloud-datalab/sampledata/flower/eval670.csv...
Copying gs://cloud-datalab/sampledata/flower/train1000.csv...
Copying gs://cloud-datalab/sampledata/flower/train200.csv...
Copying gs://cloud-datalab/sampledata/flower/train300.csv...
Copying gs://cloud-datalab/sampledata/flower/train3000.csv...

Define directories for preprocessing, model, and prediction.


In [7]:
import mltoolbox.image.classification as model
from google.datalab.ml import *

worker_dir = '/content/datalab/tmp/flower'
preprocessed_dir = worker_dir + '/flowerrunlocal'
model_dir = worker_dir + '/tinyflowermodellocal'
prediction_dir = worker_dir + '/flowermodelevallocal'
images_dir = worker_dir + '/images'
local_train_file = '/content/flowerdata/train200local.csv'
local_eval_file = '/content/flowerdata/eval100local.csv'

In [8]:
!mkdir -p {images_dir}

In order to get best efficiency, we download the images to local disk, and create our training and evaluation files to reference local path instead of GCS path. Note that the original training files referencing GCS image paths work too, although a bit slower.


In [9]:
import csv
import datalab.storage as gcs
import os


def download_images(input_csv, output_csv, images_dir):
  with open(input_csv) as csvfile:
    data = list(csv.DictReader(csvfile, fieldnames=['image_url', 'label']))
  for x in data:
    url = x['image_url']
    out_file = os.path.join(images_dir, os.path.basename(url))
    with open(out_file, 'w') as f:
      f.write(gcs.Item.from_url(url).read_from())
    x['image_url'] = out_file

  with open(output_csv, 'w') as w:
    csv.DictWriter(w, fieldnames=['image_url', 'label']).writerows(data)


download_images('/content/flowerdata/train200.csv', local_train_file, images_dir)    
download_images('/content/flowerdata/eval100.csv', local_eval_file, images_dir)

The above code can best be illustrated by the comparison below.


In [10]:
!head /content/flowerdata/train200.csv -n 5


gs://cloud-ml-data/img/flower_photos/daisy/754296579_30a9ae018c_n.jpg,daisy
gs://cloud-ml-data/img/flower_photos/dandelion/18089878729_907ed2c7cd_m.jpg,dandelion
gs://cloud-ml-data/img/flower_photos/dandelion/284497199_93a01f48f6.jpg,dandelion
gs://cloud-ml-data/img/flower_photos/dandelion/3554992110_81d8c9b0bd_m.jpg,dandelion
gs://cloud-ml-data/img/flower_photos/daisy/4065883015_4bb6010cb7_n.jpg,daisy

In [11]:
!head {local_train_file} -n 5






Preprocess

Preprocessing uses a Dataflow pipeline to convert the image format, resize images, and run the converted image through a pre-trained model to get the features or embeddings. You can also do this step using alternate technologies like Spark or plain Python code if you like.

The following cell takes ~5 min on a n1-standard-1 VM. Preprocessing the full 3000 images takes about one hour.


In [12]:
# instead of local_train_file, it can take '/content/flowerdata/train200.csv' too, but processing will be slower.
train_set = CsvDataSet(local_train_file, schema='image_url:STRING,label:STRING')
model.preprocess(train_set, preprocessed_dir)


/usr/local/lib/python2.7/dist-packages/apache_beam/coders/typecoders.py:136: UserWarning: Using fallback coder for typehint: Any.
  warnings.warn('Using fallback coder for typehint: %r.' % typehint)
WARNING:root:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
completed

Train

The next step is to train the inception model with the preprocessed images using transfer learning. Transfer learning retains most of the inception model but replaces the final layer as shown in the image.


In [13]:
import logging
logging.getLogger().setLevel(logging.INFO)
model.train(preprocessed_dir, 30, 800, model_dir)
logging.getLogger().setLevel(logging.WARNING)


INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:global_step/sec: 0
INFO:root:Train [master/0], step 1 (0.084 sec) 11.9 global steps/s, 11.9 local steps/s
INFO:root:Eval, step 1:
- on train set loss: 1.746, accuracy: 0.250
-- on eval set loss: 1.841, accuracy: 0.133
INFO:root:Eval, step 800:
- on train set loss: 0.000, accuracy: 1.000
-- on eval set loss: 0.891, accuracy: 0.767
INFO:root:Exporting prediction graph to /content/datalab/tmp/flower/tinyflowermodellocal/model
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /content/datalab/tmp/flower/tinyflowermodellocal/model/saved_model.pb
INFO:tensorflow:SavedModel written to: /content/datalab/tmp/flower/tinyflowermodellocal/model/saved_model.pb
completed

Run TensorBoard to visualize the completed training. Review accuracy and loss in particular.


In [14]:
tb_id = TensorBoard.start(model_dir)


TensorBoard was started successfully with pid 5711. Click here to access it.

We can check the TF summary events from training.


In [15]:
summary = Summary(model_dir)
summary.list_events()


Out[15]:
{u'accuracy': {'/content/datalab/tmp/flower/tinyflowermodellocal/eval_set',
  '/content/datalab/tmp/flower/tinyflowermodellocal/train_set'},
 u'batch/fraction_of_450_full': {'/content/datalab/tmp/flower/tinyflowermodellocal/eval_set',
  '/content/datalab/tmp/flower/tinyflowermodellocal/train_set'},
 u'global_step/sec': {'/content/datalab/tmp/flower/tinyflowermodellocal/train'},
 u'input_producer/fraction_of_32_full': {'/content/datalab/tmp/flower/tinyflowermodellocal/eval_set',
  '/content/datalab/tmp/flower/tinyflowermodellocal/train_set'},
 u'loss': {'/content/datalab/tmp/flower/tinyflowermodellocal/eval_set',
  '/content/datalab/tmp/flower/tinyflowermodellocal/train_set'}}

In [16]:
summary.plot('accuracy')
summary.plot('loss')


Predict

Let's start with a quick check by taking a couple of images and using the model to predict the type of flower locally.


In [17]:
images = [
  'gs://cloud-ml-data/img/flower_photos/daisy/15207766_fc2f1d692c_n.jpg',
  'gs://cloud-ml-data/img/flower_photos/tulips/6876631336_54bf150990.jpg'
]
# set show_image to False to not display pictures.
model.predict(model_dir, images, show_image=True)


Predicting...

daisy(0.99904)

tulips(0.99940)

Out[17]:
image_url label score
0 gs://cloud-ml-data/img/flower_photos/daisy/152... daisy 0.999044
1 gs://cloud-ml-data/img/flower_photos/tulips/68... tulips 0.999405

Evaluate

We did a quick test of the model using a few samples. But we need to understand how the model does by evaluating it against much larger amount of labeled data. In the initial preprocessing step, we did set aside enough images for that purpose. Next, we will use normal batch prediction and compare the results with the previously labeled targets.

The following batch prediction and loading of results takes ~3 minutes.


In [18]:
import google.datalab.bigquery as bq
bq.Dataset('flower').create()


Out[18]:
Dataset bradley-playground.flower

In [19]:
eval_set = CsvDataSet(local_eval_file, schema='image_url:STRING,label:STRING')
model.batch_predict(eval_set, model_dir, output_bq_table='flower.eval_results_local')


completed

Now that we have the results and expected results loaded in a BigQuery table, let's start analyzing the errors and plot the confusion matrix.


In [20]:
%%bq query --name wrong_prediction

SELECT * FROM flower.eval_results_local where target != predicted

In [21]:
wrong_prediction.execute().result()


Out[21]:
image_urltargetpredictedtarget_probpredicted_prob
/content/datalab/tmp/flower/images/5032376020_2ed312306c.jpgsunflowersdaisy0.06453219056130.908665060997
/content/datalab/tmp/flower/images/14674389605_df3c0bcfa1_m.jpgtulipsroses0.01921621151270.953346252441
/content/datalab/tmp/flower/images/24459750_eb49f6e4cb_m.jpgsunflowersroses0.08566407114270.563797473907
/content/datalab/tmp/flower/images/7320089276_87b544e341.jpgdaisytulips0.0008047280716710.998973965645
/content/datalab/tmp/flower/images/9338237628_4d2547608c.jpgrosestulips0.04453659057620.954923868179
/content/datalab/tmp/flower/images/17040847367_b54d05bf52.jpgrosestulips0.0309357102960.968226492405
/content/datalab/tmp/flower/images/850416050_31b3ff7086.jpgrosestulips0.2463135570290.740412175655
/content/datalab/tmp/flower/images/3705716290_cb7d803130_n.jpgrosestulips0.161033868790.829070448875
/content/datalab/tmp/flower/images/6166888942_7058198713_m.jpgsunflowerstulips0.05204066261650.919339179993
/content/datalab/tmp/flower/images/10994032453_ac7f8d9e2e.jpgdaisydandelion0.01844136044380.887112736702
/content/datalab/tmp/flower/images/14088053307_1a13a0bf91_n.jpgdaisysunflowers0.02036239765580.604844450951
/content/datalab/tmp/flower/images/3145692843_d46ba4703c.jpgrosessunflowers2.97755486827e-050.999744951725
/content/datalab/tmp/flower/images/14128835667_b6a916222c.jpgdandelionsunflowers0.3714125454430.59988039732
/content/datalab/tmp/flower/images/19586799286_beb9d684b5.jpgdandelionsunflowers0.1617079377170.626576542854
/content/datalab/tmp/flower/images/5598845098_13e8e9460f.jpgdandelionsunflowers0.2164456099270.783076047897

(rows: 15, time: 2.3s, 0B processed, job: job_fMUjz1byvxaZKDqh-u0orw8WrQU)

Confusion matrix is a common way of comparing the confusion of the model - aggregate data about where the actual result did not match the expected result.


In [22]:
ConfusionMatrix.from_bigquery('flower.eval_results_local').plot()


More advanced analysis can be done using the feature slice view. For the feature slice view, let's define SQL queries that compute accuracy and log loss and then use the metrics.


In [23]:
%%bq query --name accuracy

SELECT
  target,
  SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END) as correct,
  COUNT(*) as total,
  SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END)/COUNT(*) as accuracy
FROM
  flower.eval_results_local
GROUP BY
  target

In [24]:
accuracy.execute().result()


Out[24]:
targetcorrecttotalaccuracy
tulips22230.95652173913
dandelion20230.869565217391
sunflowers11140.785714285714
roses14190.736842105263
daisy18210.857142857143

(rows: 5, time: 2.3s, 0B processed, job: job_iGHbgw9QWLI6r2Yg6LoLmyBXm2M)

In [25]:
%%bq query --name logloss

SELECT feature, AVG(-logloss) as logloss, count(*) as count FROM
(
SELECT feature, CASE WHEN correct=1 THEN LOG(prob) ELSE LOG(1-prob) END as logloss
FROM
(
SELECT
target as feature, 
CASE WHEN target=predicted THEN 1 ELSE 0 END as correct,
target_prob as prob
FROM flower.eval_results_local))
GROUP BY feature

In [26]:
FeatureSliceView().plot(logloss)


Clean up


In [27]:
import shutil
import google.datalab.bigquery as bq

TensorBoard.stop(tb_id)
bq.Table('flower.eval_results_local').delete()
shutil.rmtree(worker_dir)

Recap

In this notebook, we covered local preprocessing, training, prediction and evaluation. We started from data in GCS in csv form plus images; used transfer learning for very fast training and then used BigQuery for model performance analysis. In the next notebook, we will use CloudML APIs that scale a lot better for larger scale. The syntax and analyses will remain the same.


In [ ]: