In [ ]:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import mnist_utils
import utils
NUM_CLASSES = 10
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
NUM_HIDDEN = 100
NUM_TEST_IMAGES = 10000
def input_fn(mode, hparams):
"""Get input tensors"""
# Load data into memory
dataset_path = '../data/mnist/'
mnist = input_data.read_data_sets(dataset_path, one_hot=True)
if mode == 'TRAIN':
min_queue_examples = 20000
batch_size = hparams.train_batch_size
input_, target = tf.train.shuffle_batch([mnist.train.images, mnist.train.labels],
batch_size,
capacity=min_queue_examples + 3*batch_size,
min_after_dequeue=min_queue_examples,
enqueue_many=True)
elif mode == 'TEST':
input_, target = tf.train.batch(
[mnist.test.images, mnist.test.labels], hparams.test_batch_size, enqueue_many=True)
return input_, target
In [ ]:
class Hparams():
pass
In [ ]:
def get_loss(input_, logits):
assert(input_.get_shape() == logits.get_shape())
all_losses = tf.nn.sigmoid_cross_entropy_with_logits(logits, input_)
loss = tf.reduce_mean(all_losses, name='loss')
return loss
In [ ]:
def model_fn(input_, mode, hparams):
if mode == 'TEST':
tf.get_variable_scope().reuse_variables()
A_scale = 1.0 / hparams.num_measurements
A_init = np.random.normal(size=[784, hparams.num_measurements], scale=A_scale).astype(np.float32)
A = tf.get_variable('A', initializer=A_init, trainable=hparams.is_A_trainable)
y = tf.matmul(input_, A)
hidden = y
prev_hidden_size = hparams.num_measurements
for i, hidden_size in enumerate(hparams.layer_sizes):
layer_name = 'hidden{0}'.format(i)
with tf.variable_scope(layer_name):
weights = tf.get_variable('weights', shape=[prev_hidden_size, hidden_size])
biases = tf.get_variable('biases', initializer =tf.zeros([hidden_size]))
hidden = tf.nn.relu(tf.matmul(hidden, weights) + biases, name=layer_name)
prev_hidden_size = hidden_size
with tf.variable_scope('sigmoid_logits'):
weights = tf.get_variable('weights', shape=[prev_hidden_size, 784])
biases = tf.get_variable('biases', initializer =tf.zeros([784]))
logits = tf.add(tf.matmul(hidden, weights), biases, name='logits')
prediction = tf.nn.sigmoid(logits)
loss = get_loss(input_, logits)
return loss, prediction
In [ ]:
hparams = Hparams()
hparams.dataset = 'mnist'
hparams.train_batch_size = 64
hparams.test_batch_size = 10
hparams.num_measurements = 30
hparams.n_input = 784
hparams.layer_sizes = [50, 200]
hparams.learning_rate = 0.001
hparams.momentum = 0.9
hparams.optimizer_type = 'adam'
hparams.measurement_type = 'learned'
hparams.model_types = ['e2e']
hparams.image_matrix = 1
hparams.summary_iter = 20
hparams.checkpoint_iter = 2000
hparams.max_train_epochs = 50
hparams.num_train_examples_per_epoch = 55000
# hparams.max_train_steps = hparams.max_train_epochs * hparams.num_train_examples_per_epoch / hparams.train_batch_size
hparams.max_train_steps = 50000
hparams.is_A_trainable = True
base_dir_pattern = '../optimization/mnist-e2e/{0}/'
dirname = '{0}_{1}_{2}_{3}/'.format(hparams.optimizer_type,
hparams.learning_rate,
hparams.num_measurements,
hparams.is_A_trainable)
hparams.summary_dir = base_dir_pattern.format('summaries') + dirname
hparams.checkpoint_dir = base_dir_pattern.format('checkpoints') + dirname
In [ ]:
tf.reset_default_graph()
In [ ]:
train_input, _ = input_fn('TRAIN', hparams)
test_input, _ = input_fn('TEST', hparams)
In [ ]:
# Set up the model
train_loss, train_prediction = model_fn(train_input, 'TRAIN', hparams)
test_loss, test_prediction = model_fn(test_input, 'TEST', hparams)
train_summary = tf.summary.scalar('train_loss', train_loss)
# Set up the solver
global_step = tf.Variable(0, trainable=False, name='global_step')
opt = utils.get_optimizer(hparams.learning_rate, hparams)
train_op = opt.minimize(train_loss, global_step=global_step, name='train_op')
In [ ]:
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# Summary writer setup
utils.set_up_dir(hparams.summary_dir)
summary_writer = tf.summary.FileWriter(hparams.summary_dir, sess.graph)
# Model checkpointing setup
utils.set_up_dir(hparams.checkpoint_dir)
model_saver = tf.train.Saver(tf.global_variables() + tf.local_variables())
# Load variables from checkpoint or pretrained model
ckpt_path = utils.get_checkpoint_path(hparams.checkpoint_dir)
if ckpt_path: # if a previous checkpoint exists
model_saver.restore(sess, ckpt_path)
ckpt_global_step = int(ckpt_path.split('/')[-1].split('-')[-1])
print 'Succesfully loaded model from {0} at step = {1}'.format(
ckpt_path, ckpt_global_step)
train_start_step = ckpt_global_step
else:
train_start_step = 0
init_ops = [tf.global_variables_initializer(), tf.local_variables_initializer()]
sess.run(init_ops)
for train_step in range(train_start_step+1, hparams.max_train_steps):
_ = sess.run([train_op])
if train_step % hparams.summary_iter == 0:
train_summary_str = sess.run(train_summary)
summary_writer.add_summary(train_summary_str, train_step)
if train_step % 1000 == 0:
train_loss_val = sess.run([train_loss])
test_input_val, test_pred_val, test_loss_val = sess.run([test_input, test_prediction, test_loss])
xs_dict = {i : test_input_val[i, :] for i in range(10)}
x_hats_dict = {'e2e' : {i : test_pred_val[i, :] for i in range(10)}}
print 'Train step = {0}, Train Loss = {1}, Test Loss = {2}'.format(train_step,
train_loss_val,
test_loss_val)
print sess.run(tf.global_variables()[7])
utils.image_matrix(xs_dict, x_hats_dict, mnist_utils.view_image, hparams)
# Checkpointing
if train_step % hparams.checkpoint_iter == 0:
model_saver.save(sess, hparams.checkpoint_dir + 'snapshot.ckpt', global_step=train_step)
# Final checkpoint
model_saver.save(sess, hparams.checkpoint_dir + 'snapshot.ckpt', global_step=hparams.max_train_steps-1)
coord.request_stop()
coord.join(threads)
sess.close()
In [ ]:
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
model_saver = tf.train.Saver(tf.global_variables() + tf.local_variables())
# Load variables from checkpoint or pretrained model
ckpt_path = utils.get_checkpoint_path(hparams.checkpoint_dir)
if ckpt_path: # if a previous checkpoint exists
model_saver.restore(sess, ckpt_path)
ckpt_global_step = int(ckpt_path.split('/')[-1].split('-')[-1])
print 'Succesfully loaded model from {0} at step = {1}'.format(
ckpt_path, ckpt_global_step)
train_start_step = ckpt_global_step
else:
train_start_step = 0
init_ops = [tf.global_variables_initializer(), tf.local_variables_initializer()]
sess.run(init_ops)
for _ in range(1):
train_loss_val = sess.run(train_loss)
test_input_val, test_pred_val, test_loss_val = sess.run([test_input, test_prediction, test_loss])
xs_dict = {i : test_input_val[i, :] for i in range(10)}
x_hats_dict = {'e2e' : {i : test_pred_val[i, :] for i in range(10)}}
print 'Train Loss = {0}, Test Loss = {1}'.format(train_loss_val, test_loss_val)
print sess.run(tf.global_variables()[7])
utils.image_matrix(xs_dict, x_hats_dict, mnist_utils.view_image, hparams)
coord.request_stop()
coord.join(threads)
sess.close()