In [0]:
    
from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
from IPython import display
    
In [4]:
    
dataname = 'mnist' 
#dataname = 'fashion_mnist'
#dataname = 'cifar10' # https://www.tensorflow.org/datasets/catalog/cifar10
# Useful pre-processing functions
#https://github.com/google/compare_gan/blob/master/compare_gan/datasets.py
  
datasets, datasets_info = tfds.load(name=dataname, with_info=True, as_supervised=False)
print(datasets_info)
input_shape = datasets_info.features['image'].shape
print(input_shape)
num_colors = input_shape[2]
    
    
In [0]:
    
batchsize = 64
# We assume the dataset has a dict of features called image and label.
# We extract the image from the dict, and scale each channel to [0,1]
# We return a tuple (rescaled-image, label).
def scale_pixels(sample):
  img = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
  label = sample['label']
  return img, label
def scale_pixels_and_drop_label(sample):
  img = tf.cast(sample['image'], tf.float32) / 255.  # Scale to unit interval.
  return img
  
drop_label = True
if drop_label:
  preprocess = scale_pixels_and_drop_label
else:
  preprocess = scale_pixels
  
def preprocess_celeba(features):
    """Returns 64x64x3 image and constant label."""
    image = features["image"]
    image = tf.image.resize_image_with_crop_or_pad(image, 160, 160)
    # Note: possibly consider using NumPy's imresize(image, (64, 64))
    image = tf.image.resize_images(image, [64, 64])
    image = tf.cast(image, tf.float32) / 255.0
    label = tf.constant(0, dtype=tf.int32)
    return image, label
    
train_dataset = (datasets['train']
                 .map(preprocess)
                 .batch(batchsize)
                 .prefetch(tf.data.experimental.AUTOTUNE)
                 .shuffle(int(10e3)))
test_dataset = (datasets['test']
                .map(preprocess)
                .batch(batchsize)
                .prefetch(tf.data.experimental.AUTOTUNE))
    
In [0]:
    
# To make it easy to perform random access to the test set,
# we convert to a vanilla numpy array
test_ds = tfds.as_numpy(datasets['test'].map(scale_pixels_and_drop_label))
L = list(test_ds) # force the generator to yield
x_test = np.stack(L, axis=0) # 10k, 28, 28, 1
def extract_label(sample):
  return sample['label']
test_ds = tfds.as_numpy(datasets['test'].map(extract_label))
L = list(test_ds) # force the generator to yield
y_test = np.stack(L, axis=0) # 10k
n_test = len(y_test)
    
In [7]:
    
# Inspect the dataset we just created
i = 0
for batch in train_dataset:
  if drop_label:
    X = batch
    print(X.shape)
  else:
    X, y = batch
    print(X.shape)
    print(y.shape)
  i += 1
  if i > 1:
    break
    
    
In [8]:
    
# extract small amount of data for testing
batch = train_dataset.take(1) # slow!!!
first_batch = list(batch)
if drop_label:
  X = first_batch[0]
else:
  X, y = first_batch[0]
print(X.shape) # B, H, W, C
Xsmall = X[:3, :, :, :]
print(Xsmall.shape)
    
    
In [0]:
    
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
    
In [0]:
    
def make_encoder(
        input_dim
        , encoder_conv_filters
        , encoder_conv_kernel_size
        , encoder_conv_strides
        , z_dim
        , use_batch_norm = False
        , use_dropout= False
        ):
  encoder_input = Input(shape=input_dim, name='encoder_input')
  x = encoder_input
  n_layers_encoder = len(encoder_conv_filters)
  for i in range(n_layers_encoder):
      conv_layer = Conv2D(
          filters = encoder_conv_filters[i]
          , kernel_size = encoder_conv_kernel_size[i]
          , strides = encoder_conv_strides[i]
          , padding = 'same'
          , name = 'encoder_conv_' + str(i)
          )
      x = conv_layer(x)
      if use_batch_norm:
          x = BatchNormalization()(x)
      x = LeakyReLU()(x)
      if use_dropout:
          x = Dropout(rate = 0.25)(x)
  shape_before_flattening = K.int_shape(x)[1:]
  x = Flatten()(x)
  mu = Dense(z_dim, name='mu')(x) # no activation
  log_var = Dense(z_dim, name='log_var')(x) # no activation
  encoder = Model(encoder_input, (mu, log_var))
  return encoder, shape_before_flattening
    
In [13]:
    
# Test
encoder, shape_before_flattening = make_encoder(
    input_dim = input_shape
    , encoder_conv_filters = [32,64]
    , encoder_conv_kernel_size = [3,3]
    , encoder_conv_strides = [2,2]
    , z_dim = 2
)
#encoder.summary()
print(shape_before_flattening)
print(Xsmall.shape)
M, V = encoder(Xsmall)
print(M.shape)
print(V.shape)
    
    
In [0]:
    
