Copyright 2020 The TensorFlow Authors.
In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This is the notebook for step 12 of the codelab Add Firebase to your TensorFlow Lite-powered app.
In this notebook, we will train an improved version of the handwritten digit classification model using data augmentation. Then we will upload the model to Firebase using the Firebase ML Model Management API.
|
Let's start by training the improved model.
We will not go into details about the model training here but if you are interested to learn more about why we apply data augmentation to this model and other details, check out this notebook.
In [0]:
# Import dependencies
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
print("TensorFlow version:", tf.__version__)
# Import MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Add a color dimension to the images in "train" and "validate" dataset to
# leverage Keras's data augmentation utilities later.
train_images = np.expand_dims(train_images, axis=3)
test_images = np.expand_dims(test_images, axis=3)
# Define data augmentation configs
datagen = keras.preprocessing.image.ImageDataGenerator(
rotation_range=30,
width_shift_range=0.25,
height_shift_range=0.25,
shear_range=0.25,
zoom_range=0.2
)
# Generate augmented data from MNIST dataset
train_generator = datagen.flow(train_images, train_labels)
test_generator = datagen.flow(test_images, test_labels)
# Define and train the Keras model.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Dropout(0.25),
keras.layers.Flatten(),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_generator, epochs=5, validation_data=test_generator)
# Convert Keras model to TF Lite format and quantize.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open('mnist_v2.tflite', "wb") as f:
f.write(tflite_model)
Step 1. Upload the private key (json file) for your service account and Initialize Firebase Admin
In [0]:
import os
from google.colab import files
import firebase_admin
from firebase_admin import ml
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]='/content/' + fn
projectID = fn.rsplit("-firebase")[0]
firebase_admin.initialize_app(
options={'projectId': projectID,
'storageBucket': projectID + '.appspot.com' })
Step 2. Upload the model file to Cloud Storage
In [0]:
# This uploads it to your bucket as mmnist_v2.tflite
source = ml.TFLiteGCSModelSource.from_keras_model(model, 'mnist_v2.tflite')
print (source.gcs_tflite_uri)
Step 3. Deploy the model to Firebase
In [0]:
# Create a Model Format
model_format = ml.TFLiteFormat(model_source=source)
# Create a Model object
sdk_model_1 = ml.Model(display_name="mnist_v2", model_format=model_format)
# Make the Create API call to create the model in Firebase
firebase_model_1 = ml.create_model(sdk_model_1)
print(firebase_model_1.as_dict())
# Publish the model
model_id = firebase_model_1.model_id
firebase_model_1 = ml.publish_model(model_id)