In [0]:
import tensorflow as tf
from tensorflow.contrib import autograph
import matplotlib.pyplot as plt
AutoGraph helps you write complicated graph code using just plain Python -- behind the scenes, AutoGraph automatically transforms your code into the equivalent TF graph code. We support a large chunk of the Python language, which is growing. Please see this document for what we currently support, and what we're working on.
Here's a quick example of how it works:
In [0]:
# Autograph can convert functions like this...
def g(x):
if x > 0:
x = x * x
else:
x = 0.0
return x
# ...into graph-building functions like this:
def tf_g(x):
with tf.name_scope('g'):
def if_true():
with tf.name_scope('if_true'):
x_1, = x,
x_1 = x_1 * x_1
return x_1,
def if_false():
with tf.name_scope('if_false'):
x_1, = x,
x_1 = 0.0
return x_1,
x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)
return x
In [0]:
# You can run your plain-Python code in graph mode,
# and get the same results out, but with all the benfits of graphs:
print('Original value: %2.2f' % g(9.0))
# Generate a graph-version of g and call it:
tf_g = autograph.to_graph(g)
with tf.Graph().as_default():
# The result works like a regular op: takes tensors in, returns tensors.
# You can inspect the graph using tf.get_default_graph().as_graph_def()
g_ops = tf_g(tf.constant(9.0))
with tf.Session() as sess:
print('Autograph value: %2.2f\n' % sess.run(g_ops))
# You can view, debug and tweak the generated code:
print(autograph.to_code(g))
AutoGraph can convert a large chunk of the Python language into equivalent graph-construction code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph. AutoGraph will automatically convert most Python control flow statements into their correct graph equivalent.
We support common statements like while
, for
, if
, break
, return
and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:
In [0]:
# Continue in a loop
def f(l):
s = 0
for c in l:
if c % 2 > 0:
continue
s += c
return s
print('Original value: %d' % f([10,12,15,20]))
tf_f = autograph.to_graph(f)
with tf.Graph().as_default():
with tf.Session():
print('Graph value: %d\n\n' % tf_f(tf.constant([10,12,15,20])).eval())
print(autograph.to_code(f))
Try replacing the continue
in the above code with break
-- AutoGraph supports that as well!
Let's try some other useful Python constructs, like print
and assert
. We automatically convert Python assert
statements into the equivalent tf.Assert
code.
In [0]:
def f(x):
assert x != 0, 'Do not pass zero!'
return x * x
tf_f = autograph.to_graph(f)
with tf.Graph().as_default():
with tf.Session():
try:
print(tf_f(tf.constant(0)).eval())
except tf.errors.InvalidArgumentError as e:
print('Got error message:\n%s' % e.message)
You can also use plain Python print
functions in in-graph
In [0]:
def f(n):
if n >= 0:
while n < 5:
n += 1
print(n)
return n
tf_f = autograph.to_graph(f)
with tf.Graph().as_default():
with tf.Session():
tf_f(tf.constant(0)).eval()
Appending to lists in loops also works (we create a TensorArray
for you behind the scenes)
In [0]:
def f(n):
z = []
# We ask you to tell us the element dtype of the list
z = autograph.utils.set_element_type(z, tf.int32)
for i in range(n):
z.append(i)
# when you're done with the list, stack it
# (this is just like np.stack)
return autograph.stack(z)
tf_f = autograph.to_graph(f)
with tf.Graph().as_default():
with tf.Session():
print(tf_f(tf.constant(3)).eval())
print('\n\n'+autograph.to_code(f))
In [0]:
def fizzbuzz(num):
if num % 3 == 0 and num % 5 == 0:
print('FizzBuzz')
elif num % 3 == 0:
print('Fizz')
elif num % 5 == 0:
print('Buzz')
else:
print(num)
return num
In [0]:
tf_g = autograph.to_graph(fizzbuzz)
with tf.Graph().as_default():
# The result works like a regular op: takes tensors in, returns tensors.
# You can inspect the graph using tf.get_default_graph().as_graph_def()
g_ops = tf_g(tf.constant(15))
with tf.Session() as sess:
sess.run(g_ops)
# You can view, debug and tweak the generated code:
print('\n')
print(autograph.to_code(fizzbuzz))
In [0]:
# See what happens when you turn AutoGraph off.
# Do you see the type or the value of x when you print it?
# @autograph.convert()
def square_log(x):
x = x * x
print('Squared value of x =', x)
return x
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(square_log(tf.constant(4))))
In [0]:
def square_if_positive(x):
x = tf.cond(tf.greater(x, 0), lambda: x * x, lambda: x)
return x
with tf.Session() as sess:
print(sess.run(square_if_positive(tf.constant(4))))
In [0]:
@autograph.convert()
def square_if_positive(x):
... # <<< fill it in!
with tf.Session() as sess:
print(sess.run(square_if_positive(tf.constant(4))))
In [0]:
# Simple cond
@autograph.convert()
def square_if_positive(x):
if x > 0:
x = x * x
return x
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(square_if_positive(tf.constant(4))))
In [0]:
def nearest_odd_square(x):
def if_positive():
x1 = x * x
x1 = tf.cond(tf.equal(x1 % 2, 0), lambda: x1 + 1, lambda: x1)
return x1,
x = tf.cond(tf.greater(x, 0), if_positive, lambda: x)
return x
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(nearest_odd_square(tf.constant(4))))
In [0]:
@autograph.convert()
def nearest_odd_square(x):
... # <<< fill it in!
with tf.Session() as sess:
print(sess.run(nearest_odd_square(tf.constant(4))))
In [0]:
@autograph.convert()
def nearest_odd_square(x):
if x > 0:
x = x * x
if x % 2 == 0:
x = x + 1
return x
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(nearest_odd_square(tf.constant(4))))
In [0]:
# Convert a while loop
def square_until_stop(x, y):
x = tf.while_loop(lambda x: tf.less(x, y), lambda x: x * x, [x])
return x
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))
In [0]:
@autograph.convert()
def square_until_stop(x, y):
... # fill it in!
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))
In [0]:
@autograph.convert()
def square_until_stop(x, y):
while x < y:
x = x * x
return x
with tf.Graph().as_default():
with tf.Session() as sess:
print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))
In [0]:
@autograph.convert()
def argwhere_cumsum(x, threshold):
current_sum = 0.0
idx = 0
for i in range(len(x)):
idx = i
if current_sum >= threshold:
break
current_sum += x[i]
return idx
N = 10
with tf.Graph().as_default():
with tf.Session() as sess:
idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))
print(sess.run(idx))
In [0]:
@autograph.convert()
def argwhere_cumsum(x, threshold):
...
N = 10
with tf.Graph().as_default():
with tf.Session() as sess:
idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))
print(sess.run(idx))
In [0]:
@autograph.convert()
def argwhere_cumsum(x, threshold):
current_sum = 0.0
idx = 0
for i in range(len(x)):
idx = i
if current_sum >= threshold:
break
current_sum += x[i]
return idx
N = 10
with tf.Graph().as_default():
with tf.Session() as sess:
idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))
print(sess.run(idx))
Writing control flow in AutoGraph is easy, so running a training loop in a TensorFlow graph should be easy as well!
Here, we show an example of training a simple Keras model on MNIST, where the entire training process -- loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence -- is done in-graph.
In [0]:
import gzip
import shutil
from six.moves import urllib
def download(directory, filename):
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
zipped_filepath = filepath + '.gz'
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(directory, images_file, labels_file):
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def decode_label(label):
label = tf.decode_raw(label, tf.uint8)
label = tf.reshape(label, [])
return tf.to_int32(label)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(decode_label)
return tf.data.Dataset.zip((images, labels))
def mnist_train(directory):
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
def mnist_test(directory):
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
In [0]:
def mlp_model(input_shape):
model = tf.keras.Sequential((
tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),
tf.keras.layers.Dense(100, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')))
model.build()
return model
def predict(m, x, y):
y_p = m(x)
losses = tf.keras.losses.categorical_crossentropy(y, y_p)
l = tf.reduce_mean(losses)
accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)
accuracy = tf.reduce_mean(accuracies)
return l, accuracy
def fit(m, x, y, opt):
l, accuracy = predict(m, x, y)
opt.minimize(l)
return l, accuracy
def setup_mnist_data(is_training, hp, batch_size):
if is_training:
ds = mnist_train('/tmp/autograph_mnist_data')
ds = ds.shuffle(batch_size * 10)
else:
ds = mnist_test('/tmp/autograph_mnist_data')
ds = ds.repeat()
ds = ds.batch(batch_size)
return ds
def get_next_batch(ds):
itr = ds.make_one_shot_iterator()
image, label = itr.get_next()
x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))
y = tf.one_hot(tf.squeeze(label), 10)
return x, y
In [0]:
def train(train_ds, test_ds, hp):
m = mlp_model((28 * 28,))
opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)
# We'd like to save our losses to a list. In order for AutoGraph
# to convert these lists into their graph equivalent,
# we need to specify the element type of the lists.
train_losses = []
train_losses = autograph.utils.set_element_type(train_losses, tf.float32)
test_losses = []
test_losses = autograph.utils.set_element_type(test_losses, tf.float32)
train_accuracies = []
train_accuracies = autograph.utils.set_element_type(train_accuracies, tf.float32)
test_accuracies = []
test_accuracies = autograph.utils.set_element_type(test_accuracies, tf.float32)
# This entire training loop will be run in-graph.
i = tf.constant(0)
while i < hp.max_steps:
train_x, train_y = get_next_batch(train_ds)
test_x, test_y = get_next_batch(test_ds)
# add get next
step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)
step_test_loss, step_test_accuracy = predict(m, test_x, test_y)
if i % (hp.max_steps // 10) == 0:
print('Step', i, 'train loss:', step_train_loss, 'test loss:',
step_test_loss, 'train accuracy:', step_train_accuracy,
'test accuracy:', step_test_accuracy)
train_losses.append(step_train_loss)
test_losses.append(step_test_loss)
train_accuracies.append(step_train_accuracy)
test_accuracies.append(step_test_accuracy)
i += 1
# We've recorded our loss values and accuracies
# to a list in a graph with AutoGraph's help.
# In order to return the values as a Tensor,
# we need to stack them before returning them.
return (autograph.stack(train_losses), autograph.stack(test_losses), autograph.stack(train_accuracies),
autograph.stack(test_accuracies))
In [0]:
with tf.Graph().as_default():
hp = tf.contrib.training.HParams(
learning_rate=0.05,
max_steps=500,
)
train_ds = setup_mnist_data(True, hp, 50)
test_ds = setup_mnist_data(False, hp, 1000)
tf_train = autograph.to_graph(train)
(train_losses, test_losses, train_accuracies,
test_accuracies) = tf_train(train_ds, test_ds, hp)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
(train_losses, test_losses, train_accuracies,
test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,
test_accuracies])
plt.title('MNIST train/test losses')
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.legend()
plt.xlabel('Training step')
plt.ylabel('Loss')
plt.show()
plt.title('MNIST train/test accuracies')
plt.plot(train_accuracies, label='train accuracy')
plt.plot(test_accuracies, label='test accuracy')
plt.legend(loc='lower right')
plt.xlabel('Training step')
plt.ylabel('Accuracy')
plt.show()