In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
In [2]:
# set up plotting
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
sns.set_style('white')
sns.set(color_codes=True)
from IPython.display import clear_output
In [3]:
from tqdm import tnrange
In [4]:
# add deep_networks to path
import sys
sys.path.append('..')
In [5]:
import warnings
warnings.simplefilter('ignore', RuntimeWarning)
In [6]:
import numpy as np
import tensorflow as tf
assert tf.__version__ >= '1.4.0'
In [7]:
from deep_networks import data_util
In [8]:
def plot_gaussian(
ax,
data,
codes=None,
color='w',
size=5,
color_palette=sns.color_palette('Set1', n_colors=8, desat=.5)):
ax.set_aspect('equal')
ax.set_ylim((-size, size))
ax.set_xlim((-size, size))
ax.tick_params(labelsize=10)
if codes is not None:
c = [color_palette[i] for i in codes]
color = None
else:
c = None
sns.kdeplot(data[:, 0], data[:, 1],
cmap='Blues', shade=True, shade_lowest=False, ax=ax)
ax.scatter(data[:, 0], data[:, 1], linewidth=1, marker='+', c=c, color=color)
In [9]:
gm_num_examples = 5000
gm_num_classes = 7
gm_batch_size = 64
gm_num_batches = gm_num_examples // gm_batch_size
gm_output_shape = (2, )
gm_log_dir = 'logs/gm'
gm_checkpoint_dir = 'checkpoints/GM'
def get_gm_data():
return data_util.gaussian_mixture(batch_size=gm_batch_size, scale=3.0, num_clusters=gm_num_classes)
gm_data, gm_labels = get_gm_data()
with tf.Session() as sess:
fig, ax = plt.subplots(figsize=(6, 12))
ax.set_title('Mixture of Gaussians Dataset')
data = np.vstack([gm_data.eval() for _ in range(10)])
plot_gaussian(ax, data)
plt.show(fig)
In [10]:
gm_samples = {}
def sample_and_load(model, sample_step, checkpoint_dir, sample_fn):
resume_step = None
for step in sample_step:
success, _ = model.load(checkpoint_dir, step)
if success:
sample_fn(model, step)
resume_step = step
else:
break
return resume_step
def sample_GAN(samples, num_batches, num_samples):
def sample(gan, step):
epoch = step // num_batches
data = gan.sample(num_samples=num_samples)
samples.append((epoch, data))
clear_output()
f, ax = plt.subplots(figsize=(6, 6))
ax.set_title('Epoch #{}'.format(epoch))
plot_gaussian(ax, data)
plt.show(f)
return sample
def save_and_sample(checkpoint_dir, sample_fn):
def sample(gan, step):
gan.save(checkpoint_dir, step)
sample_fn(gan, step)
return sample
In [11]:
from deep_networks.models.gan import GAN
with tf.Graph().as_default():
with tf.Session() as sess:
samples = []
sample_step = [i * gm_num_batches for i in (10, 20, 50, 100, 200)]
data, _ = get_gm_data()
gan = GAN(sess,
data,
num_examples=gm_num_examples,
output_shape=gm_output_shape,
batch_size=gm_batch_size)
gan._trange = tnrange
gan.init_saver(tf.train.Saver(max_to_keep=None))
sample_fn = sample_GAN(samples, gm_num_batches, 700)
resume_step = sample_and_load(gan, sample_step, gm_checkpoint_dir, sample_fn)
gan.train(num_epochs=200,
log_dir=gm_log_dir,
checkpoint_dir=gm_checkpoint_dir,
resume_step=resume_step,
sample_step=sample_step,
save_step=None,
sample_fn=save_and_sample(gm_checkpoint_dir, sample_fn))
gm_samples['GAN'] = samples
In [12]:
from deep_networks.models.acgan import ACGAN
def sample_ACGAN(samples, num_batches, num_samples):
def sample(gan, step):
epoch = step // num_batches
z = gan.sample_z(num_samples)
c = gan.sample_c(num_samples)
data = gan.sample(z=z, c=c)
samples.append((epoch, data, c))
clear_output()
f, ax = plt.subplots(figsize=(6, 6))
ax.set_title('Epoch #{}'.format(epoch))
plot_gaussian(ax, data, codes=c)
plt.show(f)
return sample
with tf.Graph().as_default():
with tf.Session() as sess:
samples = []
sample_step = [i * gm_num_batches for i in (10, 20, 50, 100, 200)]
data, labels = get_gm_data()
gan = ACGAN(sess,
data,
labels,
num_classes=gm_num_classes,
num_examples=gm_num_examples,
output_shape=gm_output_shape,
batch_size=gm_batch_size)
gan._trange = tnrange
gan.init_saver(tf.train.Saver(max_to_keep=None))
sample_fn = sample_ACGAN(samples, gm_num_batches, 700)
resume_step = sample_and_load(gan, sample_step, gm_checkpoint_dir, sample_fn)
gan.train(num_epochs=200,
log_dir=gm_log_dir,
checkpoint_dir=gm_checkpoint_dir,
resume_step=resume_step,
sample_step=sample_step,
save_step=None,
sample_fn=save_and_sample(gm_checkpoint_dir, sample_fn))
gm_samples['ACGAN'] = samples
In [13]:
from deep_networks.models.wgan import WGAN
with tf.Graph().as_default():
with tf.Session() as sess:
samples = []
sample_step = [i * gm_num_batches for i in (250, 500, 1500, 2500, 3000)]
data, _ = get_gm_data()
gan = WGAN(sess,
data,
num_examples=gm_num_examples,
output_shape=gm_output_shape,
batch_size=gm_batch_size)
gan._trange = tnrange
gan.init_saver(tf.train.Saver(max_to_keep=None))
sample_fn = sample_GAN(samples, gm_num_batches, 700)
resume_step = sample_and_load(gan, sample_step, gm_checkpoint_dir, sample_fn)
gan.train(num_epochs=3000,
log_dir=gm_log_dir,
checkpoint_dir=gm_checkpoint_dir,
resume_step=resume_step,
sample_step=sample_step,
save_step=None,
sample_fn=save_and_sample(gm_checkpoint_dir, sample_fn))
gm_samples['WGAN'] = samples
In [14]:
from deep_networks.models.dragan import DRAGAN
with tf.Graph().as_default():
with tf.Session() as sess:
samples = []
sample_step = [i * gm_num_batches for i in (10, 20, 50, 100, 200)]
data, _ = get_gm_data()
gan = DRAGAN(sess,
data,
num_examples=gm_num_examples,
output_shape=gm_output_shape,
batch_size=gm_batch_size,
reg_const=0.0)
gan._trange = tnrange
gan.init_saver(tf.train.Saver(max_to_keep=None))
sample_fn = sample_GAN(samples, gm_num_batches, 700)
resume_step = sample_and_load(gan, sample_step, gm_checkpoint_dir, sample_fn)
gan.train(num_epochs=200,
log_dir=gm_log_dir,
checkpoint_dir=gm_checkpoint_dir,
resume_step=resume_step,
sample_step=sample_step,
save_step=None,
sample_fn=save_and_sample(gm_checkpoint_dir, sample_fn))
gm_samples['DRAGAN'] = samples
In [15]:
headers = ['GAN', 'ACGAN', 'WGAN', 'DRAGAN']
num_rows = len(headers)
num_cols = 4
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, 4 * num_rows))
for r, h in enumerate(headers):
samples = gm_samples[h][1:]
axes[r][0].set_ylabel(h, rotation=90, size=25)
for col, ax in enumerate(axes[r]):
ax.set_title('Epoch #{}'.format(samples[col][0]))
if len(samples[col]) == 2:
codes = None
else:
codes = samples[col][2]
plot_gaussian(ax, samples[col][1], codes=codes)
fig.tight_layout()
plt.show(fig)
In [16]:
from deep_networks.models.discogan import DiscoGAN
gm_num_classes_even = 8
def sample_DiscoGAN(sess, samples, num_batches, num_samples):
def sample(gan, step):
epoch = step // num_batches
sample_labels_x = []
sample_data_x = []
sample_gen_y = []
sample_recon_x = []
sample_labels_y = []
sample_data_y = []
sample_gen_x = []
sample_recon_y = []
for i in range(4):
data_x, labels_x = sess.run(data_util.gaussian_mixture(batch_size=num_samples // gm_num_classes_even,
scale=3.0,
num_clusters=8,
minval=i,
maxval=i+1))
data_x, gen_y, recon_x = gan.sample_y(x=data_x)
sample_labels_x.append(labels_x)
sample_data_x.append(data_x)
sample_gen_y.append(gen_y)
sample_recon_x.append(recon_x)
data_y, labels_y = sess.run(data_util.gaussian_mixture(batch_size=num_samples // gm_num_classes_even,
scale=3.0,
num_clusters=8,
minval=i+4,
maxval=i+5))
data_y, gen_x, recon_y = gan.sample_x(y=data_y)
sample_labels_y.append(labels_y)
sample_data_y.append(data_y)
sample_gen_x.append(gen_x)
sample_recon_y.append(recon_y)
cols = [
np.concatenate(sample_labels_x),
np.vstack(sample_data_x),
np.vstack(sample_gen_y),
np.vstack(sample_recon_x),
np.concatenate(sample_labels_y),
np.vstack(sample_data_y),
np.vstack(sample_gen_x),
np.vstack(sample_recon_y),
]
samples.append((epoch, cols))
num_rows = 1
num_cols = 6
clear_output()
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, 4 * num_rows))
fig.suptitle('Epoch #{}'.format(epoch))
for i in range(3):
plot_gaussian(axes[i], cols[i+1], codes=cols[0])
plot_gaussian(axes[i+3], cols[i+5], codes=cols[4])
plt.show(fig)
return sample
with tf.Graph().as_default():
with tf.Session() as sess:
samples = []
sample_step = [i * gm_num_batches for i in (5, 10, 20, 50)]
data_x, _ = data_util.gaussian_mixture(batch_size=gm_batch_size,
scale=3.0,
num_clusters=gm_num_classes_even,
maxval=4)
data_y, _ = data_util.gaussian_mixture(batch_size=gm_batch_size,
scale=3.0,
num_clusters=gm_num_classes_even,
minval=4)
gan = DiscoGAN(sess,
data_x,
data_y,
num_examples=gm_num_examples,
x_output_shape=gm_output_shape,
y_output_shape=gm_output_shape,
batch_size=gm_batch_size)
gan._trange = tnrange
gan.init_saver(tf.train.Saver(max_to_keep=None))
sample_fn = sample_DiscoGAN(sess, samples, gm_num_batches, 150)
resume_step = sample_and_load(gan, sample_step, gm_checkpoint_dir, sample_fn)
gan.train(num_epochs=50,
log_dir=gm_log_dir,
resume_step=resume_step,
sample_step=sample_step,
save_step=None,
sample_fn=save_and_sample(gm_checkpoint_dir, sample_fn))
gm_samples['DiscoGAN'] = samples
In [17]:
rows = [('Epoch #{}'.format(step), data) for step, data in gm_samples['DiscoGAN']]
num_rows = len(rows)
num_cols = 6
fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, 4 * num_rows))
titles = ['X', '>Y', '>X', 'Y', '>X', '>Y']
for title, ax in zip(titles, axes[0]):
ax.set_title(title, fontsize=20)
for r, (h, samples) in enumerate(rows):
axes[r][0].set_ylabel(h, rotation=90, size=25)
for i in range(3):
plot_gaussian(axes[r][i], samples[i+1], codes=samples[0])
plot_gaussian(axes[r][i+3], samples[i+5], codes=samples[4])
fig.tight_layout()
plt.show(fig)