def make_decoder(
        shape_before_flattening
        , decoder_conv_t_filters
        , decoder_conv_t_kernel_size
        , decoder_conv_t_strides
        , z_dim
        , use_batch_norm = False
        , use_dropout= False
        ):
  decoder_input = Input(shape=(z_dim,), name='decoder_input')
  x = Dense(np.prod(shape_before_flattening))(decoder_input)
  x = Reshape(shape_before_flattening)(x)
  n_layers_decoder = len(decoder_conv_t_filters)
  for i in range(n_layers_decoder):
      conv_t_layer = Conv2DTranspose(
          filters = decoder_conv_t_filters[i]
          , kernel_size = decoder_conv_t_kernel_size[i]
          , strides = decoder_conv_t_strides[i]
          , padding = 'same'
          , name = 'decoder_conv_t_' + str(i)
          )
      x = conv_t_layer(x)
      if i < n_layers_decoder - 1:
          if use_batch_norm:
              x = BatchNormalization()(x)
          x = LeakyReLU()(x)
          if use_dropout:
              x = Dropout(rate = 0.25)(x)
      # No activation fn in final layer since returns logits
      #else:
      #    x = Activation('sigmoid')(x)
  decoder_output = x
  decoder = Model(decoder_input, decoder_output)
  return decoder
    
In [15]:
    
# Test
decoder = make_decoder(
    shape_before_flattening
    , decoder_conv_t_filters = [64,32,num_colors]
    , decoder_conv_t_kernel_size = [3,3,3]
    , decoder_conv_t_strides = [2,2,1]
    , z_dim = 2
)
#print(decoder.summary())
Z = np.random.randn(5, 2).astype(np.float32)
Xrecon = decoder(Z)
print(Xrecon.shape)
    
    
In [0]:
    
def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)
def sample_gauss(mean, logvar):
  eps = tf.random.normal(shape=mean.shape)
  return eps * tf.exp(logvar * .5) + mean
    
In [0]:
    
class ConvVAE(tf.keras.Model):
  def __init__(self,
        input_dim,
        encoder_conv_filters,
        encoder_conv_kernel_size,
        encoder_conv_strides,
        decoder_conv_t_filters,
        decoder_conv_t_kernel_size,
        decoder_conv_t_strides,
        z_dim,
        use_batch_norm = False,
        use_dropout= False,
        recon_loss_scaling = 1,
        kl_loss_scaling = 1,
        use_mse_loss = False
        ):
    super(ConvVAE, self).__init__()
    self.latent_dim = z_dim
    self.recon_loss_scaling = recon_loss_scaling
    self.kl_loss_scaling = kl_loss_scaling
    self.use_mse_loss = use_mse_loss
    self.inference_net, self.shape_before_flattening = make_encoder(
        input_dim,
        encoder_conv_filters,
        encoder_conv_kernel_size,
        encoder_conv_strides,
        z_dim,
        use_batch_norm,
        use_dropout)
    self.generative_net = make_decoder(
        self.shape_before_flattening,
        decoder_conv_t_filters,
        decoder_conv_t_kernel_size,
        decoder_conv_t_strides,
        z_dim,
        use_batch_norm,
        use_dropout)
  @tf.function
  def sample(self, nsamples=1):
    eps = tf.random.normal(shape=(nsamples, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)
  def encode_stochastic(self, x):
    mean, logvar = self.inference_net(x)
    return sample_gauss(mean, logvar)
  def decode(self, z, apply_sigmoid=True):
    logits = self.generative_net(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits
  
  
  @tf.function
  def compute_loss(self, x):
    mean, logvar = self.inference_net(x)
    z = sample_gauss(mean, logvar)
    if self.use_mse_loss:
      x_probs = self.decode(z, apply_sigmoid=True)
      mse = tf.reduce_mean( (x - x_probs) ** 2, axis=[1, 2, 3])
      logpx_z = -0.5*mse # log exp(-0.5 (x-mu)^2)
    else:
      x_logit = self.decode(z, apply_sigmoid=False)
      cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x) # -sum_{c=0}^1 p_c log q_c
      logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3]) # sum over H,W,C
    logpz = log_normal_pdf(z, 0., 0.) # prior: mean=0, logvar=0
    logqz_x = log_normal_pdf(z, mean, logvar)
    kl_loss = logpz - logqz_x # MC approximation
    return -tf.reduce_mean(self.recon_loss_scaling * logpx_z + self.kl_loss_scaling * kl_loss) # -ve ELBO
          
  @tf.function
  def compute_gradients(self, x):
    with tf.GradientTape() as tape:
      loss = self.compute_loss(x)
    gradients = tape.gradient(loss, self.trainable_variables)
    return gradients
    
