In [1]:
# Load graph and restore weights
metagraph_fp = './train/infer/infer.meta'
ckpt_fp = './train/model.ckpt-6532'
import tensorflow as tf
tf.reset_default_graph()
graph = tf.get_default_graph()
saver = tf.train.import_meta_graph(metagraph_fp)
sess = tf.InteractiveSession()
saver.restore(sess, ckpt_fp)
In [2]:
# Generate disentangled latent codes
nids = 4
nobs = 8
samp_feeds = {}
samp_feeds[graph.get_tensor_by_name('samp_zi_n:0')] = nids
samp_feeds[graph.get_tensor_by_name('samp_zo_n:0')] = nobs
samp_fetches = {}
samp_fetches['zis'] = graph.get_tensor_by_name('samp_zi:0')
samp_fetches['zos'] = graph.get_tensor_by_name('samp_zo:0')
_samp_fetches = sess.run(samp_fetches, samp_feeds)
print _samp_fetches['zis'].shape
print _samp_fetches['zos'].shape
In [3]:
# Generate grid of images from latent codes
feeds = {}
feeds[graph.get_tensor_by_name('zi:0')] = _samp_fetches['zis']
feeds[graph.get_tensor_by_name('zo:0')] = _samp_fetches['zos']
fetches = {}
fetches['G_z_grid'] = graph.get_tensor_by_name('G_z_grid:0')
fetches['G_z_grid_prev'] = graph.get_tensor_by_name('G_z_grid_prev:0')
_fetches = sess.run(fetches, feeds)
print _fetches['G_z_grid'].shape
print _fetches['G_z_grid_prev'].shape
In [ ]:
# Preview image
from cStringIO import StringIO
from IPython import display
import numpy as np
import PIL.Image
def display_img(a):
f = StringIO()
PIL.Image.fromarray(a).save(f, 'png')
display.display(display.Image(data=f.getvalue()))
display_img(_fetches['G_z_grid_prev'])