In [ ]:
from __future__ import print_function
from __future__ import division
import numpy as np
import time
import matplotlib.pyplot as plt
import tensorflow as tf
import sys
sys.path.append('..')
import models.VAE as vae
import models.CVAE as cvae
import models.GAN as gan
import utils.DataReader as Data
np.random.seed(0)
In [ ]:
# get data handler
data_type = 'cifar'
data_dir = '/home/mattw/Dropbox/git/dreamscape/data/'
if data_type is 'mnist':
data = Data.DataReaderMNIST(data_dir + 'mnist/', one_hot=True)
IM_HEIGHT = 28
IM_WIDTH = 28
IM_DEPTH = 1
NUM_CLASSES = 10
elif data_type is 'cifar':
data = Data.DataReaderCIFAR(data_dir + 'cifar/', one_hot=True)
IM_HEIGHT = 32
IM_WIDTH = 32
IM_DEPTH = 3
NUM_CLASSES = 10
PIX_TOTAL = IM_HEIGHT * IM_WIDTH * IM_DEPTH
In [ ]:
saving = False
save_dir = '/home/mattw/Dropbox/git/dreamscape/tmp/'
net_type = 'vae' # 'vae' | 'cvae' | 'gan'
# define training params
batch_size = 128
epochs = {
'training': 100,
'disp': 5,
'ckpt': None,
'summary': 5
}
use_gpu = 1
# Notes
# vae-mnist - learning_rate: 1e-3, training epochs: 20
# vae-cifar - learning_rate: 1e-4, training epochs:
# gan-mnist - learning_rate: 1e-4, training epochs: 1000
# initialize network
if net_type is 'vae':
layers_encoder = [PIX_TOTAL, 800, 400]
layer_latent = 100
layers_decoder = [400, 800, PIX_TOTAL]
net = vae.VAE(
layers_encoder=layers_encoder,
layer_latent=layer_latent,
layers_decoder=layers_decoder,
learning_rate=1e-4) # 1e-3 for mnist, 1e-4 for cifar
elif net_type is 'cvae':
layers_encoder = [PIX_TOTAL, 400, 400]
layer_latent = 20
layers_decoder = [400, 400, PIX_TOTAL]
num_classes = NUM_CLASSES
net = cvae.CVAE(
layers_encoder=layers_encoder,
layer_latent=layer_latent,
layers_decoder=layers_decoder,
num_classes=num_classes,
learning_rate=1e-3)
elif net_type is 'gan':
layers_generator = [100, 400, PIX_TOTAL]
layers_discriminator = [PIX_TOTAL, 400, 100, 1]
net = gan.GAN(
layers_gen=layers_generator,
layers_disc=layers_discriminator,
learning_rate=1e-4)
else:
raise Error('Invalid net_type')
# start the tensorflow session
config = tf.ConfigProto(device_count = {'GPU': use_gpu})
sess = tf.Session(config=config, graph=net.graph)
sess.run(net.init)
# train network
time_start = time.time()
net.train(
sess,
data=data,
batch_size=batch_size,
epochs_training=epochs['training'],
epochs_disp=epochs['disp'],
epochs_ckpt=epochs['ckpt'],
epochs_summary=epochs['summary'],
output_dir=save_dir)
time_end = time.time()
print('time_elapsed: %g' % (time_end - time_start))
# save network
if saving:
net.save_model(sess, save_dir)
# close the tensorflow session
# sess.close()
In [ ]:
sess.close()
In [ ]:
num_cols = 5
num_rows = 3
f, ax = plt.subplots(num_rows, num_cols)
for i in range(num_rows):
for j in range(num_cols):
gen = net.generate(sess)
if data_type == 'mnist':
to_plot = np.reshape(gen, (IM_HEIGHT, IM_WIDTH))
interpolation='nearest'
cmap='gray'
elif data_type == 'cifar':
to_plot = np.reshape(gen, (IM_HEIGHT, IM_WIDTH, IM_DEPTH))
interpolation='nearest'
cmap='viridis'
ax[i,j].imshow(to_plot, interpolation=interpolation, cmap=cmap)
ax[i,j].axes.get_xaxis().set_visible(False)
ax[i,j].axes.get_yaxis().set_visible(False)
plt.show()
In [ ]:
num_cols = 2
num_rows = 5
x = data.train.next_batch(num_rows)
eps = np.zeros((num_rows, net.num_lvs))
recon = net.reconstruct(sess, x[0], eps)
f, ax = plt.subplots(num_rows, num_cols)
for i in range(num_rows):
if data_type == 'mnist':
to_plot_1 = np.reshape(x[0][i,:], (IM_HEIGHT, IM_WIDTH))
to_plot_2 = np.reshape(recon[i,:], (IM_HEIGHT, IM_WIDTH))
interpolation='nearest'
cmap='gray'
elif data_type == 'cifar':
to_plot_1 = np.reshape(x[0][i,:], (IM_HEIGHT, IM_WIDTH, IM_DEPTH))
to_plot_2 = np.reshape(recon[i,:], (IM_HEIGHT, IM_WIDTH, IM_DEPTH))
interpolation='nearest'
cmap='viridis'
ax[i,0].imshow(to_plot_1, interpolation=interpolation, cmap=cmap)
ax[i,0].axes.get_xaxis().set_visible(False)
ax[i,0].axes.get_yaxis().set_visible(False)
ax[i,1].imshow(to_plot_2, interpolation=interpolation, cmap=cmap)
ax[i,1].axes.get_xaxis().set_visible(False)
ax[i,1].axes.get_yaxis().set_visible(False)
plt.show()
In [ ]:
"""Notes
- only works for a model with latent space dimension of 2;
- generally terrible results with GANs
"""
nx = ny = 20
x_values = np.linspace(-3, 3, nx)
y_values = np.linspace(-3, 3, ny)
canvas = np.empty((IM_HEIGHT*ny, IM_WIDTH*nx))
for i, yi in enumerate(x_values):
for j, xi in enumerate(y_values):
z_mean = np.array([[xi, yi]])
x_mean = net.generate(sess, z_mean=z_mean)
if data_type == 'mnist':
canvas[(nx-i-1)*IM_HEIGHT:(nx-i)*IM_WIDTH,
j*IM_HEIGHT:(j+1)*IM_WIDTH] = x_mean[0].reshape(IM_HEIGHT, IM_WIDTH)
cmap = 'gray'
elif data_type == 'cifar':
canvas[(nx-i-1)*IM_HEIGHT:(nx-i)*IM_WIDTH,
j*IM_HEIGHT:(j+1)*IM_WIDTH, :] = \
x_mean[0].reshape(IM_HEIGHT, IM_WIDTH, IM_DEPTH)
cmap = 'viridis'
plt.figure(figsize=(8, 10))
Xi, Yi = np.meshgrid(x_values, y_values)
plt.imshow(canvas, origin='upper',
interpolation='nearest',
cmap=cmap)
plt.tight_layout()
plt.show()
In [ ]:
%reload_ext watermark
%watermark -a "Matt Whiteway" -d -v -m -p numpy,tensorflow