Copyright 2019 Deepmind Technologies Limited. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
This notebook can load checkpoints and run the model. You need to download the data before you can run this (see README.md). It includes several plots to inspect the decomposition, iterations, and inputs to the refinement network, as well as prior samples, latent traversals and code to run on a custom example.
In [1]:
import numpy as np
import pathlib
from PIL import Image
import matplotlib.pyplot as plt
import warnings
from absl import logging
# Ignore all tensorflow deprecation warnings
logging._warn_preinit_stderr = 0
warnings.filterwarnings('ignore', module='.*tensorflow.*')
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)
import sonnet as snt
from shapeguard import ShapeGuard
import seaborn as sns
import matplotlib.pyplot as plt
from main import ex, load_checkpoint, build, get_train_step
from iodine.modules import plotting, utils
sns.set_style('whitegrid')
In [2]:
# Adjust this to load checkpoints for different datasets
configuration = 'clevr6'
#configuration = 'multi_dsprites'
#configuration = 'tetrominoes'
# Configuration values to adjust
config_updates = {
'batch_size':1, # to speed up interactive evaluation
'data.shuffle_buffer': None, # to speed up interactive evaluation
# 'num_components':4, # uncomment to change the number of components
# 'num_iters': 5, # uncomment to change the number of iterations
}
In [3]:
sess = tf.InteractiveSession()
# create a sacred run
r = ex._create_run(named_configs=[configuration], config_updates=config_updates, options={'--force': True, '--unobserved': True})
# restore the checkpoint and get the model, data, etc.
restored = load_checkpoint(session=sess)
model = restored["model"]
info = restored["info"]
dataset = restored["dataset"]
inputs = restored["inputs"]
In [4]:
# run evaluation and get the info dict
rinfo = sess.run(info)
In [5]:
fig = plotting.example_plot(rinfo, t=-1, mask_components=True)
In [6]:
fig = plotting.iterations_plot(rinfo, mask_components=True)
We can also plot error and metrics over iterations to see how they improve over the refinement steps:
In [7]:
b = 0 # sample index inside batch
fig, axes = plt.subplots(nrows=2, ncols=3, sharex=True, figsize=(12, 6))
axes[0, 0].plot(rinfo["losses"]["recons"][b], color='r')
axes[0, 0].set_title("Reconstruction Loss")
axes[0, 1].plot(rinfo["metrics"]["mse"][b], color='orange')
axes[0, 1].set_title("MSE")
# Plot total KL and faintly also the individual object KLs
axes[0, 2].plot(rinfo["losses"]["kl"][b].sum(axis=1), color='purple')
axes[0, 2].plot(rinfo["losses"]["kl"][b], color='purple', alpha=0.2)
axes[0, 2].set_title("KL")
axes[1, 0].plot(rinfo["metrics"]["ari"][b])
axes[1, 0].set_title("ARI")
axes[1, 0].set_ylim((0, 1))
axes[1, 0].set_xlabel('Iteration')
axes[1, 1].plot(rinfo["metrics"]["ari_nobg"][b])
axes[1, 1].set_title("ARI (ignoring background)")
axes[1, 1].set_ylim((0, 1))
axes[1, 1].set_xlabel('Iteration')
axes[1, 2].plot(rinfo["losses"]["factor"], color='g')
axes[1, 2].set_title("Factor Regression Loss")
axes[1, 2].set_xlabel('Iteration')
print("Errors and metrics over iterations")
In [8]:
fig = plotting.inputs_plot(rinfo, t=-1)
In [11]:
sg = ShapeGuard()
_ = sg.guard(rinfo['latent']['z'], "B, T, K, Z")
z_placeholder = tf.placeholder(tf.float32, shape=sg["B, K, Z"])
params, out_dist = model.decode(z_placeholder)
img = out_dist.mean()
highest_kl_obj_idx = rinfo['losses']['kl'][b, -1].argmax()
In [13]:
b = 0
obj_idx = highest_kl_obj_idx
interesting_latents = [3, 4, 26, 43, 56, 57, 61, 28]
fig, axes = plt.subplots(nrows=len(interesting_latents), ncols=9, sharex=True, sharey=True, figsize=(18, len(interesting_latents)*2))
for row, lat in enumerate(interesting_latents):
z_adjusted = rinfo['latent']['z'][b:b+1, -1].copy()
for col, val in enumerate(np.linspace(-1.25, 1.25, 9)):
z_adjusted[b, obj_idx, lat] = val
rimg = sess.run(img, feed_dict={z_placeholder:z_adjusted})
plotting.show_img(rimg[0, 0], ax=axes[row, col])
axes[row, 0].set_ylabel(lat)
plt.subplots_adjust(wspace=0.01, hspace=0.01)
In [8]:
samples = model.get_sample_images()
rsamples = sess.run(samples)
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(rsamples[0])
_ = ax.axis('off')
In [9]:
# Get a set of placeholders for the dataset and evaluate the model using that
input_ph = dataset.get_placeholders()
custom_info = model.eval(input_ph)
In [10]:
# Construct zero-valued fillers for all placeholders
sg = ShapeGuard()
sg.guard(input_ph['image'], "B, 1, H, W, C")
sg.guard(input_ph['mask'], "B, 1, L, H, W, 1")
fillers = {
input_ph['image']: np.zeros(sg['1, 1, H, W, C']),
input_ph['mask']: np.zeros(sg['1, 1, L, H, W, 1']),
input_ph['latent']['color']: np.zeros(sg["1, L, 1"]),
input_ph['latent']['shape']: np.zeros(sg["1, L, 1"]),
input_ph['latent']['size']: np.zeros(sg["1, L, 1"]),
input_ph['latent']['position']: np.zeros(sg["1, L, 3"]),
input_ph['latent']['rotation']: np.zeros(sg["1, L, 1"]),
input_ph['visibility']: np.zeros(sg["1, L"]),
}
Now we can evaluate the model on any custom image like this real-world replica of the CLEVR dataset
(From "Neural-Symbolic VQA: Disentangling Reasoning from Vision and Language Understanding" Kexin Yi, Jiajun Wu, Chuang Gan, Antonio Torralba, Pushmeet Kohli, Joshua B. Tenenbaum)
In [11]:
img = np.array(Image.open('images/realworld_clevr.png')) / 255.
fillers[input_ph['image']] = img.reshape(sg['1, 1, H, W, C'])
custom_rinfo = sess.run(custom_info, feed_dict=fillers)
fig = plotting.iterations_plot(custom_rinfo)
The info returned by model.eval()
contains many variables from all components and iterations.
For a rough overview see the following sketch:
In [10]:
utils.print_shapes('out_info', info)