In [ ]:
import sys
import yaml
import tensorflow as tf
import numpy as np
import pandas as pd
import functools
from pathlib import Path
from datetime import datetime
from tqdm import tqdm_notebook as tqdm

# Plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
plt.rcParams['animation.ffmpeg_path'] = str('/usr/bin/ffmpeg /usr/share/ffmpeg')

%load_ext autoreload
%autoreload 2

from progan import ProGan
import gan_utils
from load_data import preprocess_images
from ds_utils.plot_utils import plot_sample_imgs

In [ ]:
data_folder = Path.home() / "Documents/datasets"

In [ ]:
# load model config
with open('configs/progan_celeba_config.yaml', 'r') as f:
    config = yaml.load(f)
HIDDEN_DIM = config['data']['z_size']
IMG_SHAPE = config['data']['input_shape']
BATCH_SIZE = config['training']['batch_size']
IMG_IS_BW = IMG_SHAPE[2] == 1
PLOT_IMG_SHAPE = IMG_SHAPE[:2] if IMG_IS_BW else IMG_SHAPE
config

Data


In [ ]:
# load Fashion MNIST dataset
((X_train, y_train), (X_test, y_test)) = tf.keras.datasets.fashion_mnist.load_data()

In [ ]:
X_train = preprocess_images(X_train)
X_test = preprocess_images(X_test)

print(X_train[0].shape)
print(X_train[0].max())
print(X_train[0].min())

print(X_train.shape)

assert X_train[0].shape == tuple(config['data']['input_shape'])

In [ ]:
train_ds = tf.data.Dataset.from_tensor_slices(X_train).take(5000)
test_ds = tf.data.Dataset.from_tensor_slices(X_test).take(256)

In [ ]:
sys.path.append("../")
from tmp_load_data import load_imgs_tfdataset

In [ ]:
train_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 500, zipped=False, tanh_range=True)
test_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 100, zipped=False, tanh_range=True)

In [ ]:
for a in train_ds:
    n_a = a.numpy()
    print(n_a.shape)
    print(n_a.max())
    print(n_a.min())

    print(n_a.shape)
    plt.imshow((n_a+1)/2)
    break

Model


In [ ]:
# instantiate GAN
gan = ProGan(config)

In [ ]:
# test generator
generator = gan.generators[2][0]
generator_out = generator.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM))
generator_out.shape

In [ ]:
generator_out.max()

In [ ]:
# test discriminator
discriminator = gan.discriminators[2][0]
discriminator_out = discriminator.predict(generator_out)
discriminator_out.shape

In [ ]:
# plot random generated image
plot_img_shape = generator.output_shape[1:]
plt.imshow(generator.predict([np.random.randn(1, HIDDEN_DIM)])[0]
           .reshape(plot_img_shape), cmap='gray' if IMG_IS_BW else 'jet')
plt.show()

Training


In [ ]:
# setup model directory for checkpoint and tensorboard logs
model_name = "progan_celeba"
model_dir = Path.home() / "Documents/models/tf_playground/gan" / model_name
model_dir.mkdir(exist_ok=True, parents=True)
export_dir = model_dir / 'export'
export_dir.mkdir(exist_ok=True)
log_dir = model_dir / "logs" / datetime.now().strftime("%Y%m%d-%H%M%S")

In [ ]:
nb_epochs = config['training']['nb_epochs']
gan.train(train_ds=train_ds,
            validation_ds=test_ds,
            nb_epochs=nb_epochs,
            log_dir=log_dir,
            checkpoint_dir=None,
            is_tfdataset=True)

In [ ]:
# export Keras model (.h5)
gan.generator.save(str(export_dir / 'generator.h5'))
gan.discriminator.save(str(export_dir / 'discriminator.h5'))

In [ ]:
# plot generator results
plot_side = 5
plot_generator = gan.generators[2][0]
plot_img_shape = plot_generator.output_shape[1:]
plot_sample_imgs(lambda x: plot_generator.predict(np.random.randn(plot_side*plot_side, HIDDEN_DIM)), 
                 img_shape=plot_img_shape,
                 plot_side=plot_side,
                 cmap='gray' if IMG_IS_BW else 'jet')

Explore Latent Space


In [ ]:
%matplotlib inline

In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_celeba"

nb_samples = 30
nb_transition_frames = 10
nb_frames = min(2000, (nb_samples-1)*nb_transition_frames)

# random list of z vectors
z_s = np.random.randn(nb_samples, 1, HIDDEN_DIM)

# setup plot
dpi = 100
fig, ax = plt.subplots(dpi=dpi, figsize=(PLOT_IMG_SHAPE[0] / dpi, PLOT_IMG_SHAPE[1] / dpi))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
im = ax.imshow(gan.generator.predict(z_s[0])[0].reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')
plt.axis('off')

def animate(i, gan, z_s, nb_transition_frames):
    z_start = z_s[i//nb_transition_frames]
    z_end = z_s[i//nb_transition_frames+1]
    z_diff = z_end - z_start
    cur_z = z_start + (z_diff/nb_transition_frames)*(i%nb_transition_frames)
    im.set_data(gan.generator.predict(cur_z)[0].reshape(PLOT_IMG_SHAPE))

ani = animation.FuncAnimation(fig, animate, frames=nb_frames, interval=1, 
                              fargs=[gan, z_s, nb_transition_frames])

if render_dir:
    render_dir.mkdir(parents=True, exist_ok=True)
    ani.save(str(render_dir / (datetime.now().strftime("%Y%m%d-%H%M%S") + '.mp4')), 
             animation.FFMpegFileWriter(fps=30))

In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_fmnist_idxs"

nb_transition_frames = 150

# random list of z vectors
rand_idx = np.random.randint(len(X_train))
z_start = np.random.randn(1, HIDDEN_DIM)
vals = np.linspace(-1., 1., nb_transition_frames)

# setup plot
dpi = 100
fig, ax = plt.subplots(dpi=dpi, figsize=(PLOT_IMG_SHAPE[0] / dpi, PLOT_IMG_SHAPE[1] / dpi))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
#fig, ax = plt.subplots(dpi=100, figsize=(5, 4))
im = ax.imshow(gan.generator.predict(z_s[0])[0].reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')
plt.axis('off')

def animate(i, gan, z_start, idx, vals):
    z_start[0][idx:idx+10] = vals[i]
    im.set_data(gan.generator.predict(z_start)[0].reshape(PLOT_IMG_SHAPE))

for z_idx in range(100):
    ani = animation.FuncAnimation(fig, animate, frames=nb_transition_frames, interval=10, 
                                  fargs=[gan, z_start.copy(), z_idx, vals])

    if render_dir:
        render_dir.mkdir(parents=True, exist_ok=True)
        ani.save(str(render_dir / 'idx{}.mp4'.format(z_idx)), animation.FFMpegFileWriter(fps=30))