Converting a TensorFlow.js Speech-Commands Model to Python and TFLite formats

This notebook showcases how to convert a TensorFlow.js (TF.js) Speech Commands model to the Python (tensorflow.keras) and TFLite formats. The TFLite format enables the model to be deployed to mobile enviroments such as Android phones.

The technique outlined in this notebook are applicable to:

  • the original Speech Commands models (including the 18w and directional4w) variants,
  • transfer-learned models based on the original models, which can be trained and exported from Teachable Machine's Audio Project

First, install the required tensorflow and tensorflowjs Python packages.


In [ ]:
# We need scipy for .wav file IO.
!pip install tensorflowjs==2.1.0 scipy==1.4.1
# TensorFlow 2.3.0 is required due to https://github.com/tensorflow/tensorflow/issues/38135
# TODO: Switch to 2.3.0 final release when it comes out.
!pip install tensorflow-cpu==2.3.0rc0

Below we download the files of the original or transfer-learned TF.js Speech Commands model. The code example here downloads the original model. But the approach is the same for a transfer-learned model downloaded from Teachable Machine, except that the files may come in as a ZIP archive in the case of Teachable Machine and hence requires unzippping.


In [1]:
!mkdir -p /tmp/tfjs-sc-model
!curl -o /tmp/tfjs-sc-model/metadata.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/metadata.json
!curl -o /tmp/tfjs-sc-model/model.json -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/model.json
!curl -o /tmp/tfjs-sc-model/group1-shard1of2 -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard1of2
!curl -o /tmp/tfjs-sc-model/group1-shard2of2 -fsSL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/v0.3/browser_fft/18w/group1-shard2of2

In [2]:
import json

import tensorflow as tf
import tensorflowjs as tfjs

In [3]:
# Specify the path to the TensorFlow.js Speech Commands model,
# either original or transfer-learned on https://teachablemachine.withgoogle.com/)
tfjs_model_json_path = '/tmp/tfjs-sc-model/model.json'

# This is the main classifier model.
model = tfjs.converters.load_keras_model(tfjs_model_json_path)

As a required step, we download the audio preprocessing layer that replicates WebAudio's Fourier transform for non-browser environments such as Android phones.


In [4]:
!curl -o /tmp/tfjs-sc-model/sc_preproc_model.tar.gz -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/sc_preproc_model.tar.gz
!cd /tmp/tfjs-sc-model && tar xzvf ./sc_preproc_model.tar.gz


./sc_preproc_model/
./sc_preproc_model/assets/
./sc_preproc_model/variables/
./sc_preproc_model/variables/variables.data-00000-of-00001
./sc_preproc_model/variables/variables.index
./sc_preproc_model/saved_model.pb

In [5]:
# Load the preprocessing layer (wrapped in a tf.keras Model).
preproc_model_path = '/tmp/tfjs-sc-model/sc_preproc_model'
preproc_model = tf.keras.models.load_model(preproc_model_path)
preproc_model.summary()

# From the input_shape of the preproc_model, we can determine the
# required length of the input audio snippet.
input_length = preproc_model.input_shape[-1]
print("Input audio length = %d" % input_length)


WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Model: "audio_preproc"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
audio_preprocessing_layer (A (None, None, None, 1)     2048      
=================================================================
Total params: 2,048
Trainable params: 0
Non-trainable params: 2,048
_________________________________________________________________
Input audio length = 44032

In [6]:
# Construct the new non-browser model by combining the preprocessing
# layer with the main classifier model.

combined_model = tf.keras.Sequential(name='combined_model')
combined_model.add(preproc_model)
combined_model.add(model)
combined_model.build([None, input_length])
combined_model.summary()


Model: "combined_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
audio_preproc (Sequential)   (None, None, None, 1)     2048      
_________________________________________________________________
sequential (Sequential)      (None, 20)                1468684   
=================================================================
Total params: 1,470,732
Trainable params: 1,468,684
Non-trainable params: 2,048
_________________________________________________________________

In order to quickly test that the converted model works, let's download a sample .wav file.


In [7]:
!curl -o /tmp/tfjs-sc-model/audio_sample_one_male_adult.wav -fSsL https://storage.googleapis.com/tfjs-models/tfjs/speech-commands/conversion/audio_sample_one_male_adult.wav

In [8]:
# Listen to the audio sample.
wav_file_path = '/tmp/tfjs-sc-model/audio_sample_one_male_adult.wav'
import IPython.display as ipd
ipd.Audio(wav_file_path)  # Play the .wav file.


Out[8]:

In [9]:
# Read the wav file and truncate it to the an input length
# suitable for the model.
from scipy.io import wavfile

# fs: sample rate in Hz; xs: the audio PCM samples.
fs, xs = wavfile.read(wav_file_path)

if len(xs) >= input_length:
    xs = xs[:input_length]
else:
    raise ValueError("Audio from .wav file is too short")

In [10]:
# Try running some examples through the combined model.
input_tensor = tf.constant(xs, shape=(1, input_length), dtype=tf.float32) / 32768.0
# The model outputs the probabilties for the classes (`probs`).
probs = combined_model.predict(input_tensor)

# Read class labels of the model.
metadata_json_path = '/tmp/tfjs-sc-model/metadata.json'

with open(metadata_json_path, 'r') as f:
    metadata = json.load(f)
    class_labels = metadata["words"]

# Get sorted probabilities and their corresponding class labels.
probs_and_labels = list(zip(probs[0].tolist(), class_labels))
# Sort the probabilities in descending order.
probs_and_labels = sorted(probs_and_labels, key=lambda x: -x[0])
probs_and_labels
# len(probs_and_labels)

# Print the top-5 labels:
print('top-5 class probabilities:')
for i in range(5):
    prob, label = probs_and_labels[i]
    print('%20s: %.4e' % (label, prob))


top-5 class probabilities:
                 one: 1.0000e+00
                nine: 5.0455e-19
           _unknown_: 1.0553e-20
                down: 4.0031e-26
                  no: 3.8358e-26

In [11]:
# Save the model as a tflite file.
tflite_output_path = '/tmp/tfjs-sc-model/combined_model.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model(combined_model)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
with open(tflite_output_path, 'wb') as f:
    f.write(converter.convert())
print("Saved tflite file at: %s" % tflite_output_path)


WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /usr/local/google/home/cais/venv_tfjs/lib/python3.7/site-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: /tmp/tmplb12fskv/assets
Saved tflite file at: /tmp/tfjs-sc-model/combined_model.tflite