Playground for experiments with StyleGANv2 latents. Includes interactive style mixing, latents interpolation or morphing and latents tweaking.
In [ ]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
from datetime import datetime
from tqdm import tqdm
import imageio
# ffmpeg installation location, for creating videos
plt.rcParams['animation.ffmpeg_path'] = str(Path.home() / "Documents/dev_tools/ffmpeg-20190623-ffa64a4-win64-static/bin/ffmpeg.exe")
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.display import display
from ipywidgets import Button
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
%load_ext autoreload
%autoreload 2
# StyleGAN Utils
from stylegan_utils import load_network, gen_image_fun, synth_image_fun, create_video
# StyleGAN2 Repo
sys.path.append(os.path.join(os.pardir, 'stylegan2encoder'))
import run_projector
import projector
import training.dataset
import training.misc
# Data Science Utils
sys.path.append(os.path.join(os.pardir, 'data-science-learning'))
from ds_utils import generative_utils
In [ ]:
res_dir = Path.home() / 'Documents/generated_data/stylegan'
In [ ]:
MODELS_DIR = Path("C:/Users/User/Documents/models/stylegan2")
MODEL_NAME = 'original_ffhq'
SNAPSHOT_NAME = 'stylegan2-ffhq-config-f'
Gs, Gs_kwargs, noise_vars = load_network(str(MODELS_DIR / MODEL_NAME / SNAPSHOT_NAME) + '.pkl')
Z_SIZE = Gs.input_shape[1:][0]
IMG_SIZE = Gs.output_shape[2:]
IMG_SIZE
In [ ]:
# used when saving the currently displayed image
current_displayed_latents = None
In [ ]:
def load_latents(latents):
# If not already numpy array, load the latents
if type(latents) is not np.ndarray:
latents = np.load(latents)
# TMP fix for when saved latens as [1, 16, 512]
if len(latents.shape) == 3:
assert latents.shape[0] == 1
latents = latents[0]
return latents
In [ ]:
def generate_mix(latents_1, latents_2, style_layers_idxs, synth_image_fun, alpha=1):
latents_1 = load_latents(latents_1)
latents_2 = load_latents(latents_2)
assert latents_1.shape == latents_2.shape
# crossover option, from latents_1 to latents_2
mix_latents = latents_2.copy()
mix_latents[style_layers_idxs] = latents_1[style_layers_idxs] * alpha + mix_latents[style_layers_idxs] * (1-alpha)
# store in case we want to export results from widget
global current_displayed_latents
current_displayed_latents = mix_latents
# generate
gen_image = synth_image_fun(mix_latents[np.newaxis, :, :])
return gen_image
In [ ]:
# Setup plot image
button = Button(description="Savefig")
dpi = 100
fig, ax = plt.subplots(dpi=dpi, figsize=(IMG_SIZE[0]/dpi, IMG_SIZE[1]/2))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0)
plt.axis('off')
im = ax.imshow(gen_image_fun(Gs, np.random.rand(1, Z_SIZE), Gs_kwargs, noise_vars))
#prevent any output for this cell
plt.close()
# save current figure and latents
def on_button_clicked(b):
dest_dir = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / "picked"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
fig.savefig(dest_dir / (timestamp + '.png'), bbox_inches='tight')
global current_displayed_latents
np.save(dest_dir / (timestamp + '.npy'), current_displayed_latents)
button.on_click(on_button_clicked)
In [ ]:
data_dir = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / ''
entries = [p.name for p in data_dir.glob("*") if p.is_dir()]
entries.remove('tfrecords')
In [ ]:
%matplotlib inline
display(button)
@interact
def i_style_mixing(entry1 = entries, entry2 = entries,
from_layer = np.arange(0, 18), to_layer = np.arange(0, 18),
alpha = (-0.5, 1.5)):
assert from_layer <= to_layer
latents_1 =res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / '' / entry1 / "image_latents2000.npy"
latents_2 = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / '' / entry2 / "image_latents1000.npy"
gen_image = generate_mix(latents_1, latents_2,
style_layers_idxs=np.arange(from_layer, to_layer),
synth_image_fun=lambda dlatens : synth_image_fun(Gs, dlatens, Gs_kwargs, randomize_noise=True),
alpha=alpha)
im.set_data(gen_image)
display(fig)
In [ ]:
#PLOT_IMG_SHAPE = (512, 512, 3)
PLOT_IMG_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)
In [ ]:
render_dir = res_dir / MODEL_NAME / SNAPSHOT_NAME / "explore_latent"
nb_samples = 2
nb_transition_frames = 450
nb_frames = min(450, (nb_samples-1)*nb_transition_frames)
psi=1
# run animation
for i in range(0, 2):
# setup the passed latents
z_s = np.random.randn(nb_samples, Z_SIZE)
#latents = Gs.components.mapping.run(z_s, None)
passed_latents=z_s
animate_latent_transition(latent_vectors=passed_latents,
#gen_image_fun=synth_image_fun,
gen_image_fun=lambda latents : gen_image_fun(Gs, latents, Gs_kwargs, truncation_psi=psi),
gen_latent_fun=lambda z_s, i: gen_latent_linear(passed_latents, i, nb_transition_frames),
img_size=PLOT_IMG_SHAPE,
nb_frames=nb_frames,
render_dir=render_dir / "transitions")
In [ ]:
#PLOT_IMG_SHAPE = (512, 512, 3)
PLOT_IMG_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)
In [ ]:
render_dir = res_dir / MODEL_NAME / SNAPSHOT_NAME / 'explore_latent'
nb_samples = 20
nb_transition_frames = 24
nb_frames = min(450, (nb_samples-1)*nb_transition_frames)
# setup the passed latents
z_s = np.random.randn(nb_samples, Z_SIZE)
#latents = Gs.components.mapping.run(z_s, None)
passed_latents = z_s
# run animation
#[2., 1.5, 1., 0.7, 0.5, 0., -0.5, -0.7, -1., -1.5, -2.]
for psi in np.linspace(-0.5, 1.5, 9):
animate_latent_transition(latent_vectors=passed_latents,
#gen_image_fun=synth_image_fun,
gen_image_fun=lambda latents : gen_image_fun(Gs, latents, Gs_kwargs, truncation_psi=psi),
gen_latent_fun=lambda z_s, i: gen_latent_linear(passed_latents, i, nb_transition_frames),
img_size=PLOT_IMG_SHAPE,
nb_frames=nb_frames,
render_dir=render_dir / 'psi',
file_prefix='psi{}'.format(str(psi).replace('.', '_')[:5]))
In [ ]:
#PLOT_IMG_SHAPE = (512, 512, 3)
PLOT_IMG_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)
In [ ]:
render_dir = res_dir / MODEL_NAME / SNAPSHOT_NAME / "explore_latent"
nb_transition_frames = 48
# random list of z vectors
#rand_idx = np.random.randint(len(X_train))
z_start = np.random.randn(1, Z_SIZE)
#dlatents = Gs.components.mapping.run(z_start, None, dlatent_broadcast=None)
#vals = np.linspace(-2., 2., nb_transition_frames)
nb_styles = dlatents.shape[0]
stylelatent_vals= np.random.randn(nb_transition_frames, Z_SIZE) + np.linspace(-1., 1., nb_transition_frames)[:, np.newaxis]
for z_idx in range(nb_styles):
animate_latent_transition(latent_vectors=dlatents[0],
gen_image_fun=synth_image_fun,
gen_latent_fun=lambda z_s, i: gen_latent_style_idx(dlatents[0], i, z_idx, stylelatent_vals),
img_size=PLOT_IMG_SHAPE,
nb_frames=nb_transition_frames,
render_dir=render_dir / 'latent_indexes')