In [ ]:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import mnist_utils
import utils

NUM_CLASSES = 10
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
NUM_HIDDEN = 100

NUM_TEST_IMAGES = 10000

def input_fn(mode, hparams):
    """Get input tensors"""

    # Load data into memory
    dataset_path = '../data/mnist/'
    mnist = input_data.read_data_sets(dataset_path, one_hot=True)

    if mode == 'TRAIN':
        min_queue_examples = 20000
        batch_size = hparams.train_batch_size
        input_, target = tf.train.shuffle_batch([mnist.train.images, mnist.train.labels],
                                                batch_size,
                                                capacity=min_queue_examples + 3*batch_size,
                                                min_after_dequeue=min_queue_examples,
                                                enqueue_many=True)        
    elif mode == 'TEST':
        input_, target = tf.train.batch(
            [mnist.test.images, mnist.test.labels], hparams.test_batch_size, enqueue_many=True)

    return input_, target

In [ ]:
class Hparams():
    pass

In [ ]:
def get_loss(input_, logits):
    assert(input_.get_shape() == logits.get_shape())
    all_losses = tf.nn.sigmoid_cross_entropy_with_logits(logits, input_)
    loss = tf.reduce_mean(all_losses, name='loss')
    return loss

In [ ]:
def model_fn(input_, mode, hparams):
    if mode == 'TEST':
        tf.get_variable_scope().reuse_variables()

    A_scale = 1.0 / hparams.num_measurements
    A_init = np.random.normal(size=[784, hparams.num_measurements], scale=A_scale).astype(np.float32)
    A = tf.get_variable('A', initializer=A_init, trainable=hparams.is_A_trainable)
    
    y = tf.matmul(input_, A)
    hidden = y
    prev_hidden_size = hparams.num_measurements
    for i, hidden_size in enumerate(hparams.layer_sizes):
        layer_name = 'hidden{0}'.format(i)
        with tf.variable_scope(layer_name):
            weights = tf.get_variable('weights', shape=[prev_hidden_size, hidden_size])
            biases = tf.get_variable('biases', initializer =tf.zeros([hidden_size]))
            hidden = tf.nn.relu(tf.matmul(hidden, weights) + biases, name=layer_name)
        prev_hidden_size = hidden_size
        
    with tf.variable_scope('sigmoid_logits'):
        weights = tf.get_variable('weights', shape=[prev_hidden_size, 784])
        biases = tf.get_variable('biases', initializer =tf.zeros([784]))
        logits = tf.add(tf.matmul(hidden, weights), biases, name='logits')
        
    prediction = tf.nn.sigmoid(logits)
    loss = get_loss(input_, logits)
        
    return loss, prediction

In [ ]:
hparams = Hparams()
hparams.dataset = 'mnist'
hparams.train_batch_size = 64
hparams.test_batch_size = 10
hparams.num_measurements = 30
hparams.n_input = 784
hparams.layer_sizes = [50, 200]

hparams.learning_rate = 0.001
hparams.momentum = 0.9
hparams.optimizer_type = 'adam'

hparams.measurement_type = 'learned'
hparams.model_types = ['e2e']
hparams.image_matrix = 1

hparams.summary_iter = 20
hparams.checkpoint_iter = 2000

hparams.max_train_epochs = 50
hparams.num_train_examples_per_epoch = 55000
# hparams.max_train_steps = hparams.max_train_epochs * hparams.num_train_examples_per_epoch / hparams.train_batch_size
hparams.max_train_steps = 50000

hparams.is_A_trainable = True

base_dir_pattern = '../optimization/mnist-e2e/{0}/'
dirname = '{0}_{1}_{2}_{3}/'.format(hparams.optimizer_type,
                                    hparams.learning_rate,
                                    hparams.num_measurements,
                                    hparams.is_A_trainable)

hparams.summary_dir = base_dir_pattern.format('summaries') + dirname
hparams.checkpoint_dir = base_dir_pattern.format('checkpoints') + dirname

In [ ]:
tf.reset_default_graph()

In [ ]:
train_input, _ = input_fn('TRAIN', hparams)
test_input, _ = input_fn('TEST', hparams)

