In [0]:
# Copyright 2018 The TensorFlow GAN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
This colab walks you through the basics of using TensorFlow GAN (TF-GAN) on Tensor Processing Units (TPUs).
This notebook is hosted on GitHub. To view it in its original repository, after opening the notebook, select File > View on GitHub.
TPUs are chips optimized for machine learning training and inference. GAN training uses a lot of compute power, so TPUs expand the range of what we can realistically accomplish with GANs.
In this colab we'll use TPUs to train a GAN for a more difficult task than the MNIST task tackled in the GANEstimator colab. We'll use the CIFAR10 dataset to train a model to generate images of airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks.
The training runs for 50000 steps and completes in roughly 10 minutes on TPUs.
Our model implements the Deep Convolutional Generative Adversarial Network (DCGAN) architecture. Convolutional Neural Networks (CNNs) are a common machine learning technique for dealing with images. DCGAN adapts a CNN architecture so that it can function as part of a GAN.
This notebook should be run in Colaboratory. If you are viewing this from GitHub, follow the GitHub instructions. If you are viewing this from Colaboratory, you should skip to the Colaboratory instructions.
GitHub
tab.Enter a GitHub URL or search by organization or user
, put in the URL of this notebook in GitHub and click the magnifying glass icon next to it.Runtime > Change runtime type
.Hardware accelerator
.TPU
and click Save
.Connect
in the upper right corner and select Connect to hosted runtime
.use_tpu
and supply a bucket name as input). You can also run the cells manually with Shift-ENTER. First, you'll need to enable TPUs for the notebook.
Navigate to Edit→Notebook Settings, and select TPU from the Hardware Accelerator drop-down (you can also access Notebook Settings via the command palette: cmd/ctrl-shift-P).
Next, we'll check that we can connect to the TPU.
In [0]:
import os
import tensorflow as tf
import pprint
assert 'COLAB_TPU_ADDR' in os.environ, 'Did you forget to switch to TPU?'
tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
with tf.Session(tpu_address) as sess:
devices = sess.list_devices()
pprint.pprint(devices)
device_is_tpu = [True if 'TPU' in str(x) else False for x in devices]
assert True in device_is_tpu, 'Did you forget to switch to TPU?'
To run on Google's free Cloud TPUs, you must set up a Google Cloud Storage bucket to store logs and checkpoints. Even though this may require entering payment information, running this notebook alone should fall under the free pricing tier and cost nothing.
In [0]:
import json
import os
import pprint
import re
import time
import tensorflow as tf
# Google Cloud Storage bucket for Estimator logs and storing
# the training dataset.
bucket = '' #@param {type:"string"}
assert bucket, 'Must specify an existing GCS bucket name'
print('Using bucket: {}'.format(bucket))
model_dir = 'gs://{}/{}'.format(
bucket, time.strftime('tpuestimator-tfgan/%Y-%m-%d-%H-%M-%S'))
print('Using model dir: {}'.format(model_dir))
from google.colab import auth
auth.authenticate_user()
assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU; did you request a TPU in Notebook Settings?'
tpu_address = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
# Upload credentials to TPU.
with tf.Session(tpu_address) as sess:
with open('/content/adc.json', 'r') as f:
auth_info = json.load(f)
tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)
# Now credentials are set for all future sessions on this TPU.
In [0]:
# Check that imports for the rest of the file work.
import os
import tensorflow as tf
!pip install tensorflow-gan
import tensorflow_gan as tfgan
import tensorflow_datasets as tfds
import tensorflow_hub as hub
import numpy as np
import matplotlib.pyplot as plt
# Allow matplotlib images to render immediately.
%matplotlib inline
# Disable noisy outputs.
tf.logging.set_verbosity(tf.logging.ERROR)
tf.autograph.set_verbosity(0, False)
import warnings
warnings.filterwarnings("ignore")
In [0]:
import tensorflow as tf
import tensorflow_datasets as tfds
dataset_dir = 'gs://{}/{}'.format(bucket, 'datasets')
def input_fn(mode, params):
assert 'batch_size' in params
assert 'noise_dims' in params
bs = params['batch_size']
nd = params['noise_dims']
split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test'
shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
just_noise = (mode == tf.estimator.ModeKeys.PREDICT)
noise_ds = (tf.data.Dataset.from_tensors(0)
.map(lambda _: tf.random_normal([bs, nd]))
# If 'predict', just generate one batch.
.repeat(1 if just_noise else None))
if just_noise:
return noise_ds
def _preprocess(element):
# Map [0, 255] to [-1, 1].
images = (tf.cast(element['image'], tf.float32) - 127.5) / 127.5
return images
images_ds = (tfds.load('cifar10:3.*.*', split=split, data_dir=dataset_dir)
.map(_preprocess, num_parallel_calls=4)
.cache()
.repeat())
if shuffle:
images_ds = images_ds.shuffle(
buffer_size=10000, reshuffle_each_iteration=True)
images_ds = (images_ds.batch(bs, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))
return tf.data.Dataset.zip((noise_ds, images_ds))
def noise_input_fn(params):
np.random.seed(0)
np_noise = np.random.randn(params['batch_size'], params['noise_dims'])
return tf.data.Dataset.from_tensors(tf.constant(np_noise, dtype=tf.float32))
Download the data. TensorFlow Datsets will write the data once to your GCS bucket, then reuse it for future calls.
In [0]:
params = {'batch_size': 1, 'noise_dims':1}
input_fn(tf.estimator.ModeKeys.EVAL, params)
Sanity check the inputs.
In [0]:
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow_gan as tfgan
params = {'batch_size': 80, 'noise_dims':64}
ds = input_fn(tf.estimator.ModeKeys.EVAL, params)
numpy_imgs = next(tfds.as_numpy(ds))[1]
image_grid = tfgan.eval.python_image_grid(numpy_imgs, grid_shape=(8, 10))
def _show_image_grid(image_grid):
plt.axis('off')
plt.imshow((image_grid + 1.0) / 2.0, # [-1, 1] -> [0, 1]
aspect='auto')
plt.show()
_show_image_grid(image_grid)
As usual, our GAN has two separate networks:
We define discriminator()
and generator()
builder functions that assemble these networks, along with several helper functions that build pieces of these networks. In the "Estimator" section below we pass the discriminator()
and generator()
functions to the TPUGANEstimator
In [0]:
def _leaky_relu(x):
return tf.nn.leaky_relu(x, alpha=0.2)
def _batch_norm(x, is_training, name):
return tf.layers.batch_normalization(
x, momentum=0.9, epsilon=1e-5, training=is_training, name=name)
def _dense(x, channels, name):
return tf.layers.dense(
x, channels,
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name=name)
def _conv2d(x, filters, kernel_size, stride, name):
return tf.layers.conv2d(
x, filters, [kernel_size, kernel_size],
strides=[stride, stride], padding='same',
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name=name)
def _deconv2d(x, filters, kernel_size, stride, name):
return tf.layers.conv2d_transpose(
x, filters, [kernel_size, kernel_size],
strides=[stride, stride], padding='same',
kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
name=name)
def discriminator(images, unused_conditioning, is_training=True,
scope='Discriminator'):
"""Discriminator for CIFAR images.
Args:
images: A Tensor of shape [batch size, width, height, channels], that can be
either real or generated. It is the discriminator's goal to distinguish
between the two.
unused_conditioning: The TFGAN API can help with conditional GANs, which
would require extra `condition` information to both the generator and the
discriminator. Since this example is not conditional, we do not use this
argument.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
scope: A variable scope or string for the discriminator.
Returns:
A 1D Tensor of shape [batch size] representing the confidence that the
images are real. The output can lie in [-inf, inf], with positive values
indicating high confidence that the images are real.
"""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
x = _conv2d(images, 64, 5, 2, name='d_conv1')
x = _leaky_relu(x)
x = _conv2d(x, 128, 5, 2, name='d_conv2')
x = _leaky_relu(_batch_norm(x, is_training, name='d_bn2'))
x = _conv2d(x, 256, 5, 2, name='d_conv3')
x = _leaky_relu(_batch_norm(x, is_training, name='d_bn3'))
x = tf.reshape(x, [-1, 4 * 4 * 256])
x = _dense(x, 1, name='d_fc_4')
return x
def generator(noise, is_training=True, scope='Generator'):
"""Generator to produce CIFAR images.
Args:
noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
does not use conditioning, this Tensor represents a noise vector of some
kind that will be reshaped by the generator into CIFAR examples.
is_training: If `True`, batch norm uses batch statistics. If `False`, batch
norm uses the exponential moving average collected from population
statistics.
scope: A variable scope or string for the generator.
Returns:
A single Tensor with a batch of generated CIFAR images.
"""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
net = _dense(noise, 4096, name='g_fc1')
net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn1'))
net = tf.reshape(net, [-1, 4, 4, 256])
net = _deconv2d(net, 128, 5, 2, name='g_dconv2')
net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn2'))
net = _deconv2d(net, 64, 4, 2, name='g_dconv3')
net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn3'))
net = _deconv2d(net, 3, 4, 2, name='g_dconv4')
net = tf.tanh(net)
return net
In [0]:
import tensorflow as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub
import numpy as np
eval_batch_size = 4000 #@param
images_per_batch = 2000 #@param
def get_real_image_logits(num_images, classifier_model):
"""Returns an array with logits from real images and a CIFAR classifier.
We normally want many thousands of examples to run eval. However, we can't fit
inference for all of them in memory at once. Instead, we use TF-GAN eval utils
to more efficiently manage memory.
Args:
num_images: Total number of images to produce logits for.
classifier_model: A Python function that takes images and produces logits.
Returns:
A numpy array of logits of shape close to [num_images, ?].
"""
ds = input_fn(tf.estimator.ModeKeys.TRAIN,
{'batch_size': images_per_batch, 'noise_dims': 1})
iterator = tf.data.make_one_shot_iterator(ds)
real_logits = tfgan.eval.sample_and_run_classifier_fn(
sample_fn=lambda _: iterator.get_next()[1],
sample_inputs=[0.0] * (num_images // images_per_batch), # unused
classifier_fn=classifier_model)
with tf.train.MonitoredSession() as sess:
logits = sess.run(real_logits)
assert len(logits.shape) == 2
assert logits.shape[0] == num_images
return logits
def init_global_real_logits():
"""Initialize a global variable with classifier logits for real data."""
# We can hold all the real logits in memory at once, since CIFAR10 isn't that
# big. Be sure to calculate it only once.
global real_logits
try:
real_logits is not None
except NameError:
with tf.Graph().as_default():
classifier_model = hub.Module("https://tfhub.dev/deepmind/ganeval-cifar10-convnet/1")
real_logits = get_real_image_logits(
eval_batch_size, classifier_model)
assert real_logits.shape == (eval_batch_size, 10)
def calculate_real_data_classifier_score():
"""Calculate the classifier score on real data logits."""
assert real_logits is not None
classifier_score = tfgan.eval.classifier_score_from_logits(real_logits)
with tf.train.MonitoredSession() as sess:
cscore_real = sess.run(classifier_score)
return cscore_real
def get_inception_score_and_fid(est):
"""Calculate our evaluation metrics."""
global real_logits
assert real_logits is not None
tf.reset_default_graph()
classifier_model = hub.Module("https://tfhub.dev/deepmind/ganeval-cifar10-convnet/1")
# We dont' want to hold all the images and activations at once, so use a
# memory-efficient utility.
def sample_fn(_):
predictions = np.array([x['generated_data'] for x in est.predict(input_fn)])
assert predictions.shape == (images_per_batch, 32, 32, 3)
return predictions
fake_logits = tfgan.eval.sample_and_run_classifier_fn(
sample_fn=sample_fn,
sample_inputs=[0.0] * (eval_batch_size // images_per_batch), # unused
classifier_fn=classifier_model)
fake_logits.shape.assert_is_compatible_with([eval_batch_size, 10])
classifier_score = tfgan.eval.classifier_score_from_logits(fake_logits)
fid = tfgan.eval.frechet_classifier_distance_from_activations(
real_logits, fake_logits)
with tf.train.MonitoredSession() as sess:
cscore_np, fid_np = sess.run([classifier_score, fid])
return cscore_np, fid_np
In [0]:
import os
import tensorflow as tf
import tensorflow_gan as tfgan
noise_dims = 1024 #@param
generator_lr = 0.0002 #@param
discriminator_lr = 0.0002 #@param
train_batch_size = 1024 #@param
config = tf.estimator.tpu.RunConfig(
model_dir=model_dir,
master=tpu_address,
tpu_config=tf.estimator.tpu.TPUConfig(iterations_per_loop=images_per_batch))
est = tfgan.estimator.TPUGANEstimator(
generator_fn=generator,
discriminator_fn=discriminator,
generator_loss_fn=tfgan.losses.modified_generator_loss,
discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
generator_optimizer=tf.train.AdamOptimizer(generator_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(discriminator_lr, 0.5),
joint_train=True, # train G and D jointly instead of sequentially.
train_batch_size=train_batch_size,
predict_batch_size=images_per_batch,
use_tpu=True,
params={'noise_dims': noise_dims},
config=config)
In [0]:
import time
import matplotlib.pyplot as plt
max_steps = 50000 #@param
steps_per_eval = 5000 #@param
cur_step = 0
start_time = time.time()
cscores, fids, steps = [], [], []
init_global_real_logits()
print('Initialized classifier logits for real data.')
classifier_score_real_data = calculate_real_data_classifier_score()
print('Calculated classifier score for real data.')
while cur_step < max_steps:
# Train for a fixed number of steps.
start_step = cur_step
step_to_stop_at = min(cur_step + steps_per_eval, max_steps)
start = time.time()
est.train(input_fn, max_steps=step_to_stop_at)
end = time.time()
cur_step = step_to_stop_at
# Print some performance statistics.
steps_taken = step_to_stop_at - start_step
time_taken = end - start
steps_per_sec = steps_taken / time_taken
min_since_start = (time.time() - start_time) / 60.0
print("Current step: %i, %.4f steps / sec, time since start: %.1f min" % (
cur_step, steps_per_sec, min_since_start))
# Calculate some evaluation metrics.
eval_start_time = time.time()
cscore, fid = get_inception_score_and_fid(est)
eval_time = (time.time() - eval_start_time)
cscores.append(cscore)
fids.append(fid)
steps.append(cur_step)
print("Classifier score: %.2f / %.2f, FID: %.1f, "
"time to calculate eval: %.2f sec" % (
cscore, classifier_score_real_data, fid, eval_time))
# Generate and show some predictions.
predictions = np.array(
[x['generated_data'] for x in est.predict(noise_input_fn)])[:80]
image_grid = tfgan.eval.python_image_grid(predictions, grid_shape=(8, 10))
_show_image_grid(image_grid)
# Plot the metrics vs step.
plt.title('Frechet distance per step')
plt.plot(steps, fids)
plt.figure()
plt.title('Classifier Score per step')
plt.plot(steps, cscores)
plt.plot(steps, [classifier_score_real_data] * len(steps))
plt.figure()
Out[0]: