In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
In this tutorial, we'll use TensorFlow 1.15 to create an image classification model, train it with a flowers dataset, and convert it into the TensorFlow Lite format that's compatible with the Edge TPU (available in Coral devices).
The model will be based on a pre-trained version of MobileNet V2. We'll start by retraining only the classification layers, reusing MobileNet's pre-trained feature extractor layers. Then we'll fine-tune the model by also updating weights in some of the feature extractor layers. This type of transfer learning is much faster than training the entire model from scratch.
Once it's trained, we'll use post-training quantization to convert all the parameters to unit8 format, which increases inferencing speed and is required for compatibility on the Edge TPU.
For more details about how to create a model that's compatible with the Edge TPU, see the documentation at coral.ai.
Note: This tutorial requires TensorFlow 1.15. If you're using TF 2.0+, see .
To start running all the code in this tutorial, select Runtime > Run all in the Colab toolbar.
In [0]:
try:
# This %tensorflow_version magic only works in Colab.
%tensorflow_version 1.x
except Exception:
pass
# For your non-Colab code, be sure you have tensorflow==1.15
import tensorflow as tf
assert tf.__version__.startswith('1')
tf.enable_eager_execution()
import os
import numpy as np
import matplotlib.pyplot as plt
First let's download and organize the flowers dataset we'll use to retrain the model (it contains 5 flower classes).
Pay attention to this part so you can reproduce it with your own images dataset. In particular, notice that the "flower_photos" directory contains an appropriately-named directory for each class. The following code then randomizes and divides up all these photos into training and validation sets, and generates the labels file.
In [0]:
_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
zip_file = tf.keras.utils.get_file(origin=_URL,
fname="flower_photos.tgz",
extract=True)
flowers_dir = os.path.join(os.path.dirname(zip_file), 'flower_photos')
Next, we use ImageDataGenerator
to rescale the image data into float values (divide by 255 so the tensor values are between 0 and 1), and call flow_from_directory()
to create two generators: one for the training dataset and one for the validation dataset.
In [0]:
IMAGE_SIZE = 224
BATCH_SIZE = 64
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
validation_split=0.2)
train_generator = datagen.flow_from_directory(
flowers_dir,
target_size=(IMAGE_SIZE, IMAGE_SIZE),
batch_size=BATCH_SIZE,
subset='training')
val_generator = datagen.flow_from_directory(
flowers_dir,
target_size=(IMAGE_SIZE, IMAGE_SIZE),
batch_size=BATCH_SIZE,
subset='validation')
On each iteration, these generators provide a batch of images by reading images from disk and processing them to the proper tensor size (224 x 224). The output is a tuple of (images, labels). For example, you can see the shapes here:
In [0]:
image_batch, label_batch = next(val_generator)
image_batch.shape, label_batch.shape
Now save the class labels to a text file:
In [0]:
print (train_generator.class_indices)
labels = '\n'.join(sorted(train_generator.class_indices.keys()))
with open('flower_labels.txt', 'w') as f:
f.write(labels)
In [0]:
!cat flower_labels.txt
Now we'll create a model that's capable of transfer learning on just the last fully-connected layer.
We'll start with MobileNet V2 from Keras as the base model, which is pre-trained with the ImageNet dataset (trained to recognize 1,000 classes). This provides us a great feature extractor for image classification and we can then simply train a new classification layer with our own dataset.
Note: Not all models from tf.keras.applications
are compatible with the Edge TPU. For details, read about quantizing Keras models.
When instantiating the MobileNetV2
model, we specify the include_top=False
argument in order to load the network without the classification layers at the top. Then we set trainable
false to freeze all the weights in the base model. This effectively converts the model into a feature extractor because all the pre-trained weights and biases are preserved in the lower layers when we begin training for our classification head.
In [0]:
IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)
# Create the base model from the pre-trained MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
base_model.trainable = False
Now we create a new Sequential
model and pass the frozen MobileNet model from above as the base of the graph, and append new classification layers so we can set the final output dimension to match the number of classes in our dataset (5 types of flowers).
In [0]:
model = tf.keras.Sequential([
base_model,
tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(units=5, activation='softmax')
])
In [0]:
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss='categorical_crossentropy',
metrics=['accuracy'])
You can see a string summary of the final network with the summary()
method:
In [0]:
model.summary()
And because the majority of the model graph is frozen in the base model, weights from only the last convolution and dense layers are trainable:
In [0]:
print('Number of trainable weights = {}'.format(len(model.trainable_weights)))
Now we can train the model using data provided by the train_generator
and val_generator
we created at the beginning.
This takes 5-10 minutes to finish.
In [0]:
history = model.fit_generator(train_generator,
epochs=10,
validation_data=val_generator)
In [0]:
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
So far, we've only trained the classification layers—the weights of the pre-trained network were not changed. The accuracy results aren't bad, but could be better.
One way we can increase the accuracy is to train (or "fine-tune") more layers from the pre-trained model. That is, we'll un-freeze some layers from the base model and adjust those weights (which were originally trained with 1,000 ImageNet classes) so they're better tuned for features found in our flowers dataset.
So instead of freezing the entire base model, we'll freeze individual layers.
First, let's see how many layers are in the base model:
In [0]:
print("Number of layers in the base model: ", len(base_model.layers))
Let's try freezing just the bottom 100 layers.
In [0]:
base_model.trainable = True
fine_tune_at = 100
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
In [0]:
model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
loss='categorical_crossentropy',
metrics=['accuracy'])
In [0]:
model.summary()
In [0]:
print('Number of trainable weights = {}'.format(len(model.trainable_weights)))
Now start training all the trainable layers. This starts with the weights we already trained in the classification layers, so we don't need as many epochs.
In [0]:
history_fine = model.fit_generator(train_generator,
epochs=5,
validation_data=val_generator)
Now that we've done some fine-tuning on the MobileNet V2 base model, let's check the accuracy.
In [0]:
acc = history_fine.history['acc']
val_acc = history_fine.history['val_acc']
loss = history_fine.history['loss']
val_loss = history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
This is better, but it's not ideal.
The validation loss is much higher than the training loss, so there could be some overfitting during training. The overfitting might also be because the new training set is relatively small with less intra-class variance, compared to the original ImageNet dataset used to train MobileNet V2.
So this model isn't trained to an accuracy that's production ready, but it works well enough as a demonstration. So let's move on and convert the model to be compatible with the Edge TPU.
Ordinarily, creating a TensorFlow Lite model is just a few lines of code using the TFLiteConverter
. For example, this code creates a standard TensorFlow Lite model:
In [0]:
saved_keras_model = 'model.h5'
model.save(saved_keras_model)
converter = tf.lite.TFLiteConverter.from_keras_model_file(saved_keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open('mobilenet_v2_1.0_224.tflite', 'wb') as f:
f.write(tflite_model)
However, this .tflite
file isn't compatible with the Edge TPU because although the DEFAULT
optimizations flag will quantize the weights, the activation values are still in floating-point. So we must fully quantize the model to use int8 format for all parameter data (both weights and activations).
To fully quantize the model, we need to perform post-training quantization with a representative dataset, which requires a few more arguments for the TFLiteConverter
, and a function that builds a dataset that's representative of the training dataset.
So let's convert the model again, this time using post-training quantization:
In [0]:
# A generator that provides a representative dataset
def representative_data_gen():
dataset_list = tf.data.Dataset.list_files(flowers_dir + '/*/*')
for i in range(100):
image = next(iter(dataset_list))
image = tf.io.read_file(image)
image = tf.io.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
image = tf.cast(image / 255., tf.float32)
image = tf.expand_dims(image, 0)
yield [image]
saved_keras_model = 'model.h5'
model.save(saved_keras_model)
converter = tf.lite.TFLiteConverter.from_keras_model_file(saved_keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# This ensures that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# These set the input and output tensors to uint8
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
# And this sets the representative dataset so we can quantize the activations
converter.representative_dataset = representative_data_gen
tflite_model = converter.convert()
with open('mobilenet_v2_1.0_224_quant.tflite', 'wb') as f:
f.write(tflite_model)
Note: An alternative technique to quantize the model is to use quantization-aware training. This typically results in better accuracy because the training accounts for the decreased parameter precision. However, quantization-aware training requires modifications to the model graph, which is beyond the scope of this tutorial.
So now we have a fully quantized model. To be sure the conversion went well, let's run some inferences using both the raw trained model and the new TensorFlow Lite model.
First check the accuracy of the raw model:
In [0]:
batch_images, batch_labels = next(val_generator)
logits = model(batch_images)
prediction = np.argmax(logits, axis=1)
truth = np.argmax(batch_labels, axis=1)
keras_accuracy = tf.keras.metrics.Accuracy()
keras_accuracy(prediction, truth)
print("Raw model accuracy: {:.3%}".format(keras_accuracy.result()))
Now let's check the accuracy of the .tflite
file, using the same dataset:
(This TensorFlow Lite code is a bit more complicated. For details, read the guide to TensorFlow Lite inference.)
In [0]:
def set_input_tensor(interpreter, input):
input_details = interpreter.get_input_details()[0]
tensor_index = input_details['index']
input_tensor = interpreter.tensor(tensor_index)()[0]
# Inputs for the TFLite model must be uint8, so we quantize our input data.
# NOTE: This step is necessary only because we're receiving input data from
# ImageDataGenerator, which rescaled all image data to float [0,1]. When using
# bitmap inputs, they're already uint8 [0,255] so this can be replaced with:
# input_tensor[:, :] = input
scale, zero_point = input_details['quantization']
input_tensor[:, :] = np.uint8(input / scale + zero_point)
def classify_image(interpreter, input):
set_input_tensor(interpreter, input)
interpreter.invoke()
output_details = interpreter.get_output_details()[0]
output = interpreter.get_tensor(output_details['index'])
# Outputs from the TFLite model are uint8, so we dequantize the results:
scale, zero_point = output_details['quantization']
output = scale * (output - zero_point)
top_1 = np.argmax(output)
return top_1
interpreter = tf.lite.Interpreter('mobilenet_v2_1.0_224_quant.tflite')
interpreter.allocate_tensors()
# Collect all inference predictions in a list
batch_prediction = []
batch_truth = np.argmax(batch_labels, axis=1)
for i in range(len(batch_images)):
prediction = classify_image(interpreter, batch_images[i])
batch_prediction.append(prediction)
# Compare all predictions to the ground truth
tflite_accuracy = tf.keras.metrics.Accuracy()
tflite_accuracy(batch_prediction, batch_truth)
print("Quant TF Lite accuracy: {:.3%}".format(tflite_accuracy.result()))
A small drop in accuracy is expected with post-training quantization. You might be able to improve this by refining the representative dataset used during quantization.
As mentioned earlier, you might also get better accuracy with quantization-aware training.
Finally, we're ready to compile the model for the Edge TPU.
First download the Edge TPU Compiler:
In [0]:
! curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
! echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list
! sudo apt-get update
! sudo apt-get install edgetpu-compiler
Then compile the model:
In [0]:
! edgetpu_compiler mobilenet_v2_1.0_224_quant.tflite
That's it.
The compiled model uses the same filename but with "_edgetpu" appended at the end.
You can download the converted model and labels file from Colab like this:
In [0]:
from google.colab import files
files.download('mobilenet_v2_1.0_224_quant_edgetpu.tflite')
files.download('flower_labels.txt')
If you get a "Failed to fetch" error here, it's probably because the files weren't done saving. So just wait a moment and try again.
Also look out for a browser popup that might need approval to download the files.
You can now run the model on your Coral device with acceleration on the Edge TPU.
To get started, try using this code for image classification with the TensorFlow Lite API. Just follow the instructions on that page to set up your device, copy the mobilenet_v2_1.0_224_quant_edgetpu.tflite
and flower_labels.txt
files to your Coral Dev Board or device with a Coral Accelerator, and pass it a flower photo like this:
python3 classify_image.py \
--model mobilenet_v2_1.0_224_quant_edgetpu.tflite \
--labels flower_labels.txt \
--input flower.jpg
Check out more examples for running inference at coral.ai/examples.