MNIST Hessian Spectral Density Calculator

This notebook trains a simple MLP for MNIST, runs the Lanczos algorithm on its full-batch Hessian, and then plots the spectral density. This shows how to use the python TensorFlow LanczosExperiment class.


In [ ]:
import os
import sys

import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

import experiment_utils
import lanczos_experiment
import tensorflow_datasets as tfds

In [ ]:
sys.path.insert(0, os.path.abspath("./../jax"))
import density

In [ ]:
COLAB_PATH = '/tmp/spectral-density'
TRAIN_PATH = os.path.join(COLAB_PATH, 'train')
LANCZOS_PATH = os.path.join(COLAB_PATH, 'lanczos')

os.makedirs(TRAIN_PATH)
os.makedirs(LANCZOS_PATH)

IMAGE_SIZE = 28
NUM_CLASSES = 10

BATCH_SIZE = 32
LEARNING_RATE = 0.02

NUM_TRAIN_STEPS = 10000
NUM_SUMMARIZE_STEPS = 1000
NUM_LANCZOS_STEPS = 90

def data_fn(num_epochs=None, shuffle=False, initializable=False):
  """Returns tf.data dataset for MNIST."""
  dataset = tfds.load(name="mnist", split=tfds.Split.TRAIN)
  dataset = dataset.repeat(num_epochs)
  
  if shuffle:
    dataset = dataset.shuffle(buffer_size=1024)
  dataset = dataset.batch(BATCH_SIZE)

  if initializable:    
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
  else:
    iterator = dataset.make_one_shot_iterator()
    init_op = None
    
  output = iterator.get_next() 
  images = (tf.to_float(output['image']) - 128) / 128.0
  one_hot_labels = tf.one_hot(output['label'], NUM_CLASSES)  
  return images, one_hot_labels, init_op

def model_fn(features, one_hot_labels):
  """Builds MLP for MNIST and computes loss.

  Args:
    features: a [batch_size, height, width, channels] float32 tensor.
    one_hot_labels: A [batch_size, NUM_CLASSES] int tensor.
    
  Returns:
    A scalar loss tensor, and a [batch_size, NUM_CLASSES] prediction tensor.
  """
  net = tf.reshape(features, [BATCH_SIZE, IMAGE_SIZE * IMAGE_SIZE])
  net = tf.layers.dense(net, 256, activation=tf.nn.relu)
  net = tf.layers.dense(net, 256, activation=tf.nn.relu)
  net = tf.layers.dense(net, NUM_CLASSES)
  
  loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
      logits=net, labels=one_hot_labels))
    
  return loss, tf.nn.softmax(net)

Train a MNIST model.


In [ ]:
tf.reset_default_graph()

images, one_hot_labels, _ = data_fn(num_epochs=None, shuffle=True, initializable=False) 

loss, predictions = model_fn(images, one_hot_labels)

accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.math.argmax(predictions, axis=1),
                          tf.math.argmax(one_hot_labels, axis=1))))

train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss)
saver = tf.train.Saver(max_to_keep=None)

# Simple training loop that saves the model checkpoint every 1000 steps.
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  
  for i in range(NUM_TRAIN_STEPS):
    if i % NUM_SUMMARIZE_STEPS == 0:
      saver.save(sess, os.path.join(TRAIN_PATH, 'model.ckpt'), global_step=i)
    
    outputs = sess.run([loss, train_op])
    
    if i % NUM_SUMMARIZE_STEPS == 0:
      print 'Step: ', i, 'Loss: ', outputs[0] 
  
  # Save a final checkpoint.
  saver.save(sess, os.path.join(TRAIN_PATH, 'model.ckpt'), 
             global_step=NUM_TRAIN_STEPS)

In [ ]:
# Check that the model fits the training data.
with tf.Session() as sess:
  saver.restore(sess, os.path.join(TRAIN_PATH, 'model.ckpt-10000'))
  
  minibatch_accuracy = 0.0
  for i in range(100):
    minibatch_accuracy += sess.run(accuracy) / 100
    
print 'Accuracy on training data:',  minibatch_accuracy

Run Lanczos on the MNIST model.


In [ ]:
tf.reset_default_graph()

checkpoint_to_load = os.path.join(TRAIN_PATH, 'model.ckpt-10000')

# For Lanczos, the tf.data pipeline should have some very specific characteristics:
# 1. It should stop after a single epoch.
# 2. It should be deterministic (i.e., no data augmentation).
# 3. It should be initializable (we use it to restart the pipeline for each Lanczos iteration).
images, one_hot_labels, init = data_fn(num_epochs=1, shuffle=False, initializable=True)

loss, _ = model_fn(images, one_hot_labels)

# Setup for Lanczos mode.
restore_specs = [
    experiment_utils.RestoreSpec(tf.trainable_variables(),
                                 checkpoint_to_load)]

# This callback is used to restart the tf.data pipeline for each Lanczos
# iteration on each worker (the chief has a slightly different callback). You 
# can check the logs to see the status of the computation: new 
# phases of Lanczos iteration are indicated by "New phase i", and local steps 
# per worker are logged with "Local step j".
def end_of_input(sess, train_op):
  try:
    sess.run(train_op)
  except tf.errors.OutOfRangeError:
    sess.run(init)
    return True
  return False

# This object stores the state for the phases of the Lanczos iteration.
experiment = lanczos_experiment.LanczosExperiment(
    loss, 
    worker=0,  # These two flags will change when the number of workers > 1.
    num_workers=1,
    save_path=LANCZOS_PATH, 
    end_of_input=end_of_input,
    lanczos_steps=NUM_LANCZOS_STEPS,
    num_draws=1,
    output_address=LANCZOS_PATH)

# For distributed training, there are a few options:
# Multi-gpu single worker: Partition the tf.data per tower of the model, and pass the aggregate
#   loss to the LanczosExperiment class.
# Multi-gpu multi worker: Set num_workers in LanczosExperiment to be equal to the number of workers.

# These have to be ordered.
train_op = experiment.get_train_op()
saver = experiment.get_saver(checkpoint_to_load, restore_specs)
init_fn = experiment.get_init_fn()
train_fn = experiment.get_train_fn()
local_init_op = tf.group(tf.local_variables_initializer(), init)

train_step_kwargs = {}

# The LanczosExperiment class is designed with slim in mind since it gives us
# very specific control of the main training loop.
tf.contrib.slim.learning.train(
    train_op,
    train_step_kwargs=train_step_kwargs,
    train_step_fn=train_fn,
    logdir=LANCZOS_PATH,
    is_chief=True,
    init_fn=init_fn,
    local_init_op=local_init_op,
    global_step=tf.zeros([], dtype=tf.int64),  # Dummy global step.
    saver=saver,
    save_interval_secs=0,  # The LanczosExperiment class controls saving.
    summary_op=None,  # DANGER DANGER: Do not change this.
    summary_writer=None)

# This cell takes a little time to run: maybe 7 mins.

Visualize the Hessian eigenvalue density.


In [ ]:
# Outputs are saved as numpy saved files. The most interesting ones are 
# 'tridiag_1' and 'lanczos_vec_1'.
with open(os.path.join(LANCZOS_PATH, 'tridiag_1'), 'rb') as f:
  tridiagonal = np.load(f)

  # For legacy reasons, we need to squeeze tridiagonal.
  tridiagonal = np.squeeze(tridiagonal)
  # Note that the output shape is [NUM_LANCZOS_STEPS, NUM_LANCZOS_STEPS].
  print tridiagonal.shape

In [ ]:
# The function tridiag_to_density computes the density (i.e., trace estimator 
# the standard Gaussian c * exp(-(x - t)**2.0 / 2 sigma**2.0) where t is 
# from a uniform grid. Passing a reasonable sigma**2.0 to this function is 
# important -- somewhere between 1e-3 and 1e-5 seems to work best.
density, grids = density.tridiag_to_density([tridiagonal])

In [ ]:
# We add a small epsilon to make the plot not ugly.
plt.semilogy(grids, density + 1.0e-7)
plt.xlabel('$\lambda$')
plt.ylabel('Density')
plt.title('MNIST hessian eigenvalue density at step 10000')

Note that this is only one draw so not all the individual peaks are the exact same height, we can make this more accurate by taking more draws.

Exercise left to reader: run multiple draws and see what the density looks like!