In [1]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer as xinit
from recurrence import *
EOS (end of sentence) token to data (target sequence)
In [2]:
import data_utils
metadata, idx_q, idx_a = data_utils.load_data('../data/')
In [3]:
# add special symbol
i2w = metadata['idx2w']
w2i = metadata['w2idx']
In [26]:
B = 8
L = len(idx_q[0])
vocab_size = len(i2w)
enc_hdim = 150
dec_hdim = enc_hdim
In [5]:
tf.reset_default_graph()
In [6]:
inputs = tf.placeholder(tf.int32, shape=[None,L], name='inputs')
targets = tf.placeholder(tf.int32, shape=[None,L], name='targets')
go_token = tf.reshape(tf.fill(tf.shape(inputs[:,0]), w2i['GO']), [-1,1])
decoder_inputs = tf.concat(
values=[go_token, targets[:, 1:]],
axis=1)
training = tf.placeholder(tf.bool, name='is_training')
batch_size = tf.shape(inputs)[0] # infer batch size
In [7]:
decoder_inputs, targets, inputs, batch_size
Out[7]:
In [8]:
# Encoder
# Get sequence lengths
def seq_len(t):
return tf.reduce_sum(tf.cast(t>0, tf.int32), axis=1)
enc_seq_len = seq_len(inputs)
In [9]:
# Decoder
# Mask targets
# to be applied to cross entropy loss
padding_mask = tf.cast(targets > 0, tf.float32 )
In [10]:
emb_mat = tf.get_variable('emb', shape=[vocab_size, enc_hdim], dtype=tf.float32,
initializer=xinit())
emb_enc_inputs = tf.nn.embedding_lookup(emb_mat, inputs)
emb_dec_inputs = tf.nn.embedding_lookup(emb_mat, decoder_inputs)
In [11]:
with tf.variable_scope('encoder'):
(estates_f, estates_b), _ = bi_net(gru_n(num_layers=3, num_units=enc_hdim),
gru_n(num_layers=3, num_units=enc_hdim),
emb_enc_inputs,
batch_size=B,
timesteps=L,
num_layers=3
)
estates = tf.concat([estates_f, estates_b], axis=-1)
In [12]:
Ws = tf.get_variable('Ws', shape=[2*enc_hdim, enc_hdim], dtype=tf.float32)
estates = tf.reshape(tf.matmul(tf.reshape(estates, [-1, 2*enc_hdim]), Ws), [-1, L, enc_hdim])
In [13]:
estates
Out[13]:
In [14]:
emb_dec_inputs = tf.transpose(emb_dec_inputs, [1,0,2])
In [15]:
with tf.variable_scope('decoder') as scope:
decoder_outputs, _ = attentive_decoder(estates, batch_size, dec_hdim, L,
inputs=emb_dec_inputs, reuse=False)
tf.get_variable_scope().reuse_variables()
decoder_outputs_inf, _ = attentive_decoder(estates, batch_size, dec_hdim, L,
inputs=emb_dec_inputs,
reuse=True,
feed_previous=True)
In [16]:
Wo = tf.get_variable('Wo', shape=[dec_hdim, vocab_size], dtype=tf.float32,
initializer=xinit())
bo = tf.get_variable('bo', shape=[vocab_size], dtype=tf.float32,
initializer=xinit())
proj_outputs = tf.matmul(tf.reshape(decoder_outputs, [-1, dec_hdim]), Wo) + bo
proj_outputs_inf = tf.matmul(tf.reshape(decoder_outputs_inf, [-1, dec_hdim]), Wo) + bo
In [17]:
logits = tf.cond(tf.random_normal(shape=()) > 0.,
lambda : tf.reshape(proj_outputs, [batch_size, L, vocab_size]),
lambda : tf.reshape(proj_outputs_inf, [batch_size, L, vocab_size])
)
In [18]:
probs = tf.nn.softmax(tf.reshape(proj_outputs_inf, [batch_size, L, vocab_size]))
In [19]:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits = logits,
labels = targets)
# apply mask
masked_cross_entropy = cross_entropy * padding_mask
# average across sequence, batch
loss = tf.reduce_mean(masked_cross_entropy)
In [20]:
optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss)
In [21]:
prediction = tf.argmax(probs, axis=-1)
In [22]:
config = tf.ConfigProto(allow_soft_placement = True)
sess = tf.InteractiveSession(config = config)
sess.run(tf.global_variables_initializer())
In [23]:
num_epochs = 20
In [25]:
for i in range(num_epochs):
avg_loss = 0.
for j in range(len(idx_q)//B):
_, loss_v = sess.run([train_op, loss], feed_dict = {
inputs : idx_q[j*B:(j+1)*B],
targets : idx_a[j*B:(j+1)*B]
})
avg_loss += loss_v
if j and j%30==0:
print('{}.{} : {}'.format(i,j,avg_loss/30))
avg_loss = 0.
In [ ]:
ce, mce = sess.run([cross_entropy, masked_cross_entropy], feed_dict = {
inputs : idx_q[j*B:(j+1)*B],
targets : idx_a[j*B:(j+1)*B]
})
In [ ]:
ce[0], mce[0]
In [ ]:
j = 117
pred_v = sess.run(prediction, feed_dict = {
inputs : idx_q[j*B:(j+1)*B],
#targets : idx_a[j*B:(j+1)*B]
})
In [ ]:
pred_v[0], idx_a[j*B:(j+1)*B][0]
In [ ]:
def arr2sent(arr):
return ' '.join([i2w[item] for item in arr])
In [ ]:
for i in range(B):
print(arr2sent(pred_v[i]))
In [ ]:
arr2sent( idx_a[j*B:(j+1)*B][11])