In [ ]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import experiment_utils
import lanczos_experiment
import tensorflow_datasets as tfds
In [ ]:
sys.path.insert(0, os.path.abspath("./../jax"))
import density
In [ ]:
COLAB_PATH = '/tmp/spectral-density'
TRAIN_PATH = os.path.join(COLAB_PATH, 'train')
LANCZOS_PATH = os.path.join(COLAB_PATH, 'lanczos')
os.makedirs(TRAIN_PATH)
os.makedirs(LANCZOS_PATH)
IMAGE_SIZE = 28
NUM_CLASSES = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.02
NUM_TRAIN_STEPS = 10000
NUM_SUMMARIZE_STEPS = 1000
NUM_LANCZOS_STEPS = 90
def data_fn(num_epochs=None, shuffle=False, initializable=False):
"""Returns tf.data dataset for MNIST."""
dataset = tfds.load(name="mnist", split=tfds.Split.TRAIN)
dataset = dataset.repeat(num_epochs)
if shuffle:
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.batch(BATCH_SIZE)
if initializable:
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
else:
iterator = dataset.make_one_shot_iterator()
init_op = None
output = iterator.get_next()
images = (tf.to_float(output['image']) - 128) / 128.0
one_hot_labels = tf.one_hot(output['label'], NUM_CLASSES)
return images, one_hot_labels, init_op
def model_fn(features, one_hot_labels):
"""Builds MLP for MNIST and computes loss.
Args:
features: a [batch_size, height, width, channels] float32 tensor.
one_hot_labels: A [batch_size, NUM_CLASSES] int tensor.
Returns:
A scalar loss tensor, and a [batch_size, NUM_CLASSES] prediction tensor.
"""
net = tf.reshape(features, [BATCH_SIZE, IMAGE_SIZE * IMAGE_SIZE])
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
net = tf.layers.dense(net, 256, activation=tf.nn.relu)
net = tf.layers.dense(net, NUM_CLASSES)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
logits=net, labels=one_hot_labels))
return loss, tf.nn.softmax(net)
In [ ]:
tf.reset_default_graph()
images, one_hot_labels, _ = data_fn(num_epochs=None, shuffle=True, initializable=False)
loss, predictions = model_fn(images, one_hot_labels)
accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.math.argmax(predictions, axis=1),
tf.math.argmax(one_hot_labels, axis=1))))
train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss)
saver = tf.train.Saver(max_to_keep=None)
# Simple training loop that saves the model checkpoint every 1000 steps.
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(NUM_TRAIN_STEPS):
if i % NUM_SUMMARIZE_STEPS == 0:
saver.save(sess, os.path.join(TRAIN_PATH, 'model.ckpt'), global_step=i)
outputs = sess.run([loss, train_op])
if i % NUM_SUMMARIZE_STEPS == 0:
print 'Step: ', i, 'Loss: ', outputs[0]
# Save a final checkpoint.
saver.save(sess, os.path.join(TRAIN_PATH, 'model.ckpt'),
global_step=NUM_TRAIN_STEPS)
In [ ]:
# Check that the model fits the training data.
with tf.Session() as sess:
saver.restore(sess, os.path.join(TRAIN_PATH, 'model.ckpt-10000'))
minibatch_accuracy = 0.0
for i in range(100):
minibatch_accuracy += sess.run(accuracy) / 100
print 'Accuracy on training data:', minibatch_accuracy
In [ ]:
tf.reset_default_graph()
checkpoint_to_load = os.path.join(TRAIN_PATH, 'model.ckpt-10000')
# For Lanczos, the tf.data pipeline should have some very specific characteristics:
# 1. It should stop after a single epoch.
# 2. It should be deterministic (i.e., no data augmentation).
# 3. It should be initializable (we use it to restart the pipeline for each Lanczos iteration).
images, one_hot_labels, init = data_fn(num_epochs=1, shuffle=False, initializable=True)
loss, _ = model_fn(images, one_hot_labels)
# Setup for Lanczos mode.
restore_specs = [
experiment_utils.RestoreSpec(tf.trainable_variables(),
checkpoint_to_load)]
# This callback is used to restart the tf.data pipeline for each Lanczos
# iteration on each worker (the chief has a slightly different callback). You
# can check the logs to see the status of the computation: new
# phases of Lanczos iteration are indicated by "New phase i", and local steps
# per worker are logged with "Local step j".
def end_of_input(sess, train_op):
try:
sess.run(train_op)
except tf.errors.OutOfRangeError:
sess.run(init)
return True
return False
# This object stores the state for the phases of the Lanczos iteration.
experiment = lanczos_experiment.LanczosExperiment(
loss,
worker=0, # These two flags will change when the number of workers > 1.
num_workers=1,
save_path=LANCZOS_PATH,
end_of_input=end_of_input,
lanczos_steps=NUM_LANCZOS_STEPS,
num_draws=1,
output_address=LANCZOS_PATH)
# For distributed training, there are a few options:
# Multi-gpu single worker: Partition the tf.data per tower of the model, and pass the aggregate
# loss to the LanczosExperiment class.
# Multi-gpu multi worker: Set num_workers in LanczosExperiment to be equal to the number of workers.
# These have to be ordered.
train_op = experiment.get_train_op()
saver = experiment.get_saver(checkpoint_to_load, restore_specs)
init_fn = experiment.get_init_fn()
train_fn = experiment.get_train_fn()
local_init_op = tf.group(tf.local_variables_initializer(), init)
train_step_kwargs = {}
# The LanczosExperiment class is designed with slim in mind since it gives us
# very specific control of the main training loop.
tf.contrib.slim.learning.train(
train_op,
train_step_kwargs=train_step_kwargs,
train_step_fn=train_fn,
logdir=LANCZOS_PATH,
is_chief=True,
init_fn=init_fn,
local_init_op=local_init_op,
global_step=tf.zeros([], dtype=tf.int64), # Dummy global step.
saver=saver,
save_interval_secs=0, # The LanczosExperiment class controls saving.
summary_op=None, # DANGER DANGER: Do not change this.
summary_writer=None)
# This cell takes a little time to run: maybe 7 mins.
In [ ]:
# Outputs are saved as numpy saved files. The most interesting ones are
# 'tridiag_1' and 'lanczos_vec_1'.
with open(os.path.join(LANCZOS_PATH, 'tridiag_1'), 'rb') as f:
tridiagonal = np.load(f)
# For legacy reasons, we need to squeeze tridiagonal.
tridiagonal = np.squeeze(tridiagonal)
# Note that the output shape is [NUM_LANCZOS_STEPS, NUM_LANCZOS_STEPS].
print tridiagonal.shape
In [ ]:
# The function tridiag_to_density computes the density (i.e., trace estimator
# the standard Gaussian c * exp(-(x - t)**2.0 / 2 sigma**2.0) where t is
# from a uniform grid. Passing a reasonable sigma**2.0 to this function is
# important -- somewhere between 1e-3 and 1e-5 seems to work best.
density, grids = density.tridiag_to_density([tridiagonal])
In [ ]:
# We add a small epsilon to make the plot not ugly.
plt.semilogy(grids, density + 1.0e-7)
plt.xlabel('$\lambda$')
plt.ylabel('Density')
plt.title('MNIST hessian eigenvalue density at step 10000')
Note that this is only one draw so not all the individual peaks are the exact same height, we can make this more accurate by taking more draws.
Exercise left to reader: run multiple draws and see what the density looks like!