In [ ]:
import sys
import yaml
import tensorflow as tf
import numpy as np
import pandas as pd
import functools
from pathlib import Path
from datetime import datetime
from tqdm import tqdm_notebook as tqdm
# Plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
plt.rcParams['animation.ffmpeg_path'] = str(Path.home() / "anaconda3/envs/image-processing/bin/ffmpeg")
%load_ext autoreload
%autoreload 2
import dcgan
import gan_utils
from load_data import preprocess_images
from ds_utils.generative_utils import animate_latent_transition, gen_latent_linear, gen_latent_idx
from ds_utils.plot_utils import plot_sample_imgs
In [ ]:
data_folder = Path.home() / "Documents/datasets"
In [ ]:
# load model config
with open('configs/dcgan_celeba_config.yaml', 'r') as f:
config = yaml.load(f)
HIDDEN_DIM = config['data']['z_size']
IMG_SHAPE = config['data']['input_shape']
BATCH_SIZE = config['training']['batch_size']
IMG_IS_BW = IMG_SHAPE[2] == 1
PLOT_IMG_SHAPE = IMG_SHAPE[:2] if IMG_IS_BW else IMG_SHAPE
config
In [ ]:
# load Fashion MNIST dataset
((X_train, y_train), (X_test, y_test)) = tf.keras.datasets.fashion_mnist.load_data()
In [ ]:
X_train = preprocess_images(X_train)
X_test = preprocess_images(X_test)
print(X_train[0].shape)
print(X_train[0].max())
print(X_train[0].min())
print(X_train.shape)
assert X_train[0].shape == tuple(config['data']['input_shape'])
In [ ]:
train_ds = tf.data.Dataset.from_tensor_slices(X_train).take(5000)
test_ds = tf.data.Dataset.from_tensor_slices(X_test).take(256)
In [ ]:
sys.path.append("../")
from tmp_load_data import load_imgs_tfdataset
In [ ]:
train_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 500, zipped=False)
test_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 100, zipped=False)
In [ ]:
# instantiate GAN
gan = dcgan.DCGan(IMG_SHAPE, config)
In [ ]:
# test generator
generator_out = gan.generator.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM))
generator_out.shape
In [ ]:
# test discriminator
discriminator_out = gan.discriminator.predict(generator_out)
discriminator_out.shape
In [ ]:
# test gan
gan.gan.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM)).max()
In [ ]:
# plot random generated image
plt.imshow(gan.generator.predict([np.random.randn(1, HIDDEN_DIM)])[0]
.reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')
plt.show()
In [ ]:
gan.generator.summary()
In [ ]:
# setup model directory for checkpoint and tensorboard logs
model_name = "dcgan_celeba"
model_dir = Path.home() / "Documents/models/tf_playground/gan" / model_name
model_dir.mkdir(exist_ok=True, parents=True)
export_dir = model_dir / 'export'
export_dir.mkdir(exist_ok=True)
log_dir = model_dir / "logs" / datetime.now().strftime("%Y%m%d-%H%M%S")
In [ ]:
nb_epochs = 1000
gan._train(train_ds=gan.setup_dataset(train_ds),
validation_ds=gan.setup_dataset(test_ds),
nb_epochs=nb_epochs,
log_dir=log_dir,
checkpoint_dir=export_dir,
is_tfdataset=True)
In [ ]:
# export Keras model (.h5)
gan.generator.save(str(export_dir / 'generator.h5'))
gan.discriminator.save(str(export_dir / 'discriminator.h5'))
In [ ]:
# plot generator results
plot_side = 5
plot_sample_imgs(lambda x: gan.generator.predict(np.random.randn(plot_side*plot_side, HIDDEN_DIM)),
img_shape=PLOT_IMG_SHAPE,
plot_side=plot_side,
cmap='gray' if IMG_IS_BW else 'jet')
In [ ]:
%matplotlib inline
In [ ]:
def gen_image_fun(latent_vectors):
img = gan.generator.predict(latent_vectors)[0].reshape(PLOT_IMG_SHAPE)
return img
In [ ]:
img = gen_image_fun(z_s)
In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_celeba"
nb_samples = 10
nb_transition_frames = 10
nb_frames = min(2000, (nb_samples-1)*nb_transition_frames)
# random list of z vectors
z_s = np.random.randn(nb_samples, HIDDEN_DIM)
animate_latent_transition(latent_vectors=z_s,
gen_image_fun=gen_image_fun,
gen_latent_fun=lambda z_s, i: gen_latent_linear(z_s, i, nb_transition_frames),
img_size=PLOT_IMG_SHAPE,
nb_frames=nb_frames,
render_dir=render_dir)
In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_fmnist_test"
nb_transition_frames = 10
# random list of z vectors
#rand_idx = np.random.randint(len(X_train))
z_start = np.random.randn(1, HIDDEN_DIM)
vals = np.linspace(-1., 1., nb_transition_frames)
for z_idx in range(20):
animate_latent_transition(latent_vectors=z_start,
gen_image_fun=gen_image_fun,
gen_latent_fun=lambda z_s, i: gen_latent_idx(z_s, i, z_idx, vals),
img_size=PLOT_IMG_SHAPE,
nb_frames=nb_transition_frames,
render_dir=render_dir)