This notebook showcases how to convert a TensorFlow.js (TF.js) Speech Commands model to the Python (tensorflow.keras) and TFLite formats. The TFLite format enables the model to be deployed to mobile enviroments such as Android phones.
The technique outlined in this notebook are applicable to:
First, install the required tensorflow and tensorflowjs Python packages.
In [ ]:
# We need scipy for .wav file IO.
!pip install tensorflowjs==2.1.0 scipy==1.4.1
# TensorFlow 2.3.0 is required due to https://github.com/tensorflow/tensorflow/issues/38135
# TODO: Switch to 2.3.0 final release when it comes out.
!pip install tensorflow-cpu==2.3.0rc0
Below we download the files of the original or transfer-learned TF.js Speech Commands model. The code example here downloads the original model. But the approach is the same for a transfer-learned model downloaded from Teachable Machine, except that the files may come in as a ZIP archive in the case of Teachable Machine and hence requires unzippping.
In [1]:
!mkdir -p /tmp/tfjs-sc-model
!curl -o /tmp/tfjs-sc-model/metadata.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/metadata.json
!curl -o /tmp/tfjs-sc-model/model.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/model.json
!curl -o /tmp/tfjs-sc-model/group1-shard1of2 -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard1of2
!curl -o /tmp/tfjs-sc-model/group1-shard2of2 -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard2of2
In [2]:
import json
import tensorflow as tf
import tensorflowjs as tfjs
In [3]:
# Specify the path to the TensorFlow.js Speech Commands model,
# either original or transfer-learned on https://teachablemachine.withgoogle.com/)
tfjs_model_json_path = '/tmp/tfjs-sc-model/model.json'
# This is the main classifier model.
model = tfjs.converters.load_keras_model(tfjs_model_json_path)
As a required step, we download the audio preprocessing layer that replicates WebAudio's Fourier transform for non-browser environments such as Android phones.
In [4]:
!curl -o /tmp/tfjs-sc-model/sc_preproc_model.tar.gz -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/sc_preproc_model.tar.gz
!cd /tmp/tfjs-sc-model && tar xzvf ./sc_preproc_model.tar.gz
In [5]:
# Load the preprocessing layer (wrapped in a tf.keras Model).
preproc_model_path = '/tmp/tfjs-sc-model/sc_preproc_model'
preproc_model = tf.keras.models.load_model(preproc_model_path)
preproc_model.summary()
# From the input_shape of the preproc_model, we can determine the
# required length of the input audio snippet.
input_length = preproc_model.input_shape[-1]
print("Input audio length = %d" % input_length)
In [6]:
# Construct the new non-browser model by combining the preprocessing
# layer with the main classifier model.
combined_model = tf.keras.Sequential(name='combined_model')
combined_model.add(preproc_model)
combined_model.add(model)
combined_model.build([None, input_length])
combined_model.summary()
In order to quickly test that the converted model works, let's download a sample .wav file.
In [7]:
!curl -o /tmp/tfjs-sc-model/audio_sample_one_male_adult.wav -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/audio_sample_one_male_adult.wav
In [8]:
# Listen to the audio sample.
wav_file_path = '/tmp/tfjs-sc-model/audio_sample_one_male_adult.wav'
import IPython.display as ipd
ipd.Audio(wav_file_path) # Play the .wav file.
Out[8]:
In [9]:
# Read the wav file and truncate it to the an input length
# suitable for the model.
from scipy.io import wavfile
# fs: sample rate in Hz; xs: the audio PCM samples.
fs, xs = wavfile.read(wav_file_path)
if len(xs) >= input_length:
xs = xs[:input_length]
else:
raise ValueError("Audio from .wav file is too short")
In [10]:
# Try running some examples through the combined model.
input_tensor = tf.constant(xs, shape=(1, input_length), dtype=tf.float32) / 32768.0
# The model outputs the probabilties for the classes (`probs`).
probs = combined_model.predict(input_tensor)
# Read class labels of the model.
metadata_json_path = '/tmp/tfjs-sc-model/metadata.json'
with open(metadata_json_path, 'r') as f:
metadata = json.load(f)
class_labels = metadata["words"]
# Get sorted probabilities and their corresponding class labels.
probs_and_labels = list(zip(probs[0].tolist(), class_labels))
# Sort the probabilities in descending order.
probs_and_labels = sorted(probs_and_labels, key=lambda x: -x[0])
probs_and_labels
# len(probs_and_labels)
# Print the top-5 labels:
print('top-5 class probabilities:')
for i in range(5):
prob, label = probs_and_labels[i]
print('%20s: %.4e' % (label, prob))
In [11]:
# Save the model as a tflite file.
tflite_output_path = '/tmp/tfjs-sc-model/combined_model.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model(combined_model)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
with open(tflite_output_path, 'wb') as f:
f.write(converter.convert())
print("Saved tflite file at: %s" % tflite_output_path)