In this tutorial, we'll show how to use lattice layer together with other layers such as neural networks. We will construct a neural network with 1 hidden layer for classifying hand-written digit, and then feed the output of neural network to the lattice layer to capture the possible interactions between output of neural network.
In [ ]:
# Import libraries
!pip install tensorflow_lattice
import tensorflow as tf
import tensorflow_lattice as tfl
from tensorflow.examples.tutorials.mnist import input_data
In [ ]:
# Define helper functions
# linear layer's output is output = w * input_tensor + bias.
def _linear_layer(input_tensor, input_dim, output_dim):
w = tf.Variable(
tf.random_normal([input_dim, output_dim], mean=0.0, stddev=0.1))
b = tf.Variable(tf.zeros([output_dim]))
return tf.matmul(input_tensor, w) + b
# The following function returns lattice parameters for the identity function
# f(x1, x2, x3, ..., xn) = [x1, x2, ..., xn].
def identity_lattice(lattice_sizes, dim=10):
linear_weights = []
for cnt in range(dim):
linear_weight = [0.0] * dim
linear_weight[cnt] = float(dim)
linear_weights.append(linear_weight)
lattice_params = tfl.python.lib.lattice_layers.lattice_param_as_linear(
lattice_sizes,
dim,
linear_weights=linear_weights)
for cnt1 in range(len(lattice_params)):
for cnt2 in range(len(lattice_params[cnt1])):
lattice_params[cnt1][cnt2] += 0.5
return lattice_params
In [ ]:
tf.reset_default_graph()
data_dir = '/tmp/tfl-data'
# Mnist dataset contains a 28 x 28 (784) image of hand written digit and
# a label in one-hot representation, i.e., if label == 0, it means the image
# contains "0", etc. Since there are total 10 digits, the label is
# a 10-dim vector.
mnist = input_data.read_data_sets(data_dir, one_hot=True)
train_batch_size = 1000
learning_rate = 0.05
num_steps = 3000
# Placeholders for feeding the dataset.
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])
# First hidden layer has 100 hidden units.
hidden = tf.nn.relu(_linear_layer(x, 784, 100))
# From 100 hidden units to the final 10 dim output.
nn_y = _linear_layer(hidden, 100, 10)
# We also construct a lattice layer.
# We apply softmax to nn_y which converts the output of neural network to the
# probability. So nn_y is in 10 dimensional probability simplex.
# Then 2 x 2 x ... x 2 layer uses this as an input and make a final 10 dim
# prediction.
output_dim = 10
lattice_sizes = [2] * output_dim
# We initialize a lattice to be the identity function.
lattice_init = identity_lattice(lattice_sizes=lattice_sizes, dim=output_dim)
# Now we define 2 x 2 x ... x 2 lattice that uses tf.nn.softmax(nn_y) as an
# input. This is the additional non-linearity.
lattice_output, _, _, _ = tfl.lattice_layer(
tf.nn.softmax(nn_y),
lattice_sizes=lattice_sizes,
output_dim=output_dim,
lattice_initializer=lattice_init,
interpolation_type='hypercube')
# loss function for training NN.
nn_cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=nn_y))
# loss function for training lattice + NN jointly.
lattice_cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=lattice_output))
# NN training step.
nn_train_step = tf.train.AdamOptimizer(learning_rate).minimize(nn_cross_entropy)
# lattice + NN training step.
lattice_train_step = tf.train.AdamOptimizer(0.001).minimize(
lattice_cross_entropy)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
train_ops = {'train_step': nn_train_step, 'loss': nn_cross_entropy}
lattice_train_ops = {'train_step': lattice_train_step,
'loss': lattice_cross_entropy}
print('Pre training NN')
# Pre-train NN.
for cnt in range(num_steps):
batch_xs, batch_ys = mnist.train.next_batch(train_batch_size)
value_dict = sess.run(train_ops, feed_dict={x: batch_xs, y_: batch_ys})
if cnt % 1000 == 0:
print('loss=%f' % value_dict['loss'])
# NN Accuracy
correct_prediction = tf.equal(tf.argmax(lattice_output, 1),
tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print('training accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.train.images,
y_: mnist.train.labels}))
print('test accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels}))
In [ ]:
print('Lattice train')
# Lattice + NN Train
for cnt in range(num_steps):
batch_xs, batch_ys = mnist.train.next_batch(train_batch_size)
value_dict = sess.run(lattice_train_ops, feed_dict={x: batch_xs,
y_: batch_ys})
if cnt % 1000 == 0:
print('loss=%f' % value_dict['loss'])
print('training accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.train.images,
y_: mnist.train.labels}))
print('test accuracy')
print(sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels}))