LSTM CA

This version builds a 1d board with random stones in it and uses one layer of LSTM to unroll one stone at the time. Targets are one_hot encoded.

TODO: Use Tensorflow sequence to sequence


In [ ]:
import tensorflow as tf
import numpy as np

from board_utils import generate_1d_data
from board_utils import build_1d_datasets

def reset_graph():
    '''
    reset graph so we can run it multiple times without 
    duplicating variables (causing nasty errors)
    '''
    if 'sess' in globals() and sess:
        sess.close()  # Close any open session
    tf.reset_default_graph()  # clear graph stack

def build_graph(state_size=4,  #?
                n_length=8,
                num_hidden=32,
                learning_rate=1e-3,
                batch_size=32,
                num_classes=8,
                num_steps=8):
    '''
    Constructs the tensorflow graph
    
    Returns:
        dictionary of graph components 
        
    TODO: translate into model class
    '''

    reset_graph()
    
    '''
    Placeholders
    '''
    
    x = tf.placeholder(tf.int32, [None, n_length])
    y = tf.placeholder(tf.int32, [None, n_length])

    '''
    Inputs
    '''
    
    # Translates placeholders into [batch, time, state_size] tensors
    embeddings = tf.get_variable('embedding_matrix', [n_length, state_size])
    rnn_inputs = tf.nn.embedding_lookup(embeddings, x)  # get correct shape of input
#     print(rnn_inputs)  # DEBUG

    '''
    RNN
    '''
    
    cell = tf.nn.rnn_cell.LSTMCell(state_size)
    init_state = cell.zero_state(batch_size, tf.float32)

    # Add rnn_cells to graph
    # dynamic_rnn wants input shape [batch, time, state_size]
    # alternatively use nested tensors of [[batch, time]]
    rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state)
    
    '''
    Prediction loss and optimize
    '''
    
    with tf.variable_scope('softmax'):
        W = tf.get_variable('W', [state_size, num_classes])
        b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))
    
    # reshape rnn_outputs so a single matmul is possible
    # shape is now back to [batch, time]
    rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
    y_reshaped = tf.reshape(y, [-1])
    logits = tf.matmul(rnn_outputs, W) + b
    
#     logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]
#     predictions = [tf.nn.softmax(logit) for logit in logits]
#     y_as_list = [tf.squeeze(i, squeeze_dims=[1]) for i in tf.split(1, num_steps, y)]

    # sequence_loss_by_example allow us to specify parts of the sequence to generate
    # loss, currently we use all ones for full sequence
    loss_weights = [tf.ones([batch_size]) for i in range(num_steps)]
    losses = tf.nn.seq2seq.sequence_loss_by_example(logits, y_reshaped, loss_weights)
    total_loss = tf.reduce_mean(losses)
    optimize = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

    return dict(
        x = x,
        y = y,
        init_state = init_state,
        final_state = final_state,
        total_loss = total_loss,
        optimize = optimize)
    
def train_network(g,
                  num_epochs,
                  num_steps,
                  batch_size=32,
                  verbose=True):
    '''
    Function to train a graph
    
    Currently bogus stats that measures nothing
    '''

    data, labels = generate_1d_data(8, one_hot=True)
    datasets = build_1d_datasets(data, labels)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        total_loss = 0
        for step in range(num_steps):
            xf, yf = datasets.train.next_batch(batch_size)

            _, training_loss = sess.run([g['optimize'], g['total_loss']], 
                         feed_dict={g['x']: xf, g['y']: yf})
            
            total_loss += training_loss
            if step % 50 == 0 and step > 0:
                if verbose:
                    print('step:{}, training loss per step:{:.3}'.format(step, training_loss / step))

# Run tests
g = build_graph()
train_network(g, 1, 1000)

In [ ]:
'''
Debugging input and output shapes
'''

reset_graph()

x = tf.placeholder(tf.int32, [None, 8])
y = tf.placeholder(tf.int32, [None, 8])
embeddings = tf.get_variable('emb_matrix', [8, 4])
rnn_inputs = tf.nn.embedding_lookup(embeddings, x)
rnn_inputs

cell = tf.nn.rnn_cell.LSTMCell(4)
init_state = cell.zero_state(32, tf.float32)
rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=init_state)
print(rnn_outputs)
rnn_outputs = tf.reshape(rnn_outputs, [-1, 4])
y_reshaped = tf.reshape(y, [-1])

rnn_outputs
y_reshaped