Copyright 2019 Deepmind Technologies Limited. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

IODINE Interactive Evaluation

This notebook can load checkpoints and run the model. You need to download the data before you can run this (see README.md). It includes several plots to inspect the decomposition, iterations, and inputs to the refinement network, as well as prior samples, latent traversals and code to run on a custom example.


In [1]:
import numpy as np
import pathlib
from PIL import Image
import matplotlib.pyplot as plt
import warnings
from absl import logging

# Ignore all tensorflow deprecation warnings 
logging._warn_preinit_stderr = 0
warnings.filterwarnings('ignore', module='.*tensorflow.*')
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)

import sonnet as snt
from shapeguard import ShapeGuard
import seaborn as sns
import matplotlib.pyplot as plt

from main import ex, load_checkpoint,  build, get_train_step
from iodine.modules import plotting, utils

sns.set_style('whitegrid')

In [2]:
# Adjust this to load checkpoints for different datasets
configuration = 'clevr6'
#configuration = 'multi_dsprites'
#configuration = 'tetrominoes'

# Configuration values to adjust
config_updates = {
    'batch_size':1,               # to speed up interactive evaluation
    'data.shuffle_buffer': None,  # to speed up interactive evaluation
    # 'num_components':4,  # uncomment to change the number of components
    # 'num_iters': 5,      # uncomment to change the number of iterations
}

In [3]:
sess = tf.InteractiveSession()

# create a sacred run
r = ex._create_run(named_configs=[configuration], config_updates=config_updates, options={'--force': True, '--unobserved': True})

# restore the checkpoint and get the model, data, etc.
restored = load_checkpoint(session=sess)
model = restored["model"]
info = restored["info"]
dataset = restored["dataset"]
inputs = restored["inputs"]


Successfully restored Checkpoint "checkpoints/clevr6/model.ckpt"
Variable                                                 Shape      Type     Collections                            Device
encoder_net/cnn/conv_net_2d/conv_2d_0/b                  64         float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/conv_net_2d/conv_2d_0/w                  3x3x18x64  float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/conv_net_2d/conv_2d_1/b                  64         float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/conv_net_2d/conv_2d_1/w                  3x3x64x64  float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/conv_net_2d/conv_2d_2/b                  64         float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/conv_net_2d/conv_2d_2/w                  3x3x64x64  float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/conv_net_2d/conv_2d_3/b                  64         float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/conv_net_2d/conv_2d_3/w                  3x3x64x64  float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/mlp/linear_0/b                           256        float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/mlp/linear_0/w                           64x256     float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/mlp/linear_1/b                           256        float32  global_variables, trainable_variables  (legacy)
encoder_net/cnn/mlp/linear_1/w                           256x256    float32  global_variables, trainable_variables  (legacy)
factor_evaluator/repres_content/position_mean_var/M2     3          float32  global_variables                       (legacy)
factor_evaluator/repres_content/position_mean_var/mean   3          float32  global_variables                       (legacy)
factor_evaluator/repres_content/position_mean_var/total  3          float32  global_variables                       (legacy)
factor_evaluator/repres_content/predict_latents/b        19         float32  global_variables, trainable_variables  (legacy)
factor_evaluator/repres_content/predict_latents/w        64x19      float32  global_variables, trainable_variables  (legacy)
global_step                                                         int64    global_step, global_variables          (legacy)
model/iodine/initial_sample_distribution                 128        float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_counterfactual/beta   1          float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_counterfactual/gamma  1          float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_dcomponents/beta      3          float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_dcomponents/gamma     3          float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_dmask/beta            1          float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_dmask/gamma           1          float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_dzp/beta              128        float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_dzp/gamma             128        float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_log_prob/beta         1          float32  global_variables, trainable_variables  (legacy)
model/iodine/preprocess/layer_norm_log_prob/gamma        1          float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_0/b     64         float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_0/w     3x3x67x64  float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_1/b     64         float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_1/w     3x3x64x64  float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_2/b     64         float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_2/w     3x3x64x64  float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_3/b     64         float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_3/w     3x3x64x64  float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_4/b     4          float32  global_variables, trainable_variables  (legacy)
pixel_decoder/broadcast_conv/conv_net_2d/conv_2d_4/w     3x3x64x4   float32  global_variables, trainable_variables  (legacy)
recurrent_net/lstm/lstm/b_gates                          1024       float32  global_variables, trainable_variables  (legacy)
recurrent_net/lstm/lstm/w_gates                          769x1024   float32  global_variables, trainable_variables  (legacy)
refinement_head/residual_head/linear/b                   128        float32  global_variables, trainable_variables  (legacy)
refinement_head/residual_head/linear/w                   256x128    float32  global_variables, trainable_variables  (legacy)