In [ ]:
# Set up the model
train_loss, train_prediction = model_fn(train_input, 'TRAIN', hparams)
test_loss, test_prediction = model_fn(test_input, 'TEST', hparams)
train_summary = tf.summary.scalar('train_loss', train_loss)

# Set up the solver
global_step = tf.Variable(0, trainable=False, name='global_step')
opt = utils.get_optimizer(hparams.learning_rate, hparams)
train_op = opt.minimize(train_loss, global_step=global_step, name='train_op')

Train a model


In [ ]:
sess = tf.Session()

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# Summary writer setup
utils.set_up_dir(hparams.summary_dir)
summary_writer = tf.summary.FileWriter(hparams.summary_dir, sess.graph)

# Model checkpointing setup
utils.set_up_dir(hparams.checkpoint_dir)

model_saver = tf.train.Saver(tf.global_variables() + tf.local_variables())

# Load variables from checkpoint or pretrained model
ckpt_path = utils.get_checkpoint_path(hparams.checkpoint_dir)
if ckpt_path:  # if a previous checkpoint exists
    model_saver.restore(sess, ckpt_path)
    ckpt_global_step = int(ckpt_path.split('/')[-1].split('-')[-1])
    print 'Succesfully loaded model from {0} at step = {1}'.format(
        ckpt_path, ckpt_global_step)
    train_start_step = ckpt_global_step
else:
    train_start_step = 0
    init_ops = [tf.global_variables_initializer(), tf.local_variables_initializer()]
    sess.run(init_ops)
        
    
for train_step in range(train_start_step+1, hparams.max_train_steps):
    _ = sess.run([train_op])
 
    if train_step % hparams.summary_iter == 0:
        train_summary_str = sess.run(train_summary)    
        summary_writer.add_summary(train_summary_str, train_step)

    if train_step % 1000 == 0:
        train_loss_val = sess.run([train_loss])
        test_input_val, test_pred_val, test_loss_val = sess.run([test_input, test_prediction, test_loss])
        xs_dict = {i : test_input_val[i, :] for i in range(10)}
        x_hats_dict = {'e2e' : {i : test_pred_val[i, :] for i in range(10)}}
        print 'Train step = {0}, Train Loss = {1}, Test Loss = {2}'.format(train_step, 
                                                                           train_loss_val,
                                                                           test_loss_val)
        print sess.run(tf.global_variables()[7])
        utils.image_matrix(xs_dict, x_hats_dict, mnist_utils.view_image, hparams)    

    # Checkpointing
    if train_step % hparams.checkpoint_iter == 0:
        model_saver.save(sess, hparams.checkpoint_dir + 'snapshot.ckpt', global_step=train_step)

# Final checkpoint
model_saver.save(sess, hparams.checkpoint_dir + 'snapshot.ckpt', global_step=hparams.max_train_steps-1)
        
coord.request_stop()
coord.join(threads)
sess.close()

Load a trained model and run it on test images


In [ ]:
sess = tf.Session()

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

model_saver = tf.train.Saver(tf.global_variables() + tf.local_variables())

# Load variables from checkpoint or pretrained model
ckpt_path = utils.get_checkpoint_path(hparams.checkpoint_dir)
if ckpt_path:  # if a previous checkpoint exists
    model_saver.restore(sess, ckpt_path)
    ckpt_global_step = int(ckpt_path.split('/')[-1].split('-')[-1])
    print 'Succesfully loaded model from {0} at step = {1}'.format(
        ckpt_path, ckpt_global_step)
    train_start_step = ckpt_global_step
else:
    train_start_step = 0
    init_ops = [tf.global_variables_initializer(), tf.local_variables_initializer()]
    sess.run(init_ops)
        
    
for _ in range(1):
    train_loss_val = sess.run(train_loss)
    test_input_val, test_pred_val, test_loss_val = sess.run([test_input, test_prediction, test_loss])
    xs_dict = {i : test_input_val[i, :] for i in range(10)}
    x_hats_dict = {'e2e' : {i : test_pred_val[i, :] for i in range(10)}}
    print 'Train Loss = {0}, Test Loss = {1}'.format(train_loss_val, test_loss_val)
    print sess.run(tf.global_variables()[7])
    utils.image_matrix(xs_dict, x_hats_dict, mnist_utils.view_image, hparams)    
    
coord.request_stop()
coord.join(threads)
sess.close()