Use this code with no warranty and please respect the accompanying license.
In [8]:
# Imports
%reload_ext autoreload
%autoreload 1
import os, sys
sys.path.append('../')
sys.path.append('../common')
sys.path.append('../GenerativeModels')
from tools_general import tf, np
from IPython.display import Image
from tools_train import vis_square
from tools_config import data_dir
from tools_train import get_train_params, plot_latent_variable
import matplotlib.pyplot as plt
import imageio
from tensorflow.examples.tutorials.mnist import input_data
from tools_train import get_demo_data
In [2]:
# define parameters
networktype = 'CDAE_MNIST'
work_dir = '../trained_models/%s/' %networktype
if not os.path.exists(work_dir): os.makedirs(work_dir)
In [3]:
from CDAE import create_encoder, create_decoder, create_cdae_trainer
You can either get the fully trained models from the google drive or train your own models using the CDAE.py script.
In [5]:
iter_num = 30030
best_model = work_dir + "Model_Iter_%.3d.ckpt"%iter_num
best_img = work_dir + 'Rec_Iter_%d.jpg'%iter_num
Image(filename=best_img)
Out[5]:
In [6]:
latentD = 2
batch_size = 128
tf.reset_default_graph()
demo_sess = tf.InteractiveSession()
is_training = tf.placeholder(tf.bool, [], 'is_training')
Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])
Xenc_op = create_encoder(Xph, is_training, latentD, reuse=False, networktype=networktype + '_Enc')
Xrec_op = create_decoder(Xenc_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec')
tf.global_variables_initializer().run()
Enc_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_Enc')
Dec_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_Dec')
saver = tf.train.Saver(var_list=Enc_varlist+Dec_varlist)
saver.restore(demo_sess, best_model)
In [11]:
#Get uniform samples over the labels
spl = 800 # sample_per_label
data = input_data.read_data_sets(data_dir, one_hot=False, reshape=False)
Xdemo, Xdemo_labels = get_demo_data(data, spl)
Zdemo = np.random.normal(size=[spl * 10, latentD], loc=0.0, scale=1.).astype(np.float32)
decoded_data = demo_sess.run(Xenc_op, feed_dict={Xph:Xdemo, is_training:False})
plot_latent_variable(decoded_data, Xdemo_labels)
So CDAE is not a generative model per se and complex sampling methods exist that enable generating new data from their latent code. c.f. Generalized Denoising Auto-Encoders as Generative Models