STATIC GRAPH


In [1]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xinit
from recurrence import *

TODO

  • [x] Dynamic batch size
  • [x] Mask padding while calculating loss
  • [x] Deal with variable length input sequence
  • [x] Add EOS (end of sentence) token to data (target sequence)

DATA


In [2]:
import data_utils
metadata, idx_q, idx_a = data_utils.load_data('../data/')

In [3]:
# add special symbol
i2w = metadata['idx2w']
w2i = metadata['w2idx']

Parameters


In [26]:
B = 8
L = len(idx_q[0])
vocab_size = len(i2w)
enc_hdim = 150
dec_hdim = enc_hdim

Graph


In [5]:
tf.reset_default_graph()

Placeholders


In [6]:
inputs = tf.placeholder(tf.int32, shape=[None,L], name='inputs')
targets = tf.placeholder(tf.int32, shape=[None,L], name='targets')
go_token = tf.reshape(tf.fill(tf.shape(inputs[:,0]), w2i['GO']), [-1,1])
decoder_inputs = tf.concat(
    values=[go_token, targets[:, 1:]],
    axis=1)
training = tf.placeholder(tf.bool, name='is_training')
batch_size = tf.shape(inputs)[0] # infer batch size

In [7]:
decoder_inputs, targets, inputs, batch_size


Out[7]:
(<tf.Tensor 'concat:0' shape=(?, 21) dtype=int32>,
 <tf.Tensor 'targets:0' shape=(?, 21) dtype=int32>,
 <tf.Tensor 'inputs:0' shape=(?, 21) dtype=int32>,
 <tf.Tensor 'strided_slice_2:0' shape=() dtype=int32>)

Sequence Length


In [8]:
# Encoder
#  Get sequence lengths
def seq_len(t):
    return tf.reduce_sum(tf.cast(t>0, tf.int32), axis=1)

enc_seq_len = seq_len(inputs)

In [9]:
# Decoder
#   Mask targets
#    to be applied to cross entropy loss
padding_mask = tf.cast(targets > 0, tf.float32 )

Embedding


In [10]:
emb_mat = tf.get_variable('emb', shape=[vocab_size, enc_hdim], dtype=tf.float32, 
                         initializer=xinit())
emb_enc_inputs = tf.nn.embedding_lookup(emb_mat, inputs)
emb_dec_inputs = tf.nn.embedding_lookup(emb_mat, decoder_inputs)

Encoder


In [11]:
with tf.variable_scope('encoder'):
    (estates_f, estates_b), _ = bi_net(gru_n(num_layers=3, num_units=enc_hdim),
                                       gru_n(num_layers=3, num_units=enc_hdim),
                                       emb_enc_inputs,
                                       batch_size=B,
                                       timesteps=L,
                                       num_layers=3
                                      )
    estates = tf.concat([estates_f, estates_b], axis=-1)

In [12]:
Ws = tf.get_variable('Ws', shape=[2*enc_hdim, enc_hdim], dtype=tf.float32)
estates = tf.reshape(tf.matmul(tf.reshape(estates, [-1, 2*enc_hdim]), Ws), [-1, L, enc_hdim])

In [13]:
estates


Out[13]:
<tf.Tensor 'Reshape_2:0' shape=(8, 21, 150) dtype=float32>

Decoder


In [14]:
emb_dec_inputs = tf.transpose(emb_dec_inputs, [1,0,2])

In [15]:
with tf.variable_scope('decoder') as scope:
    decoder_outputs, _ = attentive_decoder(estates, batch_size, dec_hdim, L,
                                         inputs=emb_dec_inputs, reuse=False)
    
    tf.get_variable_scope().reuse_variables()
    
    decoder_outputs_inf, _ = attentive_decoder(estates, batch_size, dec_hdim, L,
                                             inputs=emb_dec_inputs,
                                             reuse=True,
                                             feed_previous=True)

Logits and Probabilities


In [16]:
Wo = tf.get_variable('Wo', shape=[dec_hdim, vocab_size], dtype=tf.float32, 
                         initializer=xinit())
bo = tf.get_variable('bo', shape=[vocab_size], dtype=tf.float32, 
                         initializer=xinit())