In [4]:
# run evaluation and get the info dict
rinfo = sess.run(info)

Decomposition

Let's plot the decomposition at the last timestep. Set mask_components=False to see the full reconstruction for each component


In [5]:
fig = plotting.example_plot(rinfo, t=-1, mask_components=True)


Iterations

This plots the decompositions for all timesteps and an additional row with the ground truth segmentation, as well as the predicted mask-logits.


In [6]:
fig = plotting.iterations_plot(rinfo, mask_components=True)


We can also plot error and metrics over iterations to see how they improve over the refinement steps:


In [7]:
b = 0  # sample index inside batch

fig, axes = plt.subplots(nrows=2, ncols=3, sharex=True, figsize=(12, 6))

axes[0, 0].plot(rinfo["losses"]["recons"][b], color='r')
axes[0, 0].set_title("Reconstruction Loss")

axes[0, 1].plot(rinfo["metrics"]["mse"][b], color='orange')
axes[0, 1].set_title("MSE")

# Plot total KL and faintly also the individual object KLs
axes[0, 2].plot(rinfo["losses"]["kl"][b].sum(axis=1), color='purple')
axes[0, 2].plot(rinfo["losses"]["kl"][b], color='purple', alpha=0.2)
axes[0, 2].set_title("KL")

axes[1, 0].plot(rinfo["metrics"]["ari"][b])
axes[1, 0].set_title("ARI")
axes[1, 0].set_ylim((0, 1))
axes[1, 0].set_xlabel('Iteration')

axes[1, 1].plot(rinfo["metrics"]["ari_nobg"][b])
axes[1, 1].set_title("ARI (ignoring background)")
axes[1, 1].set_ylim((0, 1))
axes[1, 1].set_xlabel('Iteration')

axes[1, 2].plot(rinfo["losses"]["factor"], color='g')
axes[1, 2].set_title("Factor Regression Loss")
axes[1, 2].set_xlabel('Iteration')

print("Errors and metrics over iterations")


Errors and metrics over iterations

Inputs to the Refinement Network

This plots all the spatial inputs that are being fed to the refinement network for a given timestep.


In [8]:
fig = plotting.inputs_plot(rinfo, t=-1)


Latent Traversal


In [11]:
sg = ShapeGuard()
_ = sg.guard(rinfo['latent']['z'], "B, T, K, Z")
z_placeholder = tf.placeholder(tf.float32, shape=sg["B, K, Z"])
params, out_dist = model.decode(z_placeholder)
img = out_dist.mean()
highest_kl_obj_idx = rinfo['losses']['kl'][b, -1].argmax()

In [13]:
b = 0
obj_idx = highest_kl_obj_idx
interesting_latents = [3, 4, 26, 43, 56, 57, 61, 28]


fig, axes = plt.subplots(nrows=len(interesting_latents), ncols=9, sharex=True, sharey=True, figsize=(18, len(interesting_latents)*2))
for row, lat in enumerate(interesting_latents):
    z_adjusted = rinfo['latent']['z'][b:b+1, -1].copy()
    for col, val in enumerate(np.linspace(-1.25, 1.25, 9)):
        z_adjusted[b, obj_idx, lat] = val
        rimg = sess.run(img, feed_dict={z_placeholder:z_adjusted})
        plotting.show_img(rimg[0, 0], ax=axes[row, col])
    axes[row, 0].set_ylabel(lat)
plt.subplots_adjust(wspace=0.01, hspace=0.01)