By deploying or using this software you agree to comply with the AI Hub Terms of Service and the Google APIs Terms of Service. To the extent of a direct conflict of terms, the AI Hub Terms of Service will control.

Overview

This notebook provides an example workflow of using the RetinaNet ML container for training an Object Detection ML model.

Dataset

To preprocess an object detection dataset, look at this COCO example

Objective

The goal of this notebook is to go through a common training workflow:

  • Train an ML model using the AI Platform Training service
  • Monitor the training job with TensorBoard
  • Identify if the model was trained successfully by looking at the generated "Run Report"
  • Deploy the model for serving using the AI Platform Prediction service
  • Use the endpoint for online predictions

Costs

This tutorial uses billable components of Google Cloud Platform (GCP):

  • Cloud AI Platform
  • Cloud Storage

Learn about Cloud AI Platform pricing and Cloud Storage pricing, and use the Pricing Calculator to generate a cost estimate based on your projected usage.

Set up your local development environment

If you are using Colab or AI Platform Notebooks, your environment already meets all the requirements to run this notebook. You can skip this step.

Otherwise, make sure your environment meets this notebook's requirements. You need the following:

  • The Google Cloud SDK
  • Git
  • Python 3
  • virtualenv
  • Jupyter notebook running in a virtual environment with Python 3

The Google Cloud guide to Setting up a Python development environment and the Jupyter installation guide provide detailed instructions for meeting these requirements. The following steps provide a condensed set of instructions:

  1. Install and initialize the Cloud SDK.

  2. Install Python 3.

  3. Install virtualenv and create a virtual environment that uses Python 3.

  4. Activate that environment and run pip install jupyter in a shell to install Jupyter.

  5. Run jupyter notebook in a shell to launch Jupyter.

  6. Open this notebook in the Jupyter Notebook Dashboard.

Set up your GCP project

The following steps are required, regardless of your notebook environment.

  1. Select or create a GCP project.. When you first create an account, you get a $300 free credit towards your compute/storage costs.

  2. Make sure that billing is enabled for your project.

  3. Enable the AI Platform APIs and Compute Engine APIs.

  4. Enter your project ID in the cell below. Then run the cell to make sure the Cloud SDK uses the right project for all the commands in this notebook.

Note: Jupyter runs lines prefixed with ! as shell commands, and it interpolates Python variables prefixed with $ into these commands.


In [ ]:
PROJECT_ID = "[your-project-id]" #@param {type:"string"}
! gcloud config set project $PROJECT_ID

Authenticate your GCP account

If you are using AI Platform Notebooks, your environment is already authenticated. Skip this step.

If you are using Colab, run the cell below and follow the instructions when prompted to authenticate your account via oAuth.

Otherwise, follow these steps:

  1. In the GCP Console, go to the Create service account key page.

  2. From the Service account drop-down list, select New service account.

  3. In the Service account name field, enter a name.

  4. From the Role drop-down list, select Machine Learning Engine > AI Platform Admin and Storage > Storage Object Admin.

  5. Click Create. A JSON file that contains your key downloads to your local environment.

  6. Enter the path to your service account key as the GOOGLE_APPLICATION_CREDENTIALS variable in the cell below and run the cell.


In [ ]:
import sys

# If you are running this notebook in Colab, run this cell and follow the
# instructions to authenticate your GCP account. This provides access to your
# Cloud Storage bucket and lets you submit training jobs and prediction
# requests.

if 'google.colab' in sys.modules:
  from google.colab import auth as google_auth
  google_auth.authenticate_user()

# If you are running this notebook locally, replace the string below with the
# path to your service account key and run this cell to authenticate your GCP
# account.
else:
  %env GOOGLE_APPLICATION_CREDENTIALS ''

Create a Cloud Storage bucket

The following steps are required, regardless of your notebook environment.

You need to have a "workspace" bucket that will hold the dataset and the output from the ML Container. Set the name of your Cloud Storage bucket below. It must be unique across all Cloud Storage buckets.

You may also change the REGION variable, which is used for operations throughout the rest of this notebook. Make sure to choose a region where Cloud AI Platform services are available. You may not use a Multi-Regional Storage bucket for training with AI Platform.


In [ ]:
BUCKET_NAME = "[your-bucket-name]" #@param {type:"string"}
REGION = 'us-central1' #@param {type:"string"}

Only if your bucket doesn't already exist: Run the following cell to create your Cloud Storage bucket.


In [ ]:
! gsutil mb -l $REGION gs://$BUCKET_NAME

Finally, validate access to your Cloud Storage bucket by examining its contents:


In [ ]:
! gsutil ls -al gs://$BUCKET_NAME

Import libraries and define constants


In [ ]:
import time
from IPython.core.display import HTML
import tensorflow as tf
import os
import requests
import base64
from googleapiclient import discovery

Cloud training

Accelerator and distribution support

GPU Multi-GPU Node TPU Workers Parameter Server
Yes Yes Yes No No
    --master-machine-type standard \
    --worker-machine-type cloud_tpu \

In [ ]:
output_location = os.path.join('gs://', BUCKET_NAME, 'output')

job_name = "retinanet_{}".format(time.strftime("%Y%m%d%H%M%S"))
!gcloud beta ai-platform jobs submit training $job_name \
    --master-image-uri gcr.io/aihub-c2t-containers/kfp-components/oob_algorithm/retinanet:latest \
    --region $REGION \
    --scale-tier CUSTOM \
    --master-machine-type standard \
    --worker-machine-type cloud_tpu \
    --worker-count 1 \
    --tpu-tf-version 1.14 \
    -- \
    --output-location $output_location \
    --training-data {BUCKET_NAME}/coco_data/train-*.tfrecord \
    --validation-data {BUCKET_NAME}/coco_data/val-*.tfrecord \
    --validation-json-file {BUCKET_NAME}/coco_data/instances_val2017.json \
    --number-of-classes 80 \
    --use-half-precision True \
    --training-epochs 50 \
    --use-pretrained True \
    --batch-size 128

Monitor the training with TensorBoard


In [ ]:
try:
  %load_ext tensorboard
  %tensorboard --logdir {output_location}
except:
  !tensorboard --logdir {output_location}

Inspect the Run Report

Note that this model needs more training.


In [29]:
if not tf.io.gfile.exists(os.path.join(output_location, 'report.html')):
  raise RuntimeError('The file report.html was not found. Did the training job finish?')

with tf.io.gfile.GFile(os.path.join(output_location, 'report.html')) as f:
  display(HTML(f.read()))


temp_input_nb
+ Table of Contents

Runtime arguments

value
batch_size 128
eval_batch_size 8
first_lr_drop 8
gcp_project ee2a81c09470c949f-ml
gpus_per_node 0
image_size 640
learning_rate 0.08
num_cores 8
num_examples_per_epoch None
number_of_classes 80
output_location gs://aihub-content-test/retinanet
random_flip True
remainder None
second_lr_drop 11
tpu cmle-training-7771763438152407002-tpu
tpu_zone us-central1-b
training_data gs://aihub-content-test-data/davidlicause/retinanet/coco_data/train-00255-of-00256.tfrecord
training_epochs 1
use_half_precision True
use_pretrained True
use_tpu True
validation_data gs://aihub-content-test-data/davidlicause/retinanet/coco_data/val-00031-of-00032.tfrecord
validation_json_file gs://aihub-content-test-data/davidlicause/retinanet/coco_data/instances_val2017.json

Tensorboard snippet

To see the training progress, you can need to install the latest tensorboard with the command: pip install -U tensorboard and then run one of the following commands.

Local tensorboard

tensorboard --logdir gs://aihub-content-test/retinanet

Publicly shared tensorboard

tensorboard dev upload --logdir gs://aihub-content-test/retinanet

Datasets

Data reading snippet

import os
import tensorflow as tf


def _parse_example(example_proto):
  # describing the features.  
  features = {
      'image/height': tf.FixedLenFeature([], tf.int64),
      'image/width': tf.FixedLenFeature([], tf.int64),
      'image/filename': tf.FixedLenFeature([], tf.string),
      'image/source_id': tf.FixedLenFeature([], tf.string),
      'image/key/sha256': tf.FixedLenFeature([], tf.string),
      'image/encoded': tf.FixedLenFeature([], tf.string),
      'image/format': tf.FixedLenFeature([], tf.string),
      'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
      'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
      'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
      'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
      'image/object/class/text': tf.VarLenFeature(tf.string),
      'image/object/is_crowd': tf.VarLenFeature(tf.int64),
      'image/object/area': tf.io.VarLenFeature(tf.float32),
  }
  row = tf.parse_single_example(example_proto, features)
  row['image/decoded'] = tf.io.decode_image(row['image/encoded'])
  return row

file_pattern = 'gs://aihub-content-test-data/davidlicause/retinanet/coco_data/train-00255-of-00256.tfrecord'

dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(file_pattern))
dataset = dataset.map(_parse_example)

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
  row = sess.run(next_element)

Deploy for serving snippet

MODEL_NAME='REPLACE_WITH_YOUR_MODEL_NAME'
MODEL_VERSION='v1'

# create model name
gcloud ai-platform models create $MODEL_NAME

# create version name
gcloud ai-platform versions create $MODEL_VERSION \
  --model $MODEL_NAME \
  --origin gs://aihub-content-test/retinanet/export/1581396002 \
  --runtime-version=1.15 \
  --framework=tensorflow \
  --python-version=3.7

Training data sample