In [0]:
    
model_old = ConvVAE(
    input_dim = input_shape,
    encoder_conv_filters = [32,64],
    encoder_conv_kernel_size = [3,3],
    encoder_conv_strides = [2,2],
    decoder_conv_t_filters = [64,32,num_colors],
    decoder_conv_t_kernel_size = [3,3,3],
    decoder_conv_t_strides = [2,2,1],
    z_dim = 50
)
# Match setting from
# https://github.com/davidADSP/GDL_code/blob/master/03_03_vae_digits_train.ipynb
model = ConvVAE(
    input_dim = input_shape,
    encoder_conv_filters = [32,64,64, 64],
    encoder_conv_kernel_size = [3,3,3,3],
    encoder_conv_strides = [1,2,2,1],
    decoder_conv_t_filters = [64,64,32,num_colors],
    decoder_conv_t_kernel_size = [3,3,3,3],
    decoder_conv_t_strides = [1,2,2,1],
    z_dim = 2,
    recon_loss_scaling = 1000,
    use_mse_loss = True
)
    
In [42]:
    
# Test
M, V = model.inference_net(Xsmall)
print(M.shape)
print(V.shape)
Z = model.encode_stochastic(Xsmall)
print(Z.shape)
predictions = model.decode(Z)
print(Z.shape)
print(predictions.shape)
    
    
In [43]:
    
L = model.compute_loss(Xsmall)
print(L)
g = model.compute_gradients(Xsmall)
print(g[0].shape) # 3,3,1,32 - size of first layer conv
    
    
In [46]:
    
# Callback
def generate_images(model, epoch, noise_vector):
  predictions = model.decode(noise_vector)
  n = int(np.sqrt(num_examples_to_generate))
  fig = plt.figure(figsize=(n,n))
  for i in range(predictions.shape[0]):
      plt.subplot(n, n, i+1)
      if num_colors == 1:
         plt.imshow(predictions[i, :, :, 0], cmap='gray')
      else:
         plt.imshow(predictions[i, :, :, :])
      plt.axis('off')
  
  
num_examples_to_generate = 25 # 16
# We use fixed noise vector to generate samples during training so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[num_examples_to_generate, model.latent_dim])
# we assume model.compute_loss(batch) is defined
# as well as test_dataset
def callback(model, epoch, elapsed_time):
  loss_tracker = tf.keras.metrics.Mean()
  for batch in test_dataset:
    loss_tracker(model.compute_loss(batch))
  test_loss = loss_tracker.result()
  #display.clear_output(wait=False) # don't erase old outputs
  print('Epoch {}, Test loss: {:0.5f}, time {:0.2f}'.format(
            epoch, test_loss, elapsed_time)) 
  generate_images(model, epoch, random_vector_for_generation)
  plt.suptitle('epoch {}, loss {:0.5f}'.format(epoch, test_loss), fontsize='x-large')
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()
  
  
# Test 
callback(model, 0, 42)
    
    
    
In [0]:
    
# generic (model agnostic) training function.
# model must support these methods:
# g = model.compute_gradients(X)
# model.traininable_variables (so model is a subclass of Keras.Model)
# Callback has following interface: callback(model, epoch, elapsed_time)
def train_model(model, optimizer, train_dataset,
                epochs, callback=None, print_every_n_epochs=1):
  for epoch in range(0, epochs):
    start_time = time.time()
    for batch in train_dataset:
      gradients = model.compute_gradients(batch)
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    end_time = time.time()
    elapsed_time = end_time - start_time
    if callback:
       callback(model, epoch, elapsed_time)                               
  return model
    
In [48]:
    
optimizer = tf.keras.optimizers.Adam(1e-3)
epochs = 5
model = train_model(model, optimizer, train_dataset, epochs, callback)
    
    
    
    
    
    
    
    
    
    
    
In [0]:
    
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
    
In [50]:
    
for i in range(epochs):
  plt.imshow(display_image(i))
  plt.axis('off')
  #plt.title('epoch {}'.format(i))
  plt.show()
    
    
    
    
    
    
In [0]:
    
# Make some wrapper functions
def vae_encode(model, x):
  mean, logvar = model.inference_net(x)
  return mean
def vae_decode(model, z_points):
  return model.decode(z_points)
    
In [52]:
    
# Reconstruct images
n_to_show = 10
np.random.seed(42)
example_idx = np.random.choice(n_test, n_to_show)
example_images = x_test[example_idx]
z_points = vae_encode(model, example_images)
reconst_images = vae_decode(model, z_points)
fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(n_to_show):
    #img = example_images[i].squeeze()
    img = example_images[i,:,:,0]
    sub = fig.add_subplot(2, n_to_show, i+1)
    sub.axis('off')
    #sub.text(0.5, -0.35, str(np.round(z_points[i],1)), fontsize=10, ha='center', transform=sub.transAxes)         
    sub.imshow(img, cmap='gray_r')
for i in range(n_to_show):
    #img = reconst_images[i].squeeze()
    img = reconst_images[i,:,:,0]
    sub = fig.add_subplot(2, n_to_show, i+n_to_show+1)
    sub.axis('off')
    sub.imshow(img, cmap='gray_r')
    
    
In [53]:
    
# Show 2d embedding of random images
n_to_show = 5000
figsize = 8
np.random.seed(42)
example_idx = np.random.choice(range(n_test), n_to_show)
example_images = x_test[example_idx]
example_labels = y_test[example_idx]
z_points = vae_encode(model, example_images)
min_x = min(z_points[:, 0])
max_x = max(z_points[:, 0])
min_y = min(z_points[:, 1])
max_y = max(z_points[:, 1])
plt.figure(figsize=(figsize, figsize))
plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)
plt.show()
    
    
In [54]:
    
# generate images from random points in latent space
figsize = 8
plt.figure(figsize=(figsize, figsize))
plt.scatter(z_points[:, 0] , z_points[:, 1], c='black', alpha=0.5, s=2)
grid_size = 15
grid_depth = 2
figsize = 15
np.random.seed(42)
x = np.random.normal(size = grid_size * grid_depth)
y = np.random.normal(size = grid_size * grid_depth)
z_grid = np.array(list(zip(x, y)))
reconst = vae_decode(model, z_grid)
plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'red', alpha=1, s=20)
n = np.shape(z_grid)[0]
for i in range(n):
  x = z_grid[i,0]
  y = z_grid[i,1]
  plt.text(x, y, i)
plt.show()
fig = plt.figure(figsize=(figsize, grid_depth))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size*grid_depth):
    ax = fig.add_subplot(grid_depth, grid_size, i+1)
    ax.axis('off')
    #ax.text(0.5, -0.35, str(np.round(z_grid[i],1)), fontsize=8, ha='center', transform=ax.transAxes)
    ax.text(0.5, -0.35, str(i))
    ax.imshow(reconst[i, :,:,0], cmap = 'Greys')
    
    
    
    
    
In [55]:
    
# color code latent points
from scipy.stats import norm
n_to_show = 5000
grid_size = 15
figsize = 8
np.random.seed(42)
example_idx = np.random.choice(range(len(x_test)), n_to_show)
example_images = x_test[example_idx]
example_labels = y_test[example_idx]
z_points = vae_encode(model, example_images)
p_points = norm.cdf(z_points)
plt.figure(figsize=(figsize, figsize))
plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels
            , alpha=0.5, s=2)
plt.colorbar()
plt.figure(figsize=(figsize, figsize))
plt.scatter(p_points[:, 0] , p_points[:, 1] , cmap='rainbow' , c= example_labels
            , alpha=0.5, s=5)
plt.colorbar()
plt.show()
    
    
    
In [56]:
    
# Generate images from 2d grid
n_to_show = 5000 #500
grid_size = 20
figsize = 8
np.random.seed(0)
np.random.seed(42)
example_idx = np.random.choice(range(len(x_test)), n_to_show)
example_images = x_test[example_idx]
example_labels = y_test[example_idx]
z_points = vae_encode(model, example_images)
plt.figure(figsize=(figsize, figsize))
plt.scatter(z_points[:, 0] , z_points[:, 1] , cmap='rainbow' , c= example_labels
            , alpha=0.5, s=2)
plt.colorbar()
x = norm.ppf(np.linspace(0.01, 0.99, grid_size))
y = norm.ppf(np.linspace(0.01, 0.99, grid_size))
xv, yv = np.meshgrid(x, y)
xv = xv.flatten()
yv = yv.flatten()
z_grid = np.array(list(zip(xv, yv)))
reconst = vae_decode(model, z_grid)
plt.scatter(z_grid[:, 0] , z_grid[:, 1], c = 'black'#, cmap='rainbow' , c= example_labels
            , alpha=1, s=5)
plt.show()
fig = plt.figure(figsize=(figsize, figsize))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(grid_size**2):
    ax = fig.add_subplot(grid_size, grid_size, i+1)
    ax.axis('off')
    ax.imshow(reconst[i, :,:,0], cmap = 'Greys')
plt.show()