Copyright 2019 Google LLC
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [ ]:
!git clone https://github.com/google-research/google-research.git
In [ ]:
import sys
import os
import tarfile
import urllib
import zipfile
sys.path.append('./google-research')
Below steps are taken from model_train_eval - it has more tests in streaming, non streaming, quantized and non qunatized models with TF and TFLite.
In [ ]:
# TF streaming
from kws_streaming.models import models
from kws_streaming.models import utils
from kws_streaming.layers.modes import Modes
In [ ]:
import tensorflow as tf
import numpy as np
import tensorflow.compat.v1 as tf1
import logging
from kws_streaming.models import model_params
from kws_streaming.train import model_flags
from kws_streaming.train import test
from kws_streaming.train import train
from kws_streaming.models import utils
from kws_streaming import data
tf1.disable_eager_execution()
In [ ]:
config = tf1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf1.Session(config=config)
In [ ]:
# general imports
import matplotlib.pyplot as plt
import os
import json
import numpy as np
import scipy as scipy
import scipy.io.wavfile as wav
import scipy.signal
In [ ]:
tf.__version__
In [ ]:
tf1.reset_default_graph()
sess = tf1.Session()
tf1.keras.backend.set_session(sess)
tf1.keras.backend.set_learning_phase(0)
In [ ]:
# set PATH to data sets (for example to speech commands V2):
# it can be downloaded from
# https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
# if you already run "00_check-data.ipynb" then folder "data2" should be located in the current dir
current_dir = os.getcwd()
DATA_PATH = os.path.join(current_dir, "data2/")
In [ ]:
def waveread_as_pcm16(filename):
"""Read in audio data from a wav file. Return d, sr."""
with tf.io.gfile.GFile(filename, 'rb') as file_handle:
samplerate, wave_data = wav.read(file_handle)
# Read in wav file.
return wave_data, samplerate
def wavread_as_float(filename, target_sample_rate=16000):
"""Read in audio data from a wav file. Return d, sr."""
wave_data, samplerate = waveread_as_pcm16(filename)
desired_length = int(
round(float(len(wave_data)) / samplerate * target_sample_rate))
wave_data = scipy.signal.resample(wave_data, desired_length)
# Normalize short ints to floats in range [-1..1).
data = np.array(wave_data, np.float32) / 32768.0
return data, target_sample_rate
In [ ]:
# Set path to wav file to visualize it
wav_file = os.path.join(DATA_PATH, "left/012187a4_nohash_0.wav")
# read audio file
wav_data, samplerate = wavread_as_float(wav_file)
In [ ]:
assert samplerate == 16000
In [ ]:
plt.plot(wav_data)
In [ ]:
# select model name should be one of
model_params.HOTWORD_MODEL_PARAMS.keys()
In [ ]:
MODEL_NAME = 'svdf'
MODELS_PATH = os.path.join(current_dir, "models")
MODEL_PATH = os.path.join(MODELS_PATH, MODEL_NAME + "/")
MODEL_PATH
In [ ]:
os.makedirs(MODEL_PATH)
In [ ]:
# get toy model settings
FLAGS = model_params.HOTWORD_MODEL_PARAMS[MODEL_NAME]
In [ ]:
# set path to data and model (where model will be stored)
FLAGS.data_dir = DATA_PATH
FLAGS.train_dir = MODEL_PATH
# set speech feature extractor properties
FLAGS.mel_upper_edge_hertz = 7000
FLAGS.window_size_ms = 40.0
FLAGS.window_stride_ms = 20.0
FLAGS.mel_num_bins = 80
FLAGS.dct_num_features = 40
FLAGS.feature_type = 'mfcc_op'
FLAGS.preprocess = 'raw'
# set training settings
FLAGS.train = 1
FLAGS.how_many_training_steps = '400,400,400,400' # reduced number of training steps for test only
FLAGS.learning_rate = '0.001,0.0005,0.0001,0.00002'
FLAGS.lr_schedule = 'linear'
FLAGS.verbosity = logging.INFO
# data shuffling config
FLAGS.resample = 0.15
FLAGS.time_shift_ms = 100
In [ ]:
# model parameters are different for every model
FLAGS.model_name = MODEL_NAME
FLAGS.svdf_memory_size = "4,10,10,10,10,10"
FLAGS.svdf_units1 = "16,32,32,32,64,128"
FLAGS.svdf_act = "'relu','relu','relu','relu','relu','relu'"
FLAGS.svdf_units2 = "40,40,64,64,64,-1"
FLAGS.svdf_dropout = "0.0,0.0,0.0,0.0,0.0,0.0"
FLAGS.svdf_pad = 0
FLAGS.dropout1 = 0.0
FLAGS.units2 = ''
FLAGS.act2 = ''
In [ ]:
flags = model_flags.update_flags(FLAGS)
In [ ]:
flags.__dict__
In [ ]:
with open(os.path.join(flags.train_dir, 'flags.json'), 'wt') as f:
json.dump(flags.__dict__, f)
In [ ]:
# visualize a model
model_non_stream_batch = models.MODELS[flags.model_name](flags)
tf.keras.utils.plot_model(
model_non_stream_batch,
show_shapes=True,
show_layer_names=True,
expand_nested=True)
In [ ]:
model_non_stream_batch.summary()
In [ ]:
# Model training
train.train(flags)
In [ ]:
folder_name = 'tf'
test.tf_non_stream_model_accuracy(flags, folder_name)
more testing functions can be found at test
In [ ]: