In [0]:
#@title ##### Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
|
In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
# Globally Enable XLA.
# tf.config.optimizer.set_jit(True)
try:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
tfb = tfp.bijectors
tfd = tfp.distributions
tfn = tfp.experimental.nn
In [0]:
[train_dataset, eval_dataset], datasets_info = tfds.load(
name='binarized_mnist',
split=['train', 'test'],
with_info=True,
shuffle_files=True)
def _preprocess(sample):
return tf.cast(sample['image'], tf.float32)
train_size = datasets_info.splits['train'].num_examples
batch_size = 32
train_dataset = tfn.util.tune_dataset(
train_dataset,
batch_size=batch_size,
shuffle_size=int(train_size / 7),
preprocess_fn=_preprocess)
eval_dataset = tfn.util.tune_dataset(
eval_dataset,
repeat_count=1,
preprocess_fn=_preprocess)
x = next(iter(eval_dataset.batch(10)))
tfn.util.display_imgs(x);
In [0]:
input_shape = datasets_info.features['image'].shape
encoded_size = 16
base_depth = 32
In [0]:
prior = tfd.Sample(tfd.Normal(loc=0, scale=1), sample_shape=encoded_size)
In [0]:
Conv = functools.partial(
tfn.Convolution,
init_kernel_fn=tf.initializers.he_uniform()) # Better for leaky_relu.
encoder = tfn.Sequential([
lambda x: 2. * tf.cast(x, tf.float32) - 1., # Center.
Conv(1, 1 * base_depth, 5, strides=1, padding='same'),
tf.nn.leaky_relu,
Conv(1 * base_depth, 1 * base_depth, 5, strides=2, padding='same'),
tf.nn.leaky_relu,
Conv(1 * base_depth, 2 * base_depth, 5, strides=1, padding='same'),
tf.nn.leaky_relu,
Conv(2 * base_depth, 2 * base_depth, 5, strides=2, padding='same'),
tf.nn.leaky_relu,
Conv(2 * base_depth, 4 * encoded_size, 7, strides=1, padding='valid'),
tf.nn.leaky_relu,
tfn.util.flatten_rightmost(ndims=3),
tfn.Affine(64, encoded_size + encoded_size * (encoded_size + 1) // 2),
lambda x: tfd.MultivariateNormalTriL(
loc=x[..., :encoded_size],
scale_tril=tfb.FillScaleTriL()(x[..., encoded_size:]))
], name='encoder')
print(encoder.summary())
In [0]:
DeConv = functools.partial(
tfn.ConvolutionTranspose,
init_kernel_fn=tf.initializers.he_uniform()) # Better for leaky_relu.
decoder = tfn.Sequential([
lambda x: x[..., tf.newaxis, tf.newaxis, :],
DeConv(encoded_size, 2 * base_depth, 7, strides=1, padding='valid'),
tf.nn.leaky_relu,
DeConv(2 * base_depth, 2 * base_depth, 5, strides=1, padding='same'),
tf.nn.leaky_relu,
DeConv(2 * base_depth, 2 * base_depth, 5, strides=2, padding='same'),
tf.nn.leaky_relu,
DeConv(2 * base_depth, base_depth, 5, strides=1, padding='same'),
tf.nn.leaky_relu,
DeConv(1 * base_depth, 1 * base_depth, 5, strides=2, padding='same'),
tf.nn.leaky_relu,
DeConv(1 * base_depth, 1 * base_depth, 5, strides=1, padding='same'),
tf.nn.leaky_relu,
Conv(1 * base_depth, 1, 5, strides=1, padding='same'),
tfn.util.flatten_rightmost(ndims=3),
tfb.Reshape(input_shape),
lambda x: tfd.Independent(tfd.Bernoulli(logits=x),
reinterpreted_batch_ndims=len(input_shape)),
], name='decoder')
print(decoder.summary())
In [0]:
def compute_loss(x, beta=1.):
q = encoder(x)
z = q.sample()
p = decoder(z)
kl = tf.reduce_mean(q.log_prob(z) - prior.log_prob(z), axis=-1)
# Note: we could use exact KL divergence, eg:
# kl = tf.reduce_mean(tfd.kl_divergence(q, prior))
# however we generally find that using the Monte Carlo approximation has
# lower variance.
nll = -tf.reduce_mean(p.log_prob(x), axis=-1)
loss = nll + beta * kl
return loss, (nll, kl), (q, z, p)
In [0]:
train_iter = iter(train_dataset)
def loss():
x = next(train_iter)
loss, (nll, kl), _ = compute_loss(x)
return loss, (nll, kl)
opt = tf.optimizers.Adam(learning_rate=1e-3)
fit = tfn.util.make_fit_op(
loss,
opt,
decoder.trainable_variables + encoder.trainable_variables,
grad_summary_fn=lambda gs: tf.nest.map_structure(tf.norm, gs))
In [0]:
eval_iter = iter(eval_dataset.batch(5000).repeat())
@tfn.util.tfcompile
def eval():
x = next(eval_iter)
loss, (nll, kl), _ = compute_loss(x)
return loss, (nll, kl)
In [0]:
DEBUG_MODE = False
tf.config.experimental_run_functions_eagerly(DEBUG_MODE)
In [0]:
num_train_epochs = 1. # @param { isTemplate: true}
num_evals = 200 # @param { isTemplate: true}
dur_sec = dur_num = 0
num_train_steps = int(num_train_epochs * train_size)
for i in range(num_train_steps):
start = time.time()
trn_loss, (trn_nll, trn_kl), g = fit()
stop = time.time()
dur_sec += stop - start
dur_num += 1
if i % int(num_train_steps / num_evals) == 0 or i == num_train_steps - 1:
tst_loss, (tst_nll, tst_kl) = eval()
f, x = zip(*[
('it:{:5}', opt.iterations),
('ms/it:{:6.4f}', dur_sec / max(1., dur_num) * 1000.),
('trn_loss:{:6.4f}', trn_loss),
('tst_loss:{:6.4f}', tst_loss),
('tst_nll:{:6.4f}', tst_nll),
('tst_kl:{:6.4f}', tst_kl),
('sum_norm_grad:{:6.4f}', sum(g)),
])
print(' '.join(f).format(*[getattr(x_, 'numpy', lambda: x_)()
for x_ in x]))
sys.stdout.flush()
dur_sec = dur_num = 0
# if i % 1000 == 0 or i == maxiter - 1:
# encoder.save('/tmp/encoder.npz')
# decoder.save('/tmp/decoder.npz')
In [0]:
# We'll just examine ten random digits.
x = next(iter(eval_dataset.batch(100)))
xhat = decoder(encoder(x).sample())
assert isinstance(xhat, tfd.Distribution)
In [0]:
print('Originals:')
tfn.util.display_imgs(x);
print('Decoded Random Samples:')
tfn.util.display_imgs(xhat.sample());
print('Decoded Modes:')
tfn.util.display_imgs(xhat.mode());
print('Decoded Means:')
tfn.util.display_imgs(xhat.mean());