In [1]:
"""
a toy implementation of seq2seq by tf v0.11
a translator between two meaningful number sequences, e.g., [1,2,3,4,5] -> [2,3,4,5,6]
"""
import numpy as np
# number of sentences to input in each training epoch
batch_size = 64
# number of words in one sentence
seq_length = 5
# number of possible words
vocab_size = 7
# embedding dimension
embedding_dim = 50
# number of hidden neuron in an rnn cell
memory_dim = 100
def get_train_batch(batch_size):
X = [np.random.choice(vocab_size, size=(seq_length,), replace=False)
for _ in range(batch_size)]
Y = np.mod(X + np.ones_like(X), 7)
# Dimshuffle to seq_len * batch_size
X = np.array(X).T
Y = np.array(Y).T
return X, Y
In [2]:
X, Y = get_train_batch(2)
print("Two data points:")
print(X)
print()
print("Two labels:")
print(Y)
In [3]:
import tensorflow as tf
# dim: seq_length x batch_size
encode_inputs = [tf.placeholder(tf.int32, shape=(None,),
name="inp%i" % t)
for t in range(seq_length)]
# dim: seq_length x batch_size
labels = [tf.placeholder(tf.int32, shape=(None,),
name="labels%i" % t)
for t in range(seq_length)]
# dim: seq_length x batch_size
weights = [tf.ones_like(labels_t, dtype=tf.float32)
for labels_t in labels]
# Decoder input: "GO" + encode_inputs[drop last element]
decode_inputs = [tf.zeros_like(encode_inputs[0], dtype=np.int32, name="GO")] + encode_inputs[:-1]
cell = tf.nn.rnn_cell.GRUCell(memory_dim)
# decode_outputs dim: seq_length x batch_size x vocab size
decode_outputs, dec_memory = tf.nn.seq2seq.embedding_rnn_seq2seq(
encode_inputs, decode_inputs, cell, vocab_size, vocab_size, embedding_dim)
# get the index of the largest number in vocab dimension -> this is going to be the predicted output
prediction = tf.argmax(decode_outputs, 2)
loss = tf.nn.seq2seq.sequence_loss(decode_outputs, labels, weights, vocab_size)
magnitude = tf.sqrt(tf.reduce_sum(tf.square(dec_memory[1])))
learning_rate = 0.05
momentum = 0.9
train_op = tf.train.MomentumOptimizer(learning_rate, momentum).minimize(loss)
In [4]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for t in range(500):
X, Y = get_train_batch(batch_size)
feed_dict = {encode_inputs[t]: X[t] for t in range(seq_length)}
feed_dict.update({labels[t]: Y[t] for t in range(seq_length)})
_, predict_t, loss_t = sess.run([train_op, prediction, loss], feed_dict)
if t%100 == 0:
print('------ epoch', t, '-------')
print('data', X[:,0])
print('label', Y[:,0])
print('predict', predict_t[:,0])
print(loss_t)
print()