In [14]:
import numpy as np, random
import tensorflow as tf

In [2]:
def as_bytes(n, final_size):
    res = []
    for _ in range(final_size):
        res.append(n%2)
        n = n//2
    return res

In [3]:
def generate_example(num_bits):
    a = random.randint(0, 2**(num_bits-1)-1)
    b = random.randint(0, 2**(num_bits-1)-1)
    c = a+b
    return (as_bytes(a, num_bits), as_bytes(b, num_bits),
            as_bytes(c, num_bits))

In [4]:
def generate_batch(num_bits, batch_size):
    x = np.empty([num_bits, batch_size, 2])
    y = np.empty([num_bits, batch_size, 1])
    
    for i in range(batch_size):
        a, b, s = generate_example(num_bits)
        
        x[:,i,0] = a
        x[:,i,1] = b
        y[:,i,0] = s
    return x, y

In [41]:
#####################################################################
#############             Graph Definition             ##############
#####################################################################
LOGDIR = '/home/louis/python/notebooks/.tensorflow_logs_dir/'

with tf.Graph().as_default() as graph:
    INPUT_SIZE = 2
    OUTPUT_SIZE = 1
    RNN_HIDDEN = 5
    LEARNING_RATE = 0.01
    

    # Definition of the inputs and outputs
    inputs = tf.placeholder(tf.float32, (None, None, INPUT_SIZE))
    labels = tf.placeholder(tf.float32, (None, None, OUTPUT_SIZE))

    # Definition of the cell
    cell = tf.contrib.rnn.BasicLSTMCell(num_units=RNN_HIDDEN, state_is_tuple=True)

    # Definition of the initial state
    batch_size = tf.shape(inputs)[1]
    initial_state = cell.zero_state(batch_size, tf.float32)

    # Computation of the outputs and states
    with tf.name_scope('states'):
        rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, inputs, dtype=np.float32, 
                                                    initial_state=initial_state, 
                                                    time_major=True)
        _ = tf.summary.histogram('lstm_states_histo', rnn_states)

    # Definition of the outputs
    final_projection = lambda x: tf.contrib.layers.linear(x, num_outputs=OUTPUT_SIZE, 
                                                          activation_fn=tf.nn.sigmoid)

    # Application of final projection to the outputs
    logits = tf.map_fn(final_projection, rnn_outputs)

    # Computation of the loss
    loss = -(labels*tf.log(logits) + (1.0 - labels)*tf.log(1.0 - logits))
    with tf.name_scope('loss'):
        loss = tf.reduce_mean(loss)
        _ = tf.summary.scalar('loss_function', loss)

    # train_optimizer
    train_optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(loss)

    # For validation purpose
    accuracy = tf.reduce_mean(tf.cast(abs(logits - labels) < 0.5, tf.float32))
    
    #### Summaries ####
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter(LOGDIR, graph)

In [42]:
###########################################################################
########                         Training Loop                     ########
###########################################################################

NUM_EPOCHS = 1000
ITERATONS_PER_EPOCH = 100
NUM_BITS = 10
BATCH_SIZE = 16

valid_x, valid_y = generate_batch(num_bits=NUM_BITS, batch_size=100)

with tf.Session(graph=graph) as session:
    session.run(tf.global_variables_initializer())
    for i in range(NUM_EPOCHS):
        epoch_loss = 0
        for j in range(ITERATONS_PER_EPOCH):
            x, y = generate_batch(num_bits=NUM_BITS, batch_size=BATCH_SIZE, )
            loss_summaries = []
            epoch_loss, _, summaries = session.run([loss, train_optimizer, merged], 
                                        feed_dict={inputs:x, 
                                                   labels:y})
            
            # Summaries
            ind = i*ITERATONS_PER_EPOCH + j
            if ind % 10 == 0: writer.add_summary(summaries, ind)
            
        epoch_loss /= ITERATONS_PER_EPOCH
        valid_accuracy = session.run([accuracy], 
                                     feed_dict={inputs:valid_x, labels:valid_y})[0]
        print('Iteration : %d, Epoch Loss = %.6f' % (i, epoch_loss))
        print('Accuracy = %.1f' % (valid_accuracy*100.))
        if valid_accuracy == 1: break


Iteration : 0, Epoch Loss = 0.006791
Accuracy = 57.9
Iteration : 1, Epoch Loss = 0.003834
Accuracy = 87.9
Iteration : 2, Epoch Loss = 0.000703
Accuracy = 100.0

In [ ]: