Setup


In [ ]:
from __future__ import print_function
from __future__ import division

import numpy as np
import time
import matplotlib.pyplot as plt
import tensorflow as tf

import sys
sys.path.append('..')
import models.VAE as vae
import models.CVAE as cvae
import models.GAN as gan
import utils.DataReader as Data

np.random.seed(0)

In [ ]:
# get data handler
data_type = 'cifar'
data_dir = '/home/mattw/Dropbox/git/dreamscape/data/'

if data_type is 'mnist':
    data = Data.DataReaderMNIST(data_dir + 'mnist/', one_hot=True)
    IM_HEIGHT = 28
    IM_WIDTH = 28
    IM_DEPTH = 1
    NUM_CLASSES = 10
elif data_type is 'cifar':
    data = Data.DataReaderCIFAR(data_dir + 'cifar/', one_hot=True)      
    IM_HEIGHT = 32
    IM_WIDTH = 32
    IM_DEPTH = 3
    NUM_CLASSES = 10
    
PIX_TOTAL = IM_HEIGHT * IM_WIDTH * IM_DEPTH

Define and Train a Generative Model


In [ ]:
saving = False
save_dir = '/home/mattw/Dropbox/git/dreamscape/tmp/'
net_type = 'vae' # 'vae' | 'cvae' | 'gan'

# define training params
batch_size = 128
epochs = {
    'training': 100,
    'disp': 5,
    'ckpt': None,
    'summary': 5
}
use_gpu = 1

# Notes
# vae-mnist - learning_rate: 1e-3, training epochs: 20
# vae-cifar - learning_rate: 1e-4, training epochs: 
# gan-mnist - learning_rate: 1e-4, training epochs: 1000

# initialize network
if net_type is 'vae':
    layers_encoder = [PIX_TOTAL, 800, 400]
    layer_latent = 100
    layers_decoder = [400, 800, PIX_TOTAL]
    net = vae.VAE(
        layers_encoder=layers_encoder, 
        layer_latent=layer_latent,
        layers_decoder=layers_decoder,
        learning_rate=1e-4) # 1e-3 for mnist, 1e-4 for cifar
elif net_type is 'cvae':
    layers_encoder = [PIX_TOTAL, 400, 400]
    layer_latent = 20
    layers_decoder = [400, 400, PIX_TOTAL]
    num_classes = NUM_CLASSES
    net = cvae.CVAE(
        layers_encoder=layers_encoder, 
        layer_latent=layer_latent,
        layers_decoder=layers_decoder,
        num_classes=num_classes,
        learning_rate=1e-3)
elif net_type is 'gan':
    layers_generator = [100, 400, PIX_TOTAL]
    layers_discriminator = [PIX_TOTAL, 400, 100, 1]
    net = gan.GAN(
        layers_gen=layers_generator, 
        layers_disc=layers_discriminator,
        learning_rate=1e-4)
else:
    raise Error('Invalid net_type')

# start the tensorflow session
config = tf.ConfigProto(device_count = {'GPU': use_gpu})
sess = tf.Session(config=config, graph=net.graph)
sess.run(net.init)

# train network
time_start = time.time()
net.train(
    sess, 
    data=data,
    batch_size=batch_size,
    epochs_training=epochs['training'],
    epochs_disp=epochs['disp'],
    epochs_ckpt=epochs['ckpt'],
    epochs_summary=epochs['summary'],
    output_dir=save_dir)
time_end = time.time()
print('time_elapsed: %g' % (time_end - time_start))

# save network
if saving:
    net.save_model(sess, save_dir)

# close the tensorflow session
# sess.close()

In [ ]:
sess.close()

Visualize Model

Generated Samples Visualization (all models)


In [ ]:
num_cols = 5
num_rows = 3
f, ax = plt.subplots(num_rows, num_cols)
for i in range(num_rows):
    for j in range(num_cols):
        gen = net.generate(sess)
        if data_type == 'mnist':
            to_plot = np.reshape(gen, (IM_HEIGHT, IM_WIDTH))
            interpolation='nearest'
            cmap='gray'
        elif data_type == 'cifar':
            to_plot = np.reshape(gen, (IM_HEIGHT, IM_WIDTH, IM_DEPTH))
            interpolation='nearest'
            cmap='viridis'
        ax[i,j].imshow(to_plot, interpolation=interpolation, cmap=cmap)
        ax[i,j].axes.get_xaxis().set_visible(False)
        ax[i,j].axes.get_yaxis().set_visible(False)
plt.show()

Reconstruction Visualization (autoencoders only)


In [ ]:
num_cols = 2
num_rows = 5

x = data.train.next_batch(num_rows)
eps = np.zeros((num_rows, net.num_lvs))
recon = net.reconstruct(sess, x[0], eps)

f, ax = plt.subplots(num_rows, num_cols)
for i in range(num_rows):
    if data_type == 'mnist':
        to_plot_1 = np.reshape(x[0][i,:], (IM_HEIGHT, IM_WIDTH))
        to_plot_2 = np.reshape(recon[i,:], (IM_HEIGHT, IM_WIDTH))
        interpolation='nearest'
        cmap='gray'
    elif data_type == 'cifar':
        to_plot_1 = np.reshape(x[0][i,:], (IM_HEIGHT, IM_WIDTH, IM_DEPTH))
        to_plot_2 = np.reshape(recon[i,:], (IM_HEIGHT, IM_WIDTH, IM_DEPTH))
        interpolation='nearest'
        cmap='viridis'
    ax[i,0].imshow(to_plot_1, interpolation=interpolation, cmap=cmap)
    ax[i,0].axes.get_xaxis().set_visible(False)
    ax[i,0].axes.get_yaxis().set_visible(False)
    ax[i,1].imshow(to_plot_2, interpolation=interpolation, cmap=cmap)
    ax[i,1].axes.get_xaxis().set_visible(False)
    ax[i,1].axes.get_yaxis().set_visible(False)
plt.show()

Latent Space Visualization (all models)


In [ ]:
"""Notes
- only works for a model with latent space dimension of 2;
- generally terrible results with GANs
"""
nx = ny = 20
x_values = np.linspace(-3, 3, nx)
y_values = np.linspace(-3, 3, ny)

canvas = np.empty((IM_HEIGHT*ny, IM_WIDTH*nx))
for i, yi in enumerate(x_values):
    for j, xi in enumerate(y_values):
        z_mean = np.array([[xi, yi]])
        x_mean = net.generate(sess, z_mean=z_mean)
        if data_type == 'mnist':
            canvas[(nx-i-1)*IM_HEIGHT:(nx-i)*IM_WIDTH, 
                   j*IM_HEIGHT:(j+1)*IM_WIDTH] = x_mean[0].reshape(IM_HEIGHT, IM_WIDTH)
            cmap = 'gray'
        elif data_type == 'cifar':
            canvas[(nx-i-1)*IM_HEIGHT:(nx-i)*IM_WIDTH, 
                   j*IM_HEIGHT:(j+1)*IM_WIDTH, :] = \
            x_mean[0].reshape(IM_HEIGHT, IM_WIDTH, IM_DEPTH)
            cmap = 'viridis'
            
plt.figure(figsize=(8, 10))        
Xi, Yi = np.meshgrid(x_values, y_values)
plt.imshow(canvas, origin='upper',
           interpolation='nearest',
           cmap=cmap)
plt.tight_layout()
plt.show()

In [ ]:
%reload_ext watermark
%watermark -a "Matt Whiteway" -d -v -m -p numpy,tensorflow