TensorFlow Dev Summit, 2018.
This interactive notebook demonstrates autograph, an experimental source-code transformation library to automatically convert TF.Eager and Python code to TensorFlow graphs.
Note: this is pre-alpha software! The notebook works best with Python 2, for now.
In [0]:
# Install TensorFlow; note that Colab notebooks run remotely, on virtual
# instances provided by Google.
!pip install -U -q tf-nightly
In [0]:
import os
import time
import tensorflow as tf
from tensorflow.contrib import autograph
import matplotlib.pyplot as plt
import numpy as np
import six
from google.colab import widgets
TF.Eager gives you more flexibility while coding, but at the cost of losing the benefits of TensorFlow graphs. For example, Eager does not currently support distributed training, exporting models, and a variety of memory and computation optimizations.
Autograph gives you the best of both worlds: write your code in an Eager style, and we will automatically transform it into the equivalent TF graph code. The graph code can be executed eagerly (as a single op), included as part of a larger graph, or exported.
For example, autograph can convert a function like this:
In [0]:
def g(x):
if x > 0:
x = x * x
else:
x = 0
return x
... into a TF graph-building function:
In [0]:
print(autograph.to_code(g))
You can then use the converted function as you would any regular TF op -- you can pass Tensor
arguments and it will return Tensor
s:
In [0]:
tf_g = autograph.to_graph(g)
with tf.Graph().as_default():
g_ops = tf_g(tf.constant(9))
with tf.Session() as sess:
tf_g_result = sess.run(g_ops)
print('g(9) = %s' % g(9))
print('tf_g(9) = %s' % tf_g_result)
Autograph can convert a large chunk of the Python language into graph-equivalent code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in autograph. Autograph will automatically convert most Python control flow statements into their correct graph equivalent.
We support common statements like while
, for
, if
, break
, return
and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:
In [0]:
def sum_even(numbers):
s = 0
for n in numbers:
if n % 2 > 0:
continue
s += n
return s
tf_sum_even = autograph.to_graph(sum_even)
with tf.Graph().as_default():
with tf.Session() as sess:
result = sess.run(tf_sum_even(tf.constant([10, 12, 15, 20])))
print('Sum of even numbers: %s' % result)
# Uncomment the line below to print the generated graph code
# print(autograph.to_code(sum_even))
Try replacing the continue
in the above code with break
-- Autograph supports that as well!
The Python code above is much more readable than the matching graph code. Autograph takes care of tediously converting every piece of Python code into the matching TensorFlow graph version for you, so that you can quickly write maintainable code, but still benefit from the optimizations and deployment benefits of graphs.
Let's try some other useful Python constructs, like print
and assert
. We automatically convert Python assert
statements into the equivalent tf.Assert
code.
In [0]:
def f(x):
assert x != 0, 'Do not pass zero!'
return x * x
tf_f = autograph.to_graph(f)
with tf.Graph().as_default():
with tf.Session() as sess:
try:
print(sess.run(tf_f(tf.constant(0))))
except tf.errors.InvalidArgumentError as e:
print('Got error message: %s' % e.message)
# Uncomment the line below to print the generated graph code
# print(autograph.to_code(f))
You can also use print
functions in-graph:
In [0]:
def print_sign(n):
if n >= 0:
print(n, 'is positive!')
else:
print(n, 'is negative!')
return n
tf_print_sign = autograph.to_graph(print_sign)
with tf.Graph().as_default():
with tf.Session() as sess:
sess.run(tf_print_sign(tf.constant(1)))
# Uncomment the line below to print the generated graph code
# print(autograph.to_code(print_sign))
We can convert lists to TensorArray, so appending to lists also works, with a few modifications:
In [0]:
def f(n):
numbers = []
# We ask you to tell us about the element dtype.
autograph.utils.set_element_type(numbers, tf.int32)
for i in range(n):
numbers.append(i)
return numbers.stack() # Stack the list so that it can be used as a Tensor
tf_f = autograph.to_graph(f)
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(tf_f(tf.constant(5))))
# Uncomment the line below to print the generated graph code
# print(autograph.to_code(f))
And all of these functionalities, and more, can be composed into more complicated code:
In [0]:
def print_primes(n):
"""Returns all the prime numbers less than n."""
assert n > 0
primes = []
autograph.utils.set_element_type(primes, tf.int32)
for i in range(2, n):
is_prime = True
for k in range(2, i):
if i % k == 0:
is_prime = False
break
if not is_prime:
continue
primes.append(i)
all_primes = primes.stack()
print('The prime numbers less than', n, 'are:')
print(all_primes)
return tf.no_op()
tf_print_primes = autograph.to_graph(print_primes)
with tf.Graph().as_default():
with tf.Session() as sess:
n = tf.constant(50)
sess.run(tf_print_primes(n))
# Uncomment the line below to print the generated graph code
# print(autograph.to_code(print_primes))
In [0]:
import gzip
import shutil
from six.moves import urllib
def download(directory, filename):
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
zipped_filepath = filepath + '.gz'
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(directory, images_file, labels_file):
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def decode_label(label):
label = tf.decode_raw(label, tf.uint8)
label = tf.reshape(label, [])
return tf.to_int32(label)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(decode_label)
return tf.data.Dataset.zip((images, labels))
def mnist_train(directory):
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
def mnist_test(directory):
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
First, we'll define a small three-layer neural network using the Keras API
In [0]:
def mlp_model(input_shape):
model = tf.keras.Sequential([
tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
model.build()
return model
Let's connect the model definition (here abbreviated as m
) to a loss function, so that we can train our model.
In [0]:
def predict(m, x, y):
y_p = m(x)
losses = tf.keras.losses.categorical_crossentropy(y, y_p)
l = tf.reduce_mean(losses)
accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)
accuracy = tf.reduce_mean(accuracies)
return l, accuracy
Now the final piece of the problem specification (before loading data, and clicking everything together) is backpropagating the loss through the model, and optimizing the weights using the gradient.
In [0]:
def fit(m, x, y, opt):
l, accuracy = predict(m, x, y)
opt.minimize(l)
return l, accuracy
These are some utility functions to download data and generate batches for training
In [0]:
def setup_mnist_data(is_training, hp, batch_size):
if is_training:
ds = mnist_train('/tmp/autograph_mnist_data')
ds = ds.shuffle(batch_size * 10)
else:
ds = mnist_test('/tmp/autograph_mnist_data')
ds = ds.repeat()
ds = ds.batch(batch_size)
return ds
def get_next_batch(ds):
itr = ds.make_one_shot_iterator()
image, label = itr.get_next()
x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))
y = tf.one_hot(tf.squeeze(label), 10)
return x, y
This function specifies the main training loop. We instantiate the model (using the code above), instantiate an optimizer (here we'll use SGD with momentum, nothing too fancy), and we'll instantiate some lists to keep track of training and test loss and accuracy over time.
In the loop inside this function, we'll grab a batch of data, apply an update to the weights of our model to improve its performance, and then record its current training loss and accuracy. Every so often, we'll log some information about training as well.
In [0]:
def train(train_ds, test_ds, hp):
m = mlp_model((28 * 28,))
opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)
train_losses = []
train_losses = autograph.utils.set_element_type(train_losses, tf.float32)
test_losses = []
test_losses = autograph.utils.set_element_type(test_losses, tf.float32)
train_accuracies = []
train_accuracies = autograph.utils.set_element_type(train_accuracies,
tf.float32)
test_accuracies = []
test_accuracies = autograph.utils.set_element_type(test_accuracies,
tf.float32)
i = tf.constant(0)
while i < hp.max_steps:
train_x, train_y = get_next_batch(train_ds)
test_x, test_y = get_next_batch(test_ds)
step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)
step_test_loss, step_test_accuracy = predict(m, test_x, test_y)
if i % (hp.max_steps // 10) == 0:
print('Step', i, 'train loss:', step_train_loss, 'test loss:',
step_test_loss, 'train accuracy:', step_train_accuracy,
'test accuracy:', step_test_accuracy)
train_losses.append(step_train_loss)
test_losses.append(step_test_loss)
train_accuracies.append(step_train_accuracy)
test_accuracies.append(step_test_accuracy)
i += 1
return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),
test_accuracies.stack())
Everything is ready to go, let's train the model and plot its performance!
In [0]:
with tf.Graph().as_default():
hp = tf.contrib.training.HParams(
learning_rate=0.05,
max_steps=500,
)
train_ds = setup_mnist_data(True, hp, 50)
test_ds = setup_mnist_data(False, hp, 1000)
tf_train = autograph.to_graph(train)
(train_losses, test_losses, train_accuracies,
test_accuracies) = tf_train(train_ds, test_ds, hp)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
(train_losses, test_losses, train_accuracies,
test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,
test_accuracies])
plt.title('MNIST train/test losses')
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.legend()
plt.xlabel('Training step')
plt.ylabel('Loss')
plt.show()
plt.title('MNIST train/test accuracies')
plt.plot(train_accuracies, label='train accuracy')
plt.plot(test_accuracies, label='test accuracy')
plt.legend(loc='lower right')
plt.xlabel('Training step')
plt.ylabel('Accuracy')
plt.show()
In this exercise we build and train a model similar to the RNNColorbot model that was used in the main Eager notebook. The model is adapted for converting and training in graph mode.
To get started, we load the colorbot dataset. The code is identical to that used in the other exercise and its details are unimportant.
In [0]:
def parse(line):
"""Parses a line from the colors dataset.
Args:
line: A comma-separated string containing four items:
color_name, red, green, and blue, representing the name and
respectively the RGB value of the color, as an integer
between 0 and 255.
Returns:
A tuple of three tensors (rgb, chars, length), of shapes: (batch_size, 3),
(batch_size, max_sequence_length, 256) and respectively (batch_size).
"""
items = tf.string_split([line], ",").values
rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0
color_name = items[0]
chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)
length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)
return rgb, chars, length
def maybe_download(filename, work_directory, source_url):
"""Downloads the data from source url."""
if not tf.gfile.Exists(work_directory):
tf.gfile.MakeDirs(work_directory)
filepath = os.path.join(work_directory, filename)
if not tf.gfile.Exists(filepath):
temp_file_name, _ = six.moves.urllib.request.urlretrieve(source_url)
tf.gfile.Copy(temp_file_name, filepath)
with tf.gfile.GFile(filepath) as f:
size = f.size()
print('Successfully downloaded', filename, size, 'bytes.')
return filepath
def load_dataset(data_dir, url, batch_size, training=True):
"""Loads the colors data at path into a tf.PaddedDataset."""
path = maybe_download(os.path.basename(url), data_dir, url)
dataset = tf.data.TextLineDataset(path)
dataset = dataset.skip(1)
dataset = dataset.map(parse)
dataset = dataset.cache()
dataset = dataset.repeat()
if training:
dataset = dataset.shuffle(buffer_size=3000)
dataset = dataset.padded_batch(batch_size, padded_shapes=([None], [None, None], []))
return dataset
train_url = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv"
test_url = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv"
data_dir = "tmp/rnn/data"
Next, we set up the RNNColobot model, which is very similar to the one we used in the main exercise.
Autograph doesn't fully support classes yet (but it will soon!), so we'll write the model using simple functions.
In [0]:
def model_components():
lower_cell = tf.contrib.rnn.LSTMBlockCell(256)
lower_cell.build(tf.TensorShape((None, 256)))
upper_cell = tf.contrib.rnn.LSTMBlockCell(128)
upper_cell.build(tf.TensorShape((None, 256)))
relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)
relu_layer.build(tf.TensorShape((None, 128)))
return lower_cell, upper_cell, relu_layer
def rnn_layer(chars, cell, batch_size, training):
"""A simple RNN layer.
Args:
chars: A Tensor of shape (max_sequence_length, batch_size, input_size)
cell: An object of type tf.contrib.rnn.LSTMBlockCell
batch_size: Int, the batch size to use
training: Boolean, whether the layer is used for training
Returns:
A Tensor of shape (max_sequence_length, batch_size, output_size).
"""
hidden_outputs = []
autograph.utils.set_element_type(hidden_outputs, tf.float32)
state, output = cell.zero_state(batch_size, tf.float32)
n = tf.shape(chars)[0]
i = 0
while i < n:
ch = chars[i]
cell_output, (state, output) = cell.call(ch, (state, output))
hidden_outputs.append(cell_output)
i += 1
hidden_outputs = hidden_outputs.stack()
if training:
hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)
return hidden_outputs
def model(inputs, lower_cell, upper_cell, relu_layer, batch_size, training):
"""RNNColorbot model.
The model consists of two RNN layers (made by lower_cell and upper_cell),
followed by a fully connected layer with ReLU activation.
Args:
inputs: A tuple (chars, length)
lower_cell: An object of type tf.contrib.rnn.LSTMBlockCell
upper_cell: An object of type tf.contrib.rnn.LSTMBlockCell
relu_layer: An object of type tf.layers.Dense
batch_size: Int, the batch size to use
training: Boolean, whether the layer is used for training
Returns:
A Tensor of shape (batch_size, 3) - the model predictions.
"""
(chars, length) = inputs
chars_time_major = tf.transpose(chars, [1, 0, 2])
chars_time_major.set_shape((None, batch_size, 256))
hidden_outputs = rnn_layer(chars_time_major, lower_cell, batch_size, training)
final_outputs = rnn_layer(hidden_outputs, upper_cell, batch_size, training)
# Grab just the end-of-sequence from each output.
indices = tf.stack([length - 1, range(batch_size)], axis=1)
sequence_ends = tf.gather_nd(final_outputs, indices)
return relu_layer(sequence_ends)
def loss_fn(labels, predictions):
return tf.reduce_mean((predictions - labels) ** 2)
The train and test functions are also similar to the ones used in the Eager notebook. Since the network requires a fixed batch size, we'll train in a single shot, rather than by epoch.
In [0]:
def train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):
iterator = train_data.make_one_shot_iterator()
step = 0
while step < num_steps:
labels, chars, sequence_length = iterator.get_next()
predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=True)
loss = loss_fn(labels, predictions)
optimizer.minimize(loss)
if step % (num_steps // 10) == 0:
print('Step', step, 'train loss', loss)
step += 1
return step
def test(eval_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps):
total_loss = 0.0
iterator = eval_data.make_one_shot_iterator()
step = 0
while step < num_steps:
labels, chars, sequence_length = iterator.get_next()
predictions = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, batch_size, training=False)
total_loss += loss_fn(labels, predictions)
step += 1
print('Test loss', total_loss)
return total_loss
def train_model(train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps):
optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train(optimizer, train_data, lower_cell, upper_cell, relu_layer, batch_size, num_steps=tf.constant(train_steps))
test(eval_data, lower_cell, upper_cell, relu_layer, 50, num_steps=tf.constant(2))
print('Colorbot is ready to generate colors!\n\n')
# In graph mode, every op needs to be a dependent of another op.
# Here, we create a no_op that will drive the execution of all other code in
# this function. Autograph will add the necessary control dependencies.
return tf.no_op()
Finally, we add code to run inference on a single input, which we'll read from the input.
Note the do_not_convert
annotation that lets us disable conversion for certain functions and run them as a py_func
instead, so you can still call them from compiled code.
In [0]:
@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)
def draw_prediction(color_name, pred):
pred = pred * 255
pred = pred.astype(np.uint8)
plt.axis('off')
plt.imshow(pred)
plt.title(color_name)
plt.show()
def inference(color_name, lower_cell, upper_cell, relu_layer):
_, chars, sequence_length = parse(color_name)
chars = tf.expand_dims(chars, 0)
sequence_length = tf.expand_dims(sequence_length, 0)
pred = model((chars, sequence_length), lower_cell, upper_cell, relu_layer, 1, training=False)
pred = tf.minimum(pred, 1.0)
pred = tf.expand_dims(pred, 0)
draw_prediction(color_name, pred)
# Create an op that will drive the entire function.
return tf.no_op()
Finally, we put everything together.
Note that the entire training and testing code is all compiled into a single op (tf_train_model
) that you only execute once! We also still use a sess.run
loop for the inference part, because that requires keyboard input.
In [0]:
def run_input_loop(sess, inference_ops, color_name_placeholder):
"""Helper function that reads from input and calls the inference ops in a loop."""
tb = widgets.TabBar(["RNN Colorbot"])
while True:
with tb.output_to(0):
try:
color_name = six.moves.input("Give me a color name (or press 'enter' to exit): ")
except (EOFError, KeyboardInterrupt):
break
if not color_name:
break
with tb.output_to(0):
tb.clear_tab()
sess.run(inference_ops, {color_name_placeholder: color_name})
plt.show()
with tf.Graph().as_default():
# Read the data.
batch_size = 64
train_data = load_dataset(data_dir, train_url, batch_size)
eval_data = load_dataset(data_dir, test_url, 50, training=False)
# Create the model components.
lower_cell, upper_cell, relu_layer = model_components()
# Create the helper placeholder for inference.
color_name_placeholder = tf.placeholder(tf.string, shape=())
# Compile the train / test code.
tf_train_model = autograph.to_graph(train_model)
train_model_ops = tf_train_model(
train_data, eval_data, batch_size, lower_cell, upper_cell, relu_layer, train_steps=100)
# Compile the inference code.
tf_inference = autograph.to_graph(inference)
inference_ops = tf_inference(color_name_placeholder, lower_cell, upper_cell, relu_layer)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Run training and testing.
sess.run(train_model_ops)
# Run the inference loop.
run_input_loop(sess, inference_ops, color_name_placeholder)
Autograph is available in tensorflow.contrib, but it's still in its early stages. We're excited about the possibilities it brings — write your machine learning code in the flexible Eager style, but still enjoy all the benefits that come with running in graph mode. A beta version will be available soon -- stay tuned!