Demonstrate Seq2Seq Wrapper with CMUDict dataset


In [1]:
import tensorflow as tf
import numpy as np

# preprocessed data
from datasets.cmudict import data
import data_utils

In [2]:
# load data from pickle and npy files
data_ctl, idx_words, idx_phonemes = data.load_data(PATH='datasets/cmudict/')
(trainX, trainY), (testX, testY), (validX, validY) = data_utils.split_dataset(idx_phonemes, idx_words)

In [3]:
# parameters 
xseq_len = trainX.shape[-1]
yseq_len = trainY.shape[-1]
batch_size = 128
xvocab_size = len(data_ctl['idx2pho'].keys())  
yvocab_size = len(data_ctl['idx2alpha'].keys())
emb_dim = 128

Create an instance of the Wrapper


In [4]:
import seq2seq_wrapper

In [6]:
import importlib
importlib.reload(seq2seq_wrapper)


Out[6]:
<module 'seq2seq_wrapper' from '/home/suriya/_/tf/tf-seq2seq-wrapper/seq2seq_wrapper.py'>

In [7]:
model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len,
                               yseq_len=yseq_len,
                               xvocab_size=xvocab_size,
                               yvocab_size=yvocab_size,
                               ckpt_path='ckpt/cmudict/',
                               emb_dim=emb_dim,
                               num_layers=3
                               )


<log> Building Graph </log>

Create data generators

Read data_utils.py for more information


In [8]:
val_batch_gen = data_utils.rand_batch_gen(validX, validY, 16)
train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, 128)
  • Computational graph was built when the model was instantiated
  • Now all we need to do is train the model using processed CMUdict dataset, via data generators
  • Internally a loop is run for epochs times for training
  • Evaluation is done periodically.

Train


In [16]:
sess = model.train(train_batch_gen, val_batch_gen, sess=sess1)


Model saved to disk at iteration #5000
val   loss : 0.428838

Model saved to disk at iteration #10000
val   loss : 0.352279

Model saved to disk at iteration #15000
val   loss : 0.302959

Model saved to disk at iteration #20000
val   loss : 0.290396

Model saved to disk at iteration #25000
val   loss : 0.250649

Model saved to disk at iteration #30000
val   loss : 0.239168

Model saved to disk at iteration #35000
val   loss : 0.198182

Model saved to disk at iteration #40000
val   loss : 0.203086

Model saved to disk at iteration #45000
val   loss : 0.213277

Model saved to disk at iteration #50000
val   loss : 0.208600

Model saved to disk at iteration #55000
val   loss : 0.228991

Model saved to disk at iteration #60000
val   loss : 0.205643
Interrupted by user at iteration 60001

Restore last saved session from disk


In [9]:
sess = model.restore_last_session()

Predict


In [10]:
output = model.predict(sess, val_batch_gen.__next__()[0])
print(output.shape)


(16, 16)

In [11]:
output


Out[11]:
array([[12, 15, 14,  7, 23,  5, 12, 12,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 8, 15, 12, 19, 23, 15, 18, 20,  8,  0,  0,  0,  0,  0,  0,  0],
       [ 3,  1, 22,  1, 12,  9,  5, 18,  9,  0,  0,  0,  0,  0,  0,  0],
       [16, 18, 15, 19,  5, 12,  9, 20,  9, 26,  5,  0,  0,  0,  0,  0],
       [ 8,  9,  7,  1, 19,  8,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 4,  9, 19,  2,  9, 12,  5,  9,  6,  0,  0,  0,  0,  0,  0,  0],
       [11,  9, 19, 13,  5, 20,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [21, 14,  3, 15, 20,  5,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [19, 21,  3,  8,  9,  1, 11,  9,  0,  0,  0,  0,  0,  0,  0,  0],
       [13,  5, 20, 18, 15,  3, 15, 12,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 4, 21, 18, 21, 19,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [13,  1,  7, 14, 15, 12,  9,  1,  0,  0,  0,  0,  0,  0,  0,  0],
       [12,  5, 20, 20,  1, 18, 20,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 7,  1, 12, 12,  1, 23,  1, 25,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 2, 18,  9,  1, 14,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 8, 15, 18, 16,  9,  6,  9, 14,  9,  0,  0,  0,  0,  0,  0,  0]])

Let us decode and see the words


In [12]:
for oi in output:
    print(data_utils.decode(sequence=oi, lookup=data_ctl['idx2alpha'],
                           separator=''))


longwell
holsworth
cavalieri
proselitize
higashi
disbileif
kismet
uncoted
suchiaki
metrocol
durus
magnolia
lettart
gallaway
briane
horpifini