image/decoded image/encoded image/filename image/format image/height image/key/sha256 image/object/area image/object/bbox/xmax image/object/bbox/xmin image/object/bbox/ymax image/object/bbox/ymin image/object/class/text image/object/is_crowd image/source_id image/width
0 [[[92, 75, 55], [88, 73, 54], [97, 82, 63], [95, 80, 59], [88, 70, 48], [96, 75, 54], [88, 67, 48], [96, 77, 60], [95, 72, 54], [98, 75, 61], [99, 77, 63], [86, 68, 46], [71, 55, 32], [80, 63, 45]... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000477919.jpg jpeg 378 914b2ff8ae2c20ee24de115c37c116ee6e60884642611aa93651e1e5d0113394 [ 7518.0176 5904.4673 1716.6152 867.6513 1941.1285 105.32045\n 368.917 1416.3861 51.60835 1984.9585 1056.1414 1989.2743\n 401.65704 4066.238 1029.1202 1501.3108 ... [0.20490626 0.2795 0.1005 0.10176563 0.34628126 0.3988906\n 0.6242969 0.68271875 0.74401563 0.8710156 0.91884375 0.321875\n 0.38734376 0.4849375 0.66648436 0.579625 0.57214063 0.99843... [0.09375 0.18434376 0.0165625 0.05125 0.27907813 0.37946874\n 0.58715624 0.62317187 0.736375 0.80757815 0.8604219 0.248875\n 0.344625 0.4001406 0.61564064 0.5139844 0.47110936 0. ... [0.8344709 0.9187566 0.62521166 0.460873 0.51648146 0.51137567\n 0.5314815 0.6128307 0.5036508 0.5682804 0.52719575 0.98452383\n 0.7871429 0.9391005 0.73486775 0.5987037 0.9159524 0.97... [0.3061111 0.3774074 0.49875662 0.34462962 0.31941798 0.47002646\n 0.47208995 0.4301058 0.46592593 0.36595237 0.40629628 0.8101058\n 0.68822753 0.6887302 0.42481482 0.38084656 0.75738096 0.293... ['person' 'person' 'banana' 'banana' 'banana' 'banana' 'banana' 'banana'\n 'banana' 'banana' 'banana' 'banana' 'banana' 'banana' 'person' 'person'\n 'banana' 'banana'] [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1] 477919 640
1 [[[108, 139, 19], [121, 154, 24], [125, 159, 21], [109, 142, 9], [95, 129, 6], [91, 124, 11], [90, 122, 13], [85, 117, 6], [80, 111, 7], [73, 106, 2], [71, 103, 4], [68, 102, 7], [57, 92, 2], [47,... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000108982.jpg jpeg 427 aee06a9f8ceb72b631c127a64363421a1cbc8a76267847cd3c91d704f27977fd [7488.1963] [0.5323125] [0.41179687] [0.825808] [0.40859485] ['bird'] [0] 108982 640
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
98 [[[255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 2... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000129699.jpg jpeg 480 1ce8028eb4910eb765f4a926e141ecfbba64138069dd5103dc9c877786527e33 [1080.732 1107.0106 7524.269 8806.366 ] [0.66165626 0.61901563 0.307625 0.47921875] [0.6233594 0.5820781 0.13242188 0.31753126] [0.9002917 0.91564584 0.9995833 0.9896042 ] [0.7898333 0.799625 0.8311667 0.7844167] ['traffic light' 'traffic light' 'traffic light' 'traffic light'] [0 0 0 0] 129699 640
99 [[[119, 123, 86], [109, 111, 71], [115, 117, 68], [122, 123, 65], [126, 128, 63], [125, 130, 66], [122, 127, 73], [134, 137, 90], [128, 134, 74], [121, 123, 73], [119, 117, 78], [122, 118, 80], [1... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000037038.jpg jpeg 480 65479ddc3d635c46becc10ab188700cae11788dde2af1a6e657c345f1c4547d0 [2256.8208 379.43225] [0.244625 0.2569375] [0.13501562 0.14560938] [0.75825 0.7825 ] [0.59079164 0.7134167 ] ['person' 'surfboard'] [0 0] 37038 640

100 rows × 15 columns

Validation data sample

image/decoded image/encoded image/filename image/format image/height image/key/sha256 image/object/area image/object/bbox/xmax image/object/bbox/xmin image/object/bbox/ymax image/object/bbox/ymin image/object/class/text image/object/is_crowd image/source_id image/width
0 [[[44, 47, 78], [43, 48, 78], [36, 42, 74], [32, 43, 75], [37, 50, 84], [25, 38, 73], [27, 39, 79], [30, 39, 82], [39, 48, 91], [34, 56, 95], [27, 50, 91], [27, 55, 94], [28, 64, 100], [24, 54, 90... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000099053.jpg jpeg 559 a4e8405077a8f93115116f548e7d59b29a092a6ac2f749732a525e7b82c2f35a [ 9687.258 206335.97 2792.7808 1852.6361 1612.6714 1433.9683\n 1501.8083 357614.38 ] [1. 0.97354686 0.28667188 0.42751563 0.7903594 0.6264531\n 0.682125 1. ] [0.63114065 0.05889063 0.14740625 0.35423437 0.7 0.5617969\n 0.6130625 0. ] [0.54384613 0.9745975 0.66676205 0.8116458 0.6084973 0.8344723\n 0.8131306 1. ] [0. 0.14538461 0.5575313 0.7122719 0.53044724 0.7372451\n 0.7019857 0. ] ['fork' 'bowl' 'broccoli' 'broccoli' 'broccoli' 'broccoli' 'broccoli'\n 'dining table'] [0 0 0 0 0 0 0 0] 99053 640
1 [[[160, 176, 173], [163, 179, 179], [164, 179, 186], [167, 183, 198], [168, 186, 200], [164, 183, 190], [162, 181, 177], [161, 179, 167], [157, 178, 173], [156, 177, 170], [156, 176, 167], [156, 1... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000079031.jpg jpeg 427 b82c8d3cc12b0bd3de1481db37e53932ed2d116805d2f9eff726bd2a01370123 [42510.64 41248.6 ] [0.5176719 0.9416875] [0.09610938 0. ] [0.86662763 0.98665106] [0.00084309 0.6128103 ] ['person' 'surfboard'] [0 0] 79031 640
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
98 [[[173, 172, 170], [174, 173, 169], [173, 172, 168], [170, 169, 167], [172, 170, 171], [171, 169, 170], [175, 174, 170], [171, 172, 164], [174, 173, 171], [176, 175, 173], [172, 171, 169], [172, 1... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000184321.jpg jpeg 480 5018ca72abedb01bd8634a238a941d7f6b99a0112435adbfd4a2ae0e338162e3 [9336.486] [0.4185469] [0.25385937] [0.598125] [0.36116666] ['train'] [0] 184321 640
99 [[[141, 169, 193], [140, 166, 193], [143, 164, 195], [147, 163, 197], [148, 164, 198], [147, 163, 197], [147, 163, 196], [147, 166, 196], [144, 168, 196], [142, 163, 192], [145, 164, 194], [148, 1... ����JFIF,,��C\t\t \t\n\n\n\n\n \n \t\n\n\n��C\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n... 000000266409.jpg jpeg 480 6f29db689ebb18bdf8fbb06524077f6e88e058e7d8d114483ec8dc8d43f76c65 [22075.46 979.4432] [0.7061875 0.6671875] [0.51067185 0.58248436] [0.9417292 1. ] [0.31475 0.91758335] ['person' 'skis'] [0 0] 266409 640

100 rows × 15 columns

Validation JSON file

info

contributor date_created description url version year
value COCO Consortium 2017/09/01 COCO 2017 Dataset http://cocodataset.org 1.0 2017

categories

id name supercategory
0 1 person person
1 2 bicycle vehicle
... ... ... ...
78 89 hair drier indoor
79 90 toothbrush indoor

80 rows × 3 columns

licenses

id name url
0 1 Attribution-NonCommercial-ShareAlike License http://creativecommons.org/licenses/by-nc-sa/2.0/
1 2 Attribution-NonCommercial License http://creativecommons.org/licenses/by-nc/2.0/
... ... ... ...
6 7 No known copyright restrictions http://flickr.com/commons/usage/
7 8 United States Government Work http://www.usa.gov/copyright.shtml

8 rows × 3 columns

images

coco_url date_captured file_name flickr_url height license width
id
397133 http://images.cocodataset.org/val2017/000000397133.jpg 2013-11-14 17:02:52 000000397133.jpg http://farm7.staticflickr.com/6116/6255196340_da26cf2c9e_z.jpg 427 4 640
37777 http://images.cocodataset.org/val2017/000000037777.jpg 2013-11-14 20:55:31 000000037777.jpg http://farm9.staticflickr.com/8429/7839199426_f6d48aa585_z.jpg 230 1 352
... ... ... ... ... ... ... ...
394940 http://images.cocodataset.org/val2017/000000394940.jpg 2013-11-24 13:47:05 000000394940.jpg http://farm9.staticflickr.com/8227/8566023505_e9e9f997bc_z.jpg 640 3 426
15335 http://images.cocodataset.org/val2017/000000015335.jpg 2013-11-25 14:00:10 000000015335.jpg http://farm6.staticflickr.com/5533/10257288534_c916fafd78_z.jpg 480 2 640

5000 rows × 7 columns

annotations

area bbox category_id id image_id iscrowd segmentation
0 702.10575 [473.07, 395.93, 38.65, 28.67] 18 1768 289343 0 [[510.66, 423.01, 511.72, 420.03, 510.45, 416.0, 510.34, 413.02, 510.77, 410.26, 510.77, 407.5, 510.34, 405.16, 511.51, 402.83, 511.41, 400.49, 510.24, 398.16, 509.39, 397.31, 504.61, 399.22, 502....
1 27718.4763 [272.1, 200.23, 151.97, 279.77] 18 1773 61471 0 [[289.74, 443.39, 302.29, 445.32, 308.09, 427.94, 310.02, 416.35, 304.23, 405.73, 300.14, 385.01, 298.23, 359.52, 295.04, 365.89, 282.3, 362.71, 275.29, 358.25, 277.2, 346.14, 280.39, 339.13, 284....
... ... ... ... ... ... ... ...
36779 27277 [10, 41, 403, 152] 52 905200050149 50149 1 {u'counts': [3912, 10, 363, 18, 356, 23, 301, 10, 25, 10, 5, 27, 296, 16, 19, 16, 1, 30, 292, 20, 15, 50, 289, 24, 11, 53, 287, 26, 9, 55, 285, 29, 6, 57, 283, 32, 3, 59, 281, 34, 1, 60, 261, 11, ...
36780 220834 [0, 34, 639, 388] 1 900100250282 250282 1 {u'counts': [179, 27, 392, 41, 380, 51, 371, 59, 363, 67, 356, 73, 350, 79, 129, 9, 207, 82, 124, 10, 208, 84, 121, 10, 209, 85, 121, 8, 209, 88, 336, 89, 154, 1, 181, 90, 154, 2, 178, 73, 2, 16, ...

36781 rows × 7 columns

Predictions

Local predictions snippet

import tensorflow as tf

saved_model = 'gs://aihub-content-test/retinanet/export/1581396002'
predict_fn = tf.contrib.predictor.from_saved_model(saved_model)

# encoded_image must be resized to 640 x 640 x 3 and encoded
predictions = predict_fn({'input': [encoded_image]})

Training predictions sample

Raw predictions sample

detection_boxes detection_classes detection_scores image_info num_detections
0 [[[270.48395 222.41402 364.2963 259.88705 ]\n [197.34203 59.090675 533.2446 133.81287 ]\n [437.53397 255.67345 600.25354 307.52985 ]\n [270.8843 250.56946 376.75946 295.2223 ]\n... [[52. 1. 52. 52. 52. 52. 52. 31. 52. 52. 52. 52. 52. 52. 52. 52. 1. 52.\n 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52. 52.\n 52. 52. 52. 52. 52. 52. 1. 52. 1. 1. 1. ... [[0.9787756 0.9642609 0.92659736 0.9250956 0.91004455 0.86180663\n 0.8302079 0.82389635 0.774523 0.77180094 0.7135634 0.6842082\n 0.68034685 0.66667515 0.63720423 0.6305675 0.6078194 0.... [[640. 640. 1. 640. 640.]] [44.]
1 [[[ 2.5968207e+02 2.6232297e+02 5.1560278e+02 3.4221704e+02]\n [ 2.8382983e+02 2.6770715e+02 3.7501160e+02 3.3270718e+02]\n [ 1.6483125e+02 1.3887802e+02 2.1936662e+02 1.7144524e+02]\n ... [[16. 16. 13. 16. 16. 1. 1. 1. 1. 16. 16. 1. 1. 16. 38. 1. 1. 1.\n 1. 16. 52. 16. 67. 10. 34. 25. 1. 1. 16. 53. 16. 16. 1. 56. 22. 50.\n 1. 16. 16. 1. 1. 1. 34. 53. 13. 1. 25. ... [[0.97964925 0.69946814 0.40357143 0.3263348 0.31309614 0.2728908\n 0.23948859 0.22215372 0.2137502 0.18123256 0.17839244 0.17738923\n 0.16209462 0.1545194 0.13625634 0.12849598 0.12794474 0.... [[640. 640. 1. 640. 640.]] [74.]
... ... ... ... ... ...
98 [[[506.31207 394.98352 576.11676 422.4469 ]\n [512.21857 375.833 581.41864 396.79205 ]\n [515.68066 387.548 574.36566 405.87527 ]\n [516.2678 411.274 565.8817 423.888 ]\n... [[10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10.\n 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 10. 28. 1. 5. 1.\n 10. 10. 1. 10. 10. 1. 28. 1. 1. 5. 10. ... [[0.99954164 0.99095005 0.9693987 0.9212367 0.8982083 0.83216363\n 0.7716321 0.6401021 0.6003213 0.59995717 0.5983422 0.54898846\n 0.53858197 0.47997394 0.4647947 0.4299825 0.42432255 0... [[640. 640. 1. 640. 640.]] [99.]
99 [[[379.80038 88.21208 495.48685 150.51465 ]\n [461.58478 95.801895 500.62262 163.15907 ]\n [472.59723 120.01411 502.56503 144.76631 ]\n [466.72278 107.479546 495.96338 129.1238 ]\n... [[ 1. 42. 42. 42. 42. 1. 42. 42. 1. 42. 42. 1. 42. 1. 1. 42. 32. 1.\n 32. 15. 27. 1. 31. 1. 1. 36. 42. 42. 1. 50. 3. 1. 1. 1. 42. 1.\n 36. 35. 1. 36. 42. 1. 1. 27. 48. 42. 3. ... [[1. 0.9997545 0.82412744 0.7188085 0.6965194 0.6826002\n 0.6724939 0.63843644 0.5005969 0.41778132 0.3901946 0.3672312\n 0.27838463 0.2423769 0.20527342 0.17730622 0.16993716 0.1... [[640. 640. 1. 640. 640.]] [100.]

100 rows × 5 columns

Validation predictions sample

Raw predictions sample

detection_boxes detection_classes detection_scores image_info num_detections
0 [[[ 62.341217 25.172241 603.9596 619.16016 ]\n [421.0427 432.94052 465.55576 474.9627 ]\n [ 53.016922 377.92346 293.328 631.5149 ]\n [ 28.39679 17.524475 603.3229 625.19324 ]\n... [[51. 57. 50. 67. 57. 50. 57. 57. 56. 1. 56. 57. 57. 57. 50. 57. 57. 57.\n 57. 57. 57. 56. 57. 52. 51. 52. 57. 57. 56. 57. 56. 50. 50. 1. 1. 1.\n 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. ... [[0.9883074 0.94395024 0.7881914 0.6692203 0.5218953 0.48460603\n 0.32663575 0.3265929 0.3025253 0.28862762 0.27942786 0.26685145\n 0.26411286 0.26306617 0.26225322 0.23553012 0.23465069 0... [[640. 640. 1. 640. 640.]] [33.]
1 [[[ -4.181488 70.52251 552.6395 339.47955 ]\n [529.2012 483.1694 588.0945 502.2153 ]\n [531.0291 77.08261 607.07404 461.94266 ]\n [283.82797 80.36987 316.14975 ... [[ 1. 1. 42. 1. 1. 1. 21. 8. 37. 21. 37. 1. 42. 42. 1. 1. 1. 65.\n 25. 1. 1. 1. 34. 1. 1. 1. 1. 37. 1. 1. 42. 37. 3. 67. 1. 20.\n 21. 42. 37. 37. 9. 37. 1. 1. 47. 1. 1. ... [[0.9999976 0.2849249 0.17309882 0.15023737 0.14535728 0.12500623\n 0.12449827 0.11545011 0.11445364 0.10737909 0.09450947 0.09210314\n 0.08857374 0.08807714 0.08721705 0.08694089 0.08609616 0... [[640. 640. 1. 640. 640.]] [100.]
... ... ... ... ... ...
98 [[[239.43962 162.66379 363.9946 266.5545 ]\n [243.10648 185.37146 338.47534 246.67822 ]\n [253.36205 225.07515 344.55096 265.93887 ]\n [275.8944 205.04645 306.06177 ... [[ 7. 1. 1. 1. 5. 1. 7. 5. 76. 1. 31. 1. 7. 31. 10. 10. 1. 1.\n 10. 31. 7. 10. 10. 37. 1. 31. 10. 10. 1. 1. 1. 1. 10. 7. 7. 10.\n 14. 1. 10. 10. 10. 1. 10. 7. 10. 5. 31. ... [[0.99996066 0.46129432 0.37923187 0.3591942 0.35027328 0.3439324\n 0.34049925 0.33507195 0.31934923 0.3126775 0.27593192 0.27057907\n 0.2524932 0.23842539 0.20470785 0.19551855 0.19437541 0.... [[640. 640. 1. 640. 640.]] [100.]
99 [[[196.96309 326.94247 594.12067 455.046 ]\n [260.1598 356.7415 310.90103 416.54865]\n [586.4139 360.06927 628.5669 384.85114]\n [585.1925 364.73367 606.53296 408.6659 ]\n [272.5752 374... [[ 1. 27. 35. 35. 27. 35. 35. 35. 35. 35. 27. 27. 27. 35. 27. 27. 35. 27.\n 27. 35. 35. 27. 27. 35. 35. 35. 27. 35. 35. 35. 35. 35. 27. 35. 35. 27.\n 27. 35. 1. 1. 1. 1. 1. 1. 1. 1. 1. ... [[0.9999926 0.99920124 0.9679566 0.95606655 0.94658136 0.9233593\n 0.9211176 0.9087194 0.8822673 0.75636506 0.7542323 0.68389237\n 0.66232204 0.6093652 0.589024 0.52577317 0.5107726 0.... [[640. 640. 1. 640. 640.]] [38.]

100 rows × 5 columns

Training data

Mean Average Precision (mAP) Mean Correct Localization (CorLoc)
value 0.5143 0.6983

Validation data

Mean Average Precision (mAP) Mean Correct Localization (CorLoc)
value 0.3311 0.5333

Distribution of individual CorLoc

Ground truth VS prediction

Training data

Best predictions

Worst predictions

Ground truth VS prediction

Validation data

Best predictions

Worst predictions

Deployment parameters


In [ ]:
#@markdown ---
model = 'retinanet' #@param {type:"string"}
version = 'v1' #@param {type:"string"}
#@markdown ---

In [ ]:
# the exact location of the model is in model_uri.txt
with tf.io.gfile.GFile(os.path.join(output_location, 'model_uri.txt')) as f:
  model_uri = f.read()

# create a model
!gcloud ai-platform models create $model --regions $REGION

# create a version
!gcloud beta ai-platform versions create $version \
  --model $model \
  --origin $model_uri \
  --runtime-version 1.15 \
  --project $PROJECT_ID

Get one image for test prediction

Download and encode one image


In [41]:
# 
!wget --output-document /tmp/image.jpeg \
  https://static01.nyt.com/images/2018/12/22/us/00carattack1-promo/00carattack1-promo-superJumbo-v2.jpg

# read the image, decode, resize and base 64 encode it
with tf.Session() as sess:
  encoded_image = sess.run(tf.io.read_file('/tmp/image.jpeg'))

encoded_image = base64.b64encode(encoded_image).decode()
encoded_image[:200]


--2020-02-11 10:37:06--  https://static01.nyt.com/images/2018/12/22/us/00carattack1-promo/00carattack1-promo-superJumbo-v2.jpg
Resolving static01.nyt.com (static01.nyt.com)... 151.101.189.164
Connecting to static01.nyt.com (static01.nyt.com)|151.101.189.164|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 463251 (452K) [image/jpeg]
Saving to: ‘/tmp/image.jpeg’

/tmp/image.jpeg     100%[===================>] 452.39K   978KB/s    in 0.5s    

2020-02-11 10:37:07 (978 KB/s) - ‘/tmp/image.jpeg’ saved [463251/463251]

Out[41]:
'/9j/4AAQSkZJRgABAQEBLAEsAAD/4gxYSUNDX1BST0ZJTEUAAQEAAAxITGlubwIQAABtbnRyUkdCIFhZWiAHzgACAAkABgAxAABhY3NwTVNGVAAAAABJRUMgc1JHQgAAAAAAAAAAAAAAAAAA9tYAAQAAAADTLUhQICAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA'

Make online prediction


In [49]:
service = discovery.build('ml', 'v1')
name = 'projects/{project}/models/{model}/versions/{version}'.format(project=PROJECT_ID,
                                                                    model=model,
                                                                    version=version)
body = {'instances': {'input': {'b64': encoded_image}}}

response = service.projects().predict(name=name, body=body).execute()
if 'error' in response:
    raise RuntimeError(response['error'])
    
response['predictions'][0]


/Users/evo/Library/Python/3.7/lib/python/site-packages/google/auth/_default.py:66: UserWarning:

Your application has authenticated using end user credentials from Google Cloud SDK. We recommend that most server applications use service accounts instead. If your application continues to use end user credentials from Cloud SDK, you might receive a "quota exceeded" or "API not enabled" error. For more information about service accounts, see https://cloud.google.com/docs/authentication/

Out[49]:
{'image_info': [426.0, 640.0, 3.200000047683716, 1365.0, 2048.0],
 'detection_classes': [3.0,
  10.0,
  10.0,
  10.0,
  1.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  1.0,
  10.0,
  10.0,
  3.0,
  10.0,
  10.0,
  1.0,
  10.0,
  10.0,
  3.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  14.0,
  10.0,
  13.0,
  3.0,
  10.0,
  3.0,
  8.0,
  3.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  10.0,
  1.0,
  1.0,
  10.0,
  3.0,
  10.0,
  10.0,
  10.0,
  10.0,
  3.0,
  10.0,
  3.0,
  10.0,
  28.0,
  10.0,
  1.0,
  14.0,
  10.0,
  10.0,
  3.0,
  3.0,
  10.0,
  10.0,
  49.0,
  1.0,
  10.0,
  10.0,
  1.0,
  28.0,
  1.0,
  3.0,
  10.0,
  1.0,
  10.0,
  10.0,
  10.0,
  3.0,
  13.0,
  13.0,
  3.0,
  1.0,
  14.0,
  19.0,
  10.0,
  1.0,
  10.0,
  10.0,
  10.0,
  1.0,
  10.0],
 'detection_boxes': [[188.1590118408203,
   374.0584716796875,
   247.447021484375,
   534.239990234375],
  [16.917407989501953,
   280.8747863769531,
   60.792598724365234,
   302.5542907714844],
  [38.823333740234375, 363.8931579589844, 79.69244384765625, 382.700439453125],
  [50.61260986328125, 370.7249450683594, 70.44306945800781, 389.5875549316406],
  [187.580078125, 418.3533630371094, 239.88604736328125, 464.27587890625],
  [59.181365966796875,
   359.8604431152344,
   77.23545837402344,
   376.2574462890625],
  [40.089874267578125, 321.7099914550781, 78.4954833984375, 383.5168151855469],
  [121.20897674560547,
   453.4537658691406,
   150.49844360351562,
   471.1507873535156],
  [48.04679870605469, 353.8089599609375, 85.882080078125, 410.3874206542969],
  [27.976261138916016, 276.7470703125, 44.31047058105469, 294.6986083984375],
  [50.67173385620117,
   354.4557800292969,
   76.87645721435547,
   366.34124755859375],
  [185.27716064453125, 448.1981506347656, 237.580810546875, 490.7494201660156],
  [44.068321228027344,
   345.8328857421875,
   59.32551956176758,
   371.1339111328125],
  [30.794706344604492, 269.388916015625, 66.41658020019531, 322.2486877441406],
  [181.45864868164062,
   433.30804443359375,
   193.30384826660156,
   470.66644287109375],
  [43.76359939575195,
   285.40618896484375,
   60.87775421142578,
   315.1212463378906],
  [70.18024444580078, 358.3095397949219, 81.32141876220703, 375.4552001953125],
  [174.607421875, 399.7996826171875, 259.16741943359375, 509.2102355957031],
  [126.9947738647461,
   453.3807373046875,
   147.2378692626953,
   460.05474853515625],
  [46.24774169921875,
   280.4414367675781,
   58.979949951171875,
   296.0444030761719],
  [178.97415161132812,
   442.48956298828125,
   186.17001342773438,
   482.14471435546875],
  [31.957719802856445,
   299.21063232421875,
   44.73707962036133,
   318.9381103515625],
  [21.36283302307129,
   291.0984191894531,
   54.077857971191406,
   315.5933532714844],
  [39.22830581665039, 360.7502746582031, 65.58700561523438, 372.4872131347656],
  [21.08956527709961, 276.9336853027344, 35.40372848510742, 293.2171630859375],
  [10.958246231079102,
   240.03839111328125,
   52.19831848144531,
   305.70062255859375],
  [38.098087310791016,
   364.9984130859375,
   50.908653259277344,
   384.5873718261719],
  [65.51311492919922,
   362.63995361328125,
   81.48748779296875,
   393.5404968261719],
  [39.69805145263672, 299.20489501953125, 52.6159782409668, 326.01513671875],
  [60.86628341674805,
   374.2838134765625,
   74.61965942382812,
   399.80975341796875],
  [12.043745040893555, 277.4474182128906, 54.80329895019531, 350.08544921875],
  [119.03958129882812, 466.705078125, 130.37466430664062, 475.3836669921875],
  [130.13690185546875,
   461.8191833496094,
   147.67759704589844,
   474.83831787109375],
  [33.567081451416016,
   274.42242431640625,
   58.36017990112305,
   284.61444091796875],
  [187.0061492919922,
   443.7926940917969,
   248.8065643310547,
   496.35284423828125],
  [33.56477737426758, 347.0362854003906, 59.35630798339844, 393.912353515625],
  [155.1269073486328, 72.69804382324219, 176.91758728027344, 91.406005859375],
  [174.42259216308594, 434.3402404785156, 189.6849822998047, 499.3359375],
  [120.9566650390625, 459.2337341308594, 130.6550750732422, 468.952392578125],
  [178.2853240966797,
   453.3275146484375,
   185.28721618652344,
   467.2874755859375],
  [186.58139038085938, 374.8846435546875, 241.64425659179688, 529.73046875],
  [172.7978057861328,
   446.1040954589844,
   180.87852478027344,
   470.66229248046875],
  [54.114341735839844,
   342.21087646484375,
   67.57201385498047,
   362.1702880859375],
  [117.20278930664062,
   441.75445556640625,
   167.0687255859375,
   479.3235778808594],
  [30.487014770507812, 275.9310607910156, 42.384002685546875, 282.3876953125],
  [145.31735229492188, 458.09710693359375, 153.1826171875, 469.8982849121094],
  [16.646974563598633,
   284.87640380859375,
   28.750518798828125,
   305.2222595214844],
  [146.0789794921875,
   452.6148376464844,
   158.36871337890625,
   460.42767333984375],
  [63.52558898925781, 347.9367370605469, 85.10527038574219, 370.0069885253906],
  [124.08424377441406,
   447.02069091796875,
   146.3446502685547,
   452.59759521484375],
  [117.92158508300781,
   472.3658752441406,
   129.84088134765625,
   481.0884094238281],
  [183.56466674804688,
   455.6319274902344,
   191.54368591308594,
   467.436279296875],
  [182.65359497070312,
   444.47235107421875,
   191.6577606201172,
   461.1174621582031],
  [53.18244552612305,
   279.9129943847656,
   63.012908935546875,
   300.9374694824219],
  [181.04061889648438,
   436.6422424316406,
   186.76156616210938,
   452.74627685546875],
  [76.61975860595703, 357.08734130859375, 88.0002212524414, 378.1725769042969],
  [145.13958740234375,
   455.32366943359375,
   152.04811096191406,
   462.7490539550781],
  [121.34233856201172, 455.24700927734375, 129.8720703125, 461.1498107910156],
  [159.155517578125, 455.6427307128906, 169.41253662109375, 469.6439514160156],
  [183.03285217285156,
   455.6417541503906,
   194.16656494140625,
   489.782958984375],
  [48.667640686035156,
   382.6463623046875,
   65.74028778076172,
   394.2112121582031],
  [184.89797973632812, 432.4736022949219, 204.1837615966797, 487.5517578125],
  [43.597389221191406,
   311.49029541015625,
   58.03738021850586,
   362.9372253417969],
  [172.7978057861328,
   446.1040954589844,
   180.87852478027344,
   470.66229248046875],
  [27.80367088317871,
   371.58123779296875,
   66.5971908569336,
   408.97979736328125],
  [186.32461547851562,
   467.91351318359375,
   257.72003173828125,
   533.4468994140625],
  [192.9571533203125, 411.4344177246094, 254.4733123779297, 469.5745849609375],
  [122.54256439208984,
   474.1783752441406,
   147.54771423339844,
   479.46832275390625],
  [155.71401977539062,
   457.4963073730469,
   163.13836669921875,
   464.5789794921875],
  [180.04515075683594,
   468.1459655761719,
   186.41122436523438,
   483.033447265625],
  [177.1296844482422,
   447.2373352050781,
   182.19415283203125,
   458.67620849609375],
  [145.4827117919922,
   469.8309020996094,
   154.59701538085938,
   480.5066223144531],
  [168.89169311523438,
   452.16680908203125,
   173.12709045410156,
   459.25970458984375],
  [351.2250061035156, 624.50244140625, 396.10546875, 640.6241455078125],
  [195.22314453125, 445.1014099121094, 203.7427215576172, 460.197265625],
  [155.19442749023438, 459.8017578125, 163.66128540039062, 475.0662536621094],
  [62.93824768066406, 344.9513244628906, 80.5408935546875, 357.32196044921875],
  [187.51211547851562, 437.7054443359375, 204.66061401367188, 463.6455078125],
  [178.4820556640625,
   450.7328186035156,
   186.04022216796875,
   485.2181396484375],
  [198.00936889648438, 419.75286865234375, 209.38291931152344, 430.037109375],
  [198.34487915039062,
   498.6209411621094,
   212.03504943847656,
   527.7868041992188],
  [135.9821014404297,
   475.14984130859375,
   146.61538696289062,
   484.0533752441406],
  [577.2827758789062, 544.5910034179688, 639.1688232421875, 583.7003173828125],
  [119.00767517089844,
   437.28070068359375,
   147.58509826660156,
   456.4928894042969],
  [145.56883239746094,
   444.40972900390625,
   159.18209838867188,
   452.35650634765625],
  [138.71499633789062,
   451.0059814453125,
   146.90463256835938,
   455.67755126953125],
  [189.1619873046875,
   449.1416320800781,
   200.88711547851562,
   466.1269226074219],
  [161.87548828125, 59.308929443359375, 181.87794494628906, 83.16561889648438],
  [168.80027770996094, 75.48798370361328, 183.68997192382812, 84.86767578125],
  [180.45782470703125,
   481.7185974121094,
   185.98170471191406,
   502.1627197265625],
  [313.37359619140625,
   60.03468322753906,
   381.5538330078125,
   114.58316040039062],
  [119.92286682128906,
   450.6955261230469,
   154.1692657470703,
   474.1802978515625],
  [208.02760314941406,
   214.12022399902344,
   266.3639221191406,
   247.16156005859375],
  [157.30230712890625,
   466.5044860839844,
   164.14637756347656,
   482.12982177734375],
  [196.1655731201172,
   452.8328552246094,
   208.94357299804688,
   468.4895935058594],
  [163.31520080566406, 454.28857421875, 170.89198303222656, 463.4341735839844],
  [154.32815551757812,
   423.2593688964844,
   163.5664825439453,
   429.92535400390625],
  [37.72755813598633,
   333.83111572265625,
   58.04801940917969,
   348.9486999511719],
  [345.7443542480469, 264.3921813964844, 402.7663879394531, 299.1131286621094],
  [154.24925231933594,
   441.95941162109375,
   162.48114013671875,
   452.1764221191406]],
 'detection_scores': [0.9999340772628784,
  0.999077320098877,
  0.997527539730072,
  0.8563366532325745,
  0.7562338709831238,
  0.7509637475013733,
  0.6871128678321838,
  0.5894297361373901,
  0.5790871381759644,
  0.5674752593040466,
  0.5669094324111938,
  0.5564807057380676,
  0.5409089922904968,
  0.5178970694541931,
  0.4852903187274933,
  0.45581668615341187,
  0.45336097478866577,
  0.4525772035121918,
  0.42608749866485596,
  0.4029306471347809,
  0.4004065990447998,
  0.3873329758644104,
  0.3413333594799042,
  0.33326637744903564,
  0.33302563428878784,
  0.3282148838043213,
  0.3233325779438019,
  0.3100794851779938,
  0.29875513911247253,
  0.2904152274131775,
  0.2837875485420227,
  0.2546415627002716,
  0.2536107003688812,
  0.24663248658180237,
  0.2450152486562729,
  0.24359434843063354,
  0.24266840517520905,
  0.23771117627620697,
  0.22317026555538177,
  0.21985377371311188,
  0.21930308640003204,
  0.2188199758529663,
  0.21622031927108765,
  0.2147449404001236,
  0.21154329180717468,
  0.20214411616325378,
  0.20131146907806396,
  0.20094634592533112,
  0.1969105452299118,
  0.1967000663280487,
  0.1944190114736557,
  0.1884787380695343,
  0.18735317885875702,
  0.18551942706108093,
  0.18177415430545807,
  0.17505097389221191,
  0.17340728640556335,
  0.17087136209011078,
  0.168310284614563,
  0.16697771847248077,
  0.16432581841945648,
  0.1577453464269638,
  0.15604306757450104,
  0.15328916907310486,
  0.15326720476150513,
  0.15173810720443726,
  0.15035898983478546,
  0.1485171765089035,
  0.14785583317279816,
  0.1478416472673416,
  0.14562185108661652,
  0.1455327272415161,
  0.14224261045455933,
  0.1402326077222824,
  0.13977549970149994,
  0.13923518359661102,
  0.13868077099323273,
  0.1380036622285843,
  0.1354478895664215,
  0.1331036239862442,
  0.13161826133728027,
  0.13076505064964294,
  0.12773340940475464,
  0.1270059496164322,
  0.12579387426376343,
  0.12561452388763428,
  0.1250024288892746,
  0.12493589520454407,
  0.1243327185511589,
  0.1227402538061142,
  0.12210544943809509,
  0.12197118997573853,
  0.11971386522054672,
  0.11860451847314835,
  0.11849794536828995,
  0.11847437918186188,
  0.11788228154182434,
  0.11787869036197662,
  0.11708483844995499,
  0.11599892377853394],
 'num_detections': 100.0}