Licensed under the Apache License, Version 2.0 (the "License")


In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Retrain a classification model for Edge TPU using post-training quantization (with TF1)

In this tutorial, we'll use TensorFlow 1.15 to create an image classification model, train it with a flowers dataset, and convert it into the TensorFlow Lite format that's compatible with the Edge TPU (available in Coral devices).

The model will be based on a pre-trained version of MobileNet V2. We'll start by retraining only the classification layers, reusing MobileNet's pre-trained feature extractor layers. Then we'll fine-tune the model by also updating weights in some of the feature extractor layers. This type of transfer learning is much faster than training the entire model from scratch.

Once it's trained, we'll use post-training quantization to convert all the parameters to unit8 format, which increases inferencing speed and is required for compatibility on the Edge TPU.

For more details about how to create a model that's compatible with the Edge TPU, see the documentation at coral.ai.

Note: This tutorial requires TensorFlow 1.15. If you're using TF 2.0+, see .

    

To start running all the code in this tutorial, select Runtime > Run all in the Colab toolbar.

Import the required libraries


In [0]:
try:
  # This %tensorflow_version magic only works in Colab.
  %tensorflow_version 1.x
except Exception:
  pass
# For your non-Colab code, be sure you have tensorflow==1.15
import tensorflow as tf
assert tf.__version__.startswith('1')

tf.enable_eager_execution()

import os
import numpy as np
import matplotlib.pyplot as plt

Prepare the training data

First let's download and organize the flowers dataset we'll use to retrain the model (it contains 5 flower classes).

Pay attention to this part so you can reproduce it with your own images dataset. In particular, notice that the "flower_photos" directory contains an appropriately-named directory for each class. The following code then randomizes and divides up all these photos into training and validation sets, and generates the labels file.


In [0]:
_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"

zip_file = tf.keras.utils.get_file(origin=_URL, 
                                   fname="flower_photos.tgz", 
                                   extract=True)

flowers_dir = os.path.join(os.path.dirname(zip_file), 'flower_photos')

Next, we use ImageDataGenerator to rescale the image data into float values (divide by 255 so the tensor values are between 0 and 1), and call flow_from_directory() to create two generators: one for the training dataset and one for the validation dataset.


In [0]:
IMAGE_SIZE = 224
BATCH_SIZE = 64

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255, 
    validation_split=0.2)

train_generator = datagen.flow_from_directory(
    flowers_dir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE, 
    subset='training')

val_generator = datagen.flow_from_directory(
    flowers_dir,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE, 
    subset='validation')

On each iteration, these generators provide a batch of images by reading images from disk and processing them to the proper tensor size (224 x 224). The output is a tuple of (images, labels). For example, you can see the shapes here:


In [0]:
image_batch, label_batch = next(val_generator)
image_batch.shape, label_batch.shape

Now save the class labels to a text file:


In [0]:
print (train_generator.class_indices)

labels = '\n'.join(sorted(train_generator.class_indices.keys()))

with open('flower_labels.txt', 'w') as f:
  f.write(labels)

In [0]:
!cat flower_labels.txt

Build the model

Now we'll create a model that's capable of transfer learning on just the last fully-connected layer.

We'll start with MobileNet V2 from Keras as the base model, which is pre-trained with the ImageNet dataset (trained to recognize 1,000 classes). This provides us a great feature extractor for image classification and we can then simply train a new classification layer with our own dataset.

Note: Not all models from tf.keras.applications are compatible with the Edge TPU. For details, read about quantizing Keras models.

Create the base model

When instantiating the MobileNetV2 model, we specify the include_top=False argument in order to load the network without the classification layers at the top. Then we set trainable false to freeze all the weights in the base model. This effectively converts the model into a feature extractor because all the pre-trained weights and biases are preserved in the lower layers when we begin training for our classification head.


In [0]:
IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)

# Create the base model from the pre-trained MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                              include_top=False, 
                                              weights='imagenet')
base_model.trainable = False

Add a classification head

Now we create a new Sequential model and pass the frozen MobileNet model from above as the base of the graph, and append new classification layers so we can set the final output dimension to match the number of classes in our dataset (5 types of flowers).


In [0]:
model = tf.keras.Sequential([
  base_model,
  tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.GlobalAveragePooling2D(),
  tf.keras.layers.Dense(units=5, activation='softmax')
])

Configure the model

Although this method is called compile(), it's basically a configuration step that's required before we can start training.


In [0]:
model.compile(optimizer=tf.keras.optimizers.Adam(), 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

You can see a string summary of the final network with the summary() method:


In [0]:
model.summary()

And because the majority of the model graph is frozen in the base model, weights from only the last convolution and dense layers are trainable:


In [0]:
print('Number of trainable weights = {}'.format(len(model.trainable_weights)))

Train the model

Now we can train the model using data provided by the train_generator and val_generator we created at the beginning.

This takes 5-10 minutes to finish.


In [0]:
history = model.fit_generator(train_generator, 
                    epochs=10, 
                    validation_data=val_generator)

Review the learning curves


In [0]:
acc = history.history['acc']
val_acc = history.history['val_acc']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

Fine tune the base model

So far, we've only trained the classification layers—the weights of the pre-trained network were not changed. The accuracy results aren't bad, but could be better.

One way we can increase the accuracy is to train (or "fine-tune") more layers from the pre-trained model. That is, we'll un-freeze some layers from the base model and adjust those weights (which were originally trained with 1,000 ImageNet classes) so they're better tuned for features found in our flowers dataset.

Un-freeze more layers

So instead of freezing the entire base model, we'll freeze individual layers.

First, let's see how many layers are in the base model:


In [0]:
print("Number of layers in the base model: ", len(base_model.layers))

Let's try freezing just the bottom 100 layers.


In [0]:
base_model.trainable = True
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False

Reconfigure the model

Now configure the model again, but this time with a lower training rate (the default is 0.001).


In [0]:
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [0]:
model.summary()

In [0]:
print('Number of trainable weights = {}'.format(len(model.trainable_weights)))

Continue training

Now start training all the trainable layers. This starts with the weights we already trained in the classification layers, so we don't need as many epochs.


In [0]:
history_fine = model.fit_generator(train_generator, 
                         epochs=5,
                         validation_data=val_generator)

Review the new learning curves

Now that we've done some fine-tuning on the MobileNet V2 base model, let's check the accuracy.


In [0]:
acc = history_fine.history['acc']
val_acc = history_fine.history['val_acc']

loss = history_fine.history['loss']
val_loss = history_fine.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

This is better, but it's not ideal.

The validation loss is much higher than the training loss, so there could be some overfitting during training. The overfitting might also be because the new training set is relatively small with less intra-class variance, compared to the original ImageNet dataset used to train MobileNet V2.

So this model isn't trained to an accuracy that's production ready, but it works well enough as a demonstration. So let's move on and convert the model to be compatible with the Edge TPU.

Convert to TFLite

Ordinarily, creating a TensorFlow Lite model is just a few lines of code using the TFLiteConverter. For example, this code creates a standard TensorFlow Lite model:


In [0]:
saved_keras_model = 'model.h5'
model.save(saved_keras_model)

converter = tf.lite.TFLiteConverter.from_keras_model_file(saved_keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

with open('mobilenet_v2_1.0_224.tflite', 'wb') as f:
  f.write(tflite_model)

However, this .tflite file isn't compatible with the Edge TPU because although the DEFAULT optimizations flag will quantize the weights, the activation values are still in floating-point. So we must fully quantize the model to use int8 format for all parameter data (both weights and activations).

To fully quantize the model, we need to perform post-training quantization with a representative dataset, which requires a few more arguments for the TFLiteConverter, and a function that builds a dataset that's representative of the training dataset.

So let's convert the model again, this time using post-training quantization:


In [0]:
# A generator that provides a representative dataset
def representative_data_gen():
  dataset_list = tf.data.Dataset.list_files(flowers_dir + '/*/*')
  for i in range(100):
    image = next(iter(dataset_list))
    image = tf.io.read_file(image)
    image = tf.io.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    image = tf.cast(image / 255., tf.float32)
    image = tf.expand_dims(image, 0)
    yield [image]

saved_keras_model = 'model.h5'
model.save(saved_keras_model)

converter = tf.lite.TFLiteConverter.from_keras_model_file(saved_keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# This ensures that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# These set the input and output tensors to uint8
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
# And this sets the representative dataset so we can quantize the activations
converter.representative_dataset = representative_data_gen
tflite_model = converter.convert()

with open('mobilenet_v2_1.0_224_quant.tflite', 'wb') as f:
  f.write(tflite_model)

Note: An alternative technique to quantize the model is to use quantization-aware training. This typically results in better accuracy because the training accounts for the decreased parameter precision. However, quantization-aware training requires modifications to the model graph, which is beyond the scope of this tutorial.

Compare the accuracy

So now we have a fully quantized model. To be sure the conversion went well, let's run some inferences using both the raw trained model and the new TensorFlow Lite model.

First check the accuracy of the raw model:


In [0]:
batch_images, batch_labels = next(val_generator)

logits = model(batch_images)
prediction = np.argmax(logits, axis=1)
truth = np.argmax(batch_labels, axis=1)

keras_accuracy = tf.keras.metrics.Accuracy()
keras_accuracy(prediction, truth)

print("Raw model accuracy: {:.3%}".format(keras_accuracy.result()))

Now let's check the accuracy of the .tflite file, using the same dataset:

(This TensorFlow Lite code is a bit more complicated. For details, read the guide to TensorFlow Lite inference.)


In [0]:
def set_input_tensor(interpreter, input):
  input_details = interpreter.get_input_details()[0]
  tensor_index = input_details['index']
  input_tensor = interpreter.tensor(tensor_index)()[0]
  # Inputs for the TFLite model must be uint8, so we quantize our input data.
  # NOTE: This step is necessary only because we're receiving input data from
  # ImageDataGenerator, which rescaled all image data to float [0,1]. When using
  # bitmap inputs, they're already uint8 [0,255] so this can be replaced with:
  #   input_tensor[:, :] = input
  scale, zero_point = input_details['quantization']
  input_tensor[:, :] = np.uint8(input / scale + zero_point)

def classify_image(interpreter, input):
  set_input_tensor(interpreter, input)
  interpreter.invoke()
  output_details = interpreter.get_output_details()[0]
  output = interpreter.get_tensor(output_details['index'])
  # Outputs from the TFLite model are uint8, so we dequantize the results:
  scale, zero_point = output_details['quantization']
  output = scale * (output - zero_point)
  top_1 = np.argmax(output)
  return top_1

interpreter = tf.lite.Interpreter('mobilenet_v2_1.0_224_quant.tflite')
interpreter.allocate_tensors()

# Collect all inference predictions in a list
batch_prediction = []
batch_truth = np.argmax(batch_labels, axis=1)

for i in range(len(batch_images)):
  prediction = classify_image(interpreter, batch_images[i])
  batch_prediction.append(prediction)

# Compare all predictions to the ground truth
tflite_accuracy = tf.keras.metrics.Accuracy()
tflite_accuracy(batch_prediction, batch_truth)
print("Quant TF Lite accuracy: {:.3%}".format(tflite_accuracy.result()))

A small drop in accuracy is expected with post-training quantization. You might be able to improve this by refining the representative dataset used during quantization.

As mentioned earlier, you might also get better accuracy with quantization-aware training.

Compile for the Edge TPU

Finally, we're ready to compile the model for the Edge TPU.

First download the Edge TPU Compiler:


In [0]:
! curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -

! echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list

! sudo apt-get update

! sudo apt-get install edgetpu-compiler

Then compile the model:


In [0]:
! edgetpu_compiler mobilenet_v2_1.0_224_quant.tflite

That's it.

The compiled model uses the same filename but with "_edgetpu" appended at the end.

Download the model

You can download the converted model and labels file from Colab like this:


In [0]:
from google.colab import files

files.download('mobilenet_v2_1.0_224_quant_edgetpu.tflite')
files.download('flower_labels.txt')

If you get a "Failed to fetch" error here, it's probably because the files weren't done saving. So just wait a moment and try again.

Also look out for a browser popup that might need approval to download the files.

Run the model on the Edge TPU

You can now run the model on your Coral device with acceleration on the Edge TPU.

To get started, try using this code for image classification with the TensorFlow Lite API. Just follow the instructions on that page to set up your device, copy the mobilenet_v2_1.0_224_quant_edgetpu.tflite and flower_labels.txt files to your Coral Dev Board or device with a Coral Accelerator, and pass it a flower photo like this:

python3 classify_image.py \
  --model mobilenet_v2_1.0_224_quant_edgetpu.tflite \
  --labels flower_labels.txt \
  --input flower.jpg

Check out more examples for running inference at coral.ai/examples.