In [ ]:
import sys
import yaml
import tensorflow as tf
import numpy as np
import pandas as pd
import functools
from pathlib import Path
from datetime import datetime
from tqdm import tqdm_notebook as tqdm
# Plotting
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
plt.rcParams['animation.ffmpeg_path'] = str('/usr/bin/ffmpeg /usr/share/ffmpeg')
%load_ext autoreload
%autoreload 2
from progan import ProGan
import gan_utils
from load_data import preprocess_images
from ds_utils.plot_utils import plot_sample_imgs
In [ ]:
data_folder = Path.home() / "Documents/datasets"
In [ ]:
# load model config
with open('configs/progan_celeba_config.yaml', 'r') as f:
config = yaml.load(f)
HIDDEN_DIM = config['data']['z_size']
IMG_SHAPE = config['data']['input_shape']
BATCH_SIZE = config['training']['batch_size']
IMG_IS_BW = IMG_SHAPE[2] == 1
PLOT_IMG_SHAPE = IMG_SHAPE[:2] if IMG_IS_BW else IMG_SHAPE
config
In [ ]:
# load Fashion MNIST dataset
((X_train, y_train), (X_test, y_test)) = tf.keras.datasets.fashion_mnist.load_data()
In [ ]:
X_train = preprocess_images(X_train)
X_test = preprocess_images(X_test)
print(X_train[0].shape)
print(X_train[0].max())
print(X_train[0].min())
print(X_train.shape)
assert X_train[0].shape == tuple(config['data']['input_shape'])
In [ ]:
train_ds = tf.data.Dataset.from_tensor_slices(X_train).take(5000)
test_ds = tf.data.Dataset.from_tensor_slices(X_test).take(256)
In [ ]:
sys.path.append("../")
from tmp_load_data import load_imgs_tfdataset
In [ ]:
train_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 500, zipped=False, tanh_range=True)
test_ds = load_imgs_tfdataset(data_folder/'img_align_celeba', '*.jpg', config, 100, zipped=False, tanh_range=True)
In [ ]:
for a in train_ds:
n_a = a.numpy()
print(n_a.shape)
print(n_a.max())
print(n_a.min())
print(n_a.shape)
plt.imshow((n_a+1)/2)
break
In [ ]:
# instantiate GAN
gan = ProGan(config)
In [ ]:
# test generator
generator = gan.generators[2][0]
generator_out = generator.predict(np.random.randn(BATCH_SIZE, HIDDEN_DIM))
generator_out.shape
In [ ]:
generator_out.max()
In [ ]:
# test discriminator
discriminator = gan.discriminators[2][0]
discriminator_out = discriminator.predict(generator_out)
discriminator_out.shape
In [ ]:
# plot random generated image
plot_img_shape = generator.output_shape[1:]
plt.imshow(generator.predict([np.random.randn(1, HIDDEN_DIM)])[0]
.reshape(plot_img_shape), cmap='gray' if IMG_IS_BW else 'jet')
plt.show()
In [ ]:
# setup model directory for checkpoint and tensorboard logs
model_name = "progan_celeba"
model_dir = Path.home() / "Documents/models/tf_playground/gan" / model_name
model_dir.mkdir(exist_ok=True, parents=True)
export_dir = model_dir / 'export'
export_dir.mkdir(exist_ok=True)
log_dir = model_dir / "logs" / datetime.now().strftime("%Y%m%d-%H%M%S")
In [ ]:
nb_epochs = config['training']['nb_epochs']
gan.train(train_ds=train_ds,
validation_ds=test_ds,
nb_epochs=nb_epochs,
log_dir=log_dir,
checkpoint_dir=None,
is_tfdataset=True)
In [ ]:
# export Keras model (.h5)
gan.generator.save(str(export_dir / 'generator.h5'))
gan.discriminator.save(str(export_dir / 'discriminator.h5'))
In [ ]:
# plot generator results
plot_side = 5
plot_generator = gan.generators[2][0]
plot_img_shape = plot_generator.output_shape[1:]
plot_sample_imgs(lambda x: plot_generator.predict(np.random.randn(plot_side*plot_side, HIDDEN_DIM)),
img_shape=plot_img_shape,
plot_side=plot_side,
cmap='gray' if IMG_IS_BW else 'jet')
In [ ]:
%matplotlib inline
In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_celeba"
nb_samples = 30
nb_transition_frames = 10
nb_frames = min(2000, (nb_samples-1)*nb_transition_frames)
# random list of z vectors
z_s = np.random.randn(nb_samples, 1, HIDDEN_DIM)
# setup plot
dpi = 100
fig, ax = plt.subplots(dpi=dpi, figsize=(PLOT_IMG_SHAPE[0] / dpi, PLOT_IMG_SHAPE[1] / dpi))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
im = ax.imshow(gan.generator.predict(z_s[0])[0].reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')
plt.axis('off')
def animate(i, gan, z_s, nb_transition_frames):
z_start = z_s[i//nb_transition_frames]
z_end = z_s[i//nb_transition_frames+1]
z_diff = z_end - z_start
cur_z = z_start + (z_diff/nb_transition_frames)*(i%nb_transition_frames)
im.set_data(gan.generator.predict(cur_z)[0].reshape(PLOT_IMG_SHAPE))
ani = animation.FuncAnimation(fig, animate, frames=nb_frames, interval=1,
fargs=[gan, z_s, nb_transition_frames])
if render_dir:
render_dir.mkdir(parents=True, exist_ok=True)
ani.save(str(render_dir / (datetime.now().strftime("%Y%m%d-%H%M%S") + '.mp4')),
animation.FFMpegFileWriter(fps=30))
In [ ]:
render_dir = Path.home() / 'Documents/videos/gan' / "gan_fmnist_idxs"
nb_transition_frames = 150
# random list of z vectors
rand_idx = np.random.randint(len(X_train))
z_start = np.random.randn(1, HIDDEN_DIM)
vals = np.linspace(-1., 1., nb_transition_frames)
# setup plot
dpi = 100
fig, ax = plt.subplots(dpi=dpi, figsize=(PLOT_IMG_SHAPE[0] / dpi, PLOT_IMG_SHAPE[1] / dpi))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
#fig, ax = plt.subplots(dpi=100, figsize=(5, 4))
im = ax.imshow(gan.generator.predict(z_s[0])[0].reshape(PLOT_IMG_SHAPE), cmap='gray' if IMG_IS_BW else 'jet')
plt.axis('off')
def animate(i, gan, z_start, idx, vals):
z_start[0][idx:idx+10] = vals[i]
im.set_data(gan.generator.predict(z_start)[0].reshape(PLOT_IMG_SHAPE))
for z_idx in range(100):
ani = animation.FuncAnimation(fig, animate, frames=nb_transition_frames, interval=10,
fargs=[gan, z_start.copy(), z_idx, vals])
if render_dir:
render_dir.mkdir(parents=True, exist_ok=True)
ani.save(str(render_dir / 'idx{}.mp4'.format(z_idx)), animation.FFMpegFileWriter(fps=30))