In [1]:
"""
Train an end-to-end speech recognition model using CTC.
Use $python train.py --help for usage
"""
from __future__ import absolute_import, division, print_function
import argparse
import logging
import os
from data_generator import DataGenerator
from model import compile_gru_model, compile_train_fn, compile_test_fn
from utils import configure_logging, save_model
logger = logging.getLogger(__name__)
In [ ]:
save_dir="/liepa_notebooks/data/jupyter_baidu/save_data"
language="lt"
train_desc_file="/liepa_notebooks/data/liepa_train_en.json"
val_desc_file="/liepa_notebooks/data/liepa_validation_en.json"
In [ ]:
# Prepare the data generator
datagen = DataGenerator(language=language)
# Load the JSON file that contains the dataset
datagen.load_train_data(train_desc_file)
datagen.load_validation_data(val_desc_file)
# Use a few samples from the dataset, to calculate the means and variance
# of the features, so that we can center our inputs to the network
datagen.fit_train(100)
# Compile a Recurrent Network with 1 1D convolution layer, GRU units
# and 1 fully connected layer
model = compile_gru_model(recur_layers=3, nodes=1000, batch_norm=True)
# Compile the CTC training function
train_fn = compile_train_fn(model)
# Compile the validation function
val_fn = compile_test_fn(model)