proj_outputs = tf.matmul(tf.reshape(decoder_outputs, [-1, dec_hdim]), Wo) + bo
proj_outputs_inf = tf.matmul(tf.reshape(decoder_outputs_inf, [-1, dec_hdim]), Wo) + bo

In [17]:
logits = tf.cond(tf.random_normal(shape=()) > 0.,
    lambda : tf.reshape(proj_outputs, [batch_size, L, vocab_size]),
    lambda : tf.reshape(proj_outputs_inf, [batch_size, L, vocab_size])
                )

In [18]:
probs = tf.nn.softmax(tf.reshape(proj_outputs_inf, [batch_size, L, vocab_size]))

Loss


In [19]:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    logits =  logits,
    labels = targets)
# apply mask
masked_cross_entropy = cross_entropy * padding_mask
# average across sequence, batch
loss = tf.reduce_mean(masked_cross_entropy)

Optimization


In [20]:
optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)

Inference


In [21]:
prediction = tf.argmax(probs, axis=-1)

TRAINING


In [22]:
config = tf.ConfigProto(allow_soft_placement = True)
sess = tf.InteractiveSession(config = config)
sess.run(tf.global_variables_initializer())

Training parameters


In [23]:
num_epochs = 20

Start Training


In [25]:
for i in range(num_epochs):
    avg_loss = 0.
    for j in range(len(idx_q)//B):
        _, loss_v = sess.run([train_op, loss], feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            targets : idx_a[j*B:(j+1)*B]
        })
        avg_loss += loss_v
        if j and j%30==0:
            print('{}.{} : {}'.format(i,j,avg_loss/30))
            avg_loss = 0.


0.30 : 4.484886995951334
0.60 : 4.475595307350159
0.90 : 4.349940570195516
0.120 : 3.9550580898920695
0.150 : 4.015570902824402
0.180 : 3.972534743944804
0.210 : 3.762335006395976
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-25-b3f69adef234> in <module>()
      4         _, loss_v = sess.run([train_op, loss], feed_dict = {
      5             inputs : idx_q[j*B:(j+1)*B],
----> 6             targets : idx_a[j*B:(j+1)*B]
      7         })
      8         avg_loss += loss_v

/home/suriyadeepan/Desktop/env/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    776     try:
    777       result = self._run(None, fetches, feed_dict, options_ptr,
--> 778                          run_metadata_ptr)
    779       if run_metadata:
    780         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/suriyadeepan/Desktop/env/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    980     if final_fetches or final_targets:
    981       results = self._do_run(handle, final_targets, final_fetches,
--> 982                              feed_dict_string, options, run_metadata)
    983     else:
    984       results = []

/home/suriyadeepan/Desktop/env/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1030     if handle is None:
   1031       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,
-> 1032                            target_list, options, run_metadata)
   1033     else:
   1034       return self._do_call(_prun_fn, self._session, handle, feed_dict,

/home/suriyadeepan/Desktop/env/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1037   def _do_call(self, fn, *args):
   1038     try:
-> 1039       return fn(*args)
   1040     except errors.OpError as e:
   1041       message = compat.as_text(e.message)

/home/suriyadeepan/Desktop/env/tf/lib/python3.5/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
   1019         return tf_session.TF_Run(session, options,
   1020                                  feed_dict, fetch_list, target_list,
-> 1021                                  status, run_metadata)
   1022 
   1023     def _prun_fn(session, handle, feed_dict, fetch_list):

KeyboardInterrupt: 

Compare CE and masked CE


In [ ]:
ce, mce = sess.run([cross_entropy, masked_cross_entropy], feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            targets : idx_a[j*B:(j+1)*B]
        })

In [ ]:
ce[0], mce[0]

Test inference


In [ ]:
j = 117
pred_v = sess.run(prediction, feed_dict = {
            inputs : idx_q[j*B:(j+1)*B],
            #targets : idx_a[j*B:(j+1)*B]
        })

In [ ]:
pred_v[0], idx_a[j*B:(j+1)*B][0]

In [ ]:
def arr2sent(arr):
    return ' '.join([i2w[item] for item in arr])

In [ ]:
for i in range(B):
    print(arr2sent(pred_v[i]))

In [ ]:
arr2sent( idx_a[j*B:(j+1)*B][11])