MNIST Visualization Example

Real-time visualization of MNIST training on a CNN, using TensorFlow and TensorDebugger

The visualizations in this notebook won't show up on http://nbviewer.ipython.org. To view the widgets and interact with them, you will need to download this notebook and run it with a Jupyter Notebook server.

Step 1: Load TDB Notebook Extension


In [2]:
%%javascript
Jupyter.utils.load_extensions('tdb_ext/main')



In [1]:
#import sys
#sys.path.append('/home/evjang/thesis/tensor_debugger')
import tdb
from tdb.examples import mnist, viz
import matplotlib.pyplot as plt
import tensorflow as tf
import urllib

Step 2: Build TensorFlow Model


In [3]:
(train_data_node,
    train_labels_node,
    validation_data_node,
    test_data_node,
    # predictions
    train_prediction,
    validation_prediction,
    test_prediction,
    # weights
    conv1_weights,
    conv2_weights,
    fc1_weights,
    fc2_weights,
    # training
    optimizer,
    loss,
    learning_rate,
    summaries) = mnist.build_model()

Step 3: Attach Plotting Ops


In [4]:
def viz_activations(ctx, m):
    plt.matshow(m.T,cmap=plt.cm.gray)
    plt.title("LeNet Predictions")
    plt.xlabel("Batch")
    plt.ylabel("Digit Activation")

In [5]:
# plotting a user-defined function 'viz_activations'
p0=tdb.plot_op(viz_activations,inputs=[train_prediction])
# weight variables are of type tf.Variable, so we need to find the corresponding tf.Tensor instead
g=tf.get_default_graph()
p1=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv1_weights)])
p2=tdb.plot_op(viz.viz_conv_weights,inputs=[g.as_graph_element(conv2_weights)])
p3=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc1_weights)])
p4=tdb.plot_op(viz.viz_fc_weights,inputs=[g.as_graph_element(fc2_weights)])
p2=tdb.plot_op(viz.viz_conv_hist,inputs=[g.as_graph_element(conv1_weights)])
ploss=tdb.plot_op(viz.watch_loss,inputs=[loss])

Step 4: Download the MNIST dataset


In [6]:
base_url='http://yann.lecun.com/exdb/mnist/'
files=['train-images-idx3-ubyte.gz',
 'train-labels-idx1-ubyte.gz',
 't10k-images-idx3-ubyte.gz',
 't10k-labels-idx1-ubyte.gz']
download_dir='/tmp/'
for f in files:
    print(f)
    urllib.urlretrieve(base_url+f, download_dir+f)


train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz

Step 5: Debug + Visualize!

Upon evaluating plot nodes p1,p2,p3,p4,ploss, plots will be generated in the Plot view on the right.


In [7]:
# return the TF nodes corresponding to graph input placeholders
(train_data, 
 train_labels, 
 validation_data, 
 validation_labels, 
 test_data, 
 test_labels) = mnist.get_data(download_dir)


('Extracting', '/tmp/train-images-idx3-ubyte.gz')
('Extracting', '/tmp/train-labels-idx1-ubyte.gz')
('Extracting', '/tmp/t10k-images-idx3-ubyte.gz')
('Extracting', '/tmp/t10k-labels-idx1-ubyte.gz')

In [8]:
# start the TensorFlow session that will be used to evaluate the graph
s=tf.InteractiveSession()
tf.initialize_all_variables().run()

In [9]:
BATCH_SIZE = 64
NUM_EPOCHS = 5
TRAIN_SIZE=10000

for step in xrange(NUM_EPOCHS * TRAIN_SIZE // BATCH_SIZE):
    offset = (step * BATCH_SIZE) % (TRAIN_SIZE - BATCH_SIZE)
    batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]
    batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
    feed_dict = {
        train_data_node: batch_data,
        train_labels_node: batch_labels
    }
    # run training node and visualization node
    status,result=tdb.debug([optimizer,p0], feed_dict=feed_dict, session=s)
    if step % 10 == 0:  
        status,result=tdb.debug([loss,p1,p2,p3,p4,ploss], feed_dict=feed_dict, breakpoints=None, break_immediately=False, session=s)
        print('loss: %f' % (result[0]))


loss: 29.668428
loss: 15.983353
loss: 11.249242
loss: 10.028939
loss: 8.065391
loss: 9.335689
loss: 7.316875
loss: 8.376289
loss: 7.735221
loss: 8.383675
loss: 5.704120
loss: 6.037778
loss: 7.309663
loss: 7.349874
loss: 7.528041
loss: 8.209503
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-9-ccc202b40680> in <module>()
     12     }
     13     # run training node and visualization node
---> 14     status,result=tdb.debug([optimizer,p0], feed_dict=feed_dict, session=s)
     15     if step % 10 == 0:
     16         status,result=tdb.debug([loss,p1,p2,p3,p4,ploss], feed_dict=feed_dict, breakpoints=None, break_immediately=False, session=s)

/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/interface.pyc in debug(evals, feed_dict, breakpoints, break_immediately, session)
     15         global _dbsession
     16         _dbsession=debug_session.DebugSession(session)
---> 17         return _dbsession.run(evals,feed_dict,breakpoints,break_immediately)
     18 
     19 def s():

/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc in run(self, evals, feed_dict, breakpoints, break_immediately)
     59                         return self._break()
     60                 else:
---> 61                         return self.c()
     62 
     63         def s(self):

/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc in c(self)
     85 
     86                 self.state = RUNNING
---> 87                 self._eval(node)
     88                 # increment to next node
     89                 self.step=i+1

/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tdb/debug_session.pyc in _eval(self, node)
    168                 else: # is a TensorFlow node
    169                         if isinstance(node,tf.Tensor):
--> 170                                 result=self.session.run(node,self._cache)
    171                                 self._cache[node.name]=result
    172                         else:

/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict)
    346     # Run request and get response.
    347     #pdb.set_trace()
--> 348     results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
    349 
    350     # User may have fetched the same tensor multiple times, but we

/gpfs/main/sys/shared/psfu/contrib/projects/tensorflow/tensorflow-venv/local/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, target_list, fetch_list, feed_dict)
    405 
    406       return tf_session.TF_Run(self._session, feed_dict, fetch_list,
--> 407                                target_list)
    408 
    409     except tf_session.StatusNotOK as e:

KeyboardInterrupt: