Lattice in MNIST

In this tutorial, we'll show how to use lattice layer together with other layers such as neural networks. We will construct a neural network with 1 hidden layer for classifying hand-written digit, and then feed the output of neural network to the lattice layer to capture the possible interactions between output of neural network.


In [ ]:
# Import libraries
!pip install tensorflow_lattice
import tensorflow as tf
import tensorflow_lattice as tfl
from tensorflow.examples.tutorials.mnist import input_data

In [ ]:
# Define helper functions

# linear layer's output is output = w * input_tensor + bias.
def _linear_layer(input_tensor, input_dim, output_dim):
    w = tf.Variable(
        tf.random_normal([input_dim, output_dim], mean=0.0, stddev=0.1))
    b = tf.Variable(tf.zeros([output_dim]))
    return tf.matmul(input_tensor, w) + b

# The following function returns lattice parameters for the identity function
# f(x1, x2, x3, ..., xn) = [x1, x2, ..., xn].
def identity_lattice(lattice_sizes, dim=10):
    linear_weights = []
    for cnt in range(dim):
        linear_weight = [0.0] * dim
        linear_weight[cnt] = float(dim)
        linear_weights.append(linear_weight)
    lattice_params = tfl.python.lib.lattice_layers.lattice_param_as_linear(
        lattice_sizes,
        dim,
        linear_weights=linear_weights)
    for cnt1 in range(len(lattice_params)):
        for cnt2 in range(len(lattice_params[cnt1])):
            lattice_params[cnt1][cnt2] += 0.5
      
    return lattice_params

In [ ]:
tf.reset_default_graph()

data_dir = '/tmp/tfl-data'

# Mnist dataset contains a 28 x 28 (784) image of hand written digit and
# a label in one-hot representation, i.e., if label == 0, it means the image
# contains "0", etc. Since there are total 10 digits, the label is
# a 10-dim vector.
mnist = input_data.read_data_sets(data_dir, one_hot=True)
train_batch_size = 1000
learning_rate = 0.05
num_steps = 3000

# Placeholders for feeding the dataset.
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
                       
# First hidden layer has 100 hidden units.
hidden = tf.nn.relu(_linear_layer(x, 784, 100))
# From 100 hidden units to the final 10 dim output.
nn_y = _linear_layer(hidden, 100, 10)

# We also construct a lattice layer.
# We apply softmax to nn_y which converts the output of neural network to the
# probability. So nn_y is in 10 dimensional probability simplex.
# Then 2 x 2 x ... x 2 layer uses this as an input and make a final 10 dim
# prediction.
output_dim = 10
lattice_sizes = [2] * output_dim

# We initialize a lattice to be the identity function.
lattice_init = identity_lattice(lattice_sizes=lattice_sizes, dim=output_dim)

# Now we define 2 x 2 x ... x 2 lattice that uses tf.nn.softmax(nn_y) as an
# input. This is the additional non-linearity.
lattice_output, _, _, _ = tfl.lattice_layer(
    tf.nn.softmax(nn_y),
    lattice_sizes=lattice_sizes,
    output_dim=output_dim,
    lattice_initializer=lattice_init,
    interpolation_type='hypercube')

# loss function for training NN.
nn_cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=nn_y))

# loss function for training lattice + NN jointly.
lattice_cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=lattice_output))

# NN training step.
nn_train_step = tf.train.AdamOptimizer(learning_rate).minimize(nn_cross_entropy)

# lattice + NN training step.
lattice_train_step = tf.train.AdamOptimizer(0.001).minimize(
    lattice_cross_entropy)

sess = tf.Session()

sess.run(tf.global_variables_initializer())

train_ops = {'train_step': nn_train_step, 'loss': nn_cross_entropy}
lattice_train_ops = {'train_step': lattice_train_step,
                     'loss': lattice_cross_entropy}

print('Pre training NN')
# Pre-train NN.
for cnt in range(num_steps):
    batch_xs, batch_ys = mnist.train.next_batch(train_batch_size)
    value_dict = sess.run(train_ops, feed_dict={x: batch_xs, y_: batch_ys})
    if cnt % 1000 == 0:
        print('loss=%f' % value_dict['loss'])


# NN Accuracy
correct_prediction = tf.equal(tf.argmax(lattice_output, 1),
                              tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('training accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.train.images,
                                    y_: mnist.train.labels}))
print('test accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                    y_: mnist.test.labels}))

In [ ]:
print('Lattice train')
# Lattice + NN Train
for cnt in range(num_steps):
    batch_xs, batch_ys = mnist.train.next_batch(train_batch_size)
    value_dict = sess.run(lattice_train_ops, feed_dict={x: batch_xs,
                                                        y_: batch_ys})
    if cnt % 1000 == 0:
        print('loss=%f' % value_dict['loss']) 

print('training accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.train.images,
                                    y_: mnist.train.labels}))
print('test accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                    y_: mnist.test.labels}))