Copyright 2018 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Figures

This notebook contains code for generating the figures and tables from the paper "Understanding and Improving Interpolation in Autoencoders via an Adversarial Regularizer". The code is mainly provided as an example and may require modification to be run in a different setting.


In [ ]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import scipy.ndimage
import lib.eval
import collections
import tensorflow as tf
import glob
import lib.utils
import all_aes
from absl import flags
import sys
FLAGS = flags.FLAGS
FLAGS(['--lr', '0.0001'])

import os
if not os.path.exists('figures'):
    os.makedirs('figures')
    
def flatten_lines(lines, padding=2):
    padding = np.ones((lines.shape[0], padding) + lines.shape[2:])
    lines = np.concatenate([padding, lines, padding], 1)
    lines = np.concatenate(lines, 0)
    return np.transpose(lines, [1, 0] + list(range(2, lines.ndim)))

def get_final_value_median(values, steps, N=20):
    sorted_steps = np.argsort(steps)
    values = np.array(values)[sorted_steps]
    return np.median(values[-N:])

HEIGHT = 32
WIDTH = 32
N_LINES = 16

START_ANGLE = 5*np.pi/7
END_ANGLE = 3*np.pi/2.

Example line interpolations

Samples


In [ ]:
example_lines = np.zeros((N_LINES, HEIGHT, WIDTH))
# Cover the space of angles somewhat evenly
angles = np.linspace(0, 2*np.pi - np.pi/N_LINES, N_LINES)
np.random.shuffle(angles)
for n, angle in enumerate(angles):
    example_lines[n] = lib.data.draw_line(angle, HEIGHT, WIDTH)[..., 0]

fig = plt.figure(figsize=(15, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(flatten_lines(example_lines), cmap=plt.cm.gray, interpolation='nearest')
plt.gca().set_axis_off()

plt.savefig('figures/line_samples.pdf', aspect='normal')

Correct interpolation


In [ ]:
line_interpolation = np.zeros((N_LINES, HEIGHT, WIDTH))
angles = np.linspace(START_ANGLE, END_ANGLE, N_LINES)

for n in range(N_LINES):
    line_interpolation[n] = lib.data.draw_line(angles[n], HEIGHT, WIDTH)[..., 0]

fig = plt.figure(figsize=(15, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(flatten_lines(line_interpolation), cmap=plt.cm.gray, interpolation='nearest')
plt.gca().set_axis_off()

plt.savefig('figures/line_correct_interpolation.pdf', aspect='normal')
print lib.eval.line_eval(line_interpolation[np.newaxis, ..., np.newaxis])

Data-space interpolation


In [ ]:
line_interpolation = np.zeros((N_LINES, HEIGHT, WIDTH))
start_line = lib.data.draw_line(START_ANGLE, HEIGHT, WIDTH)[..., 0]
end_line = lib.data.draw_line(END_ANGLE, HEIGHT, WIDTH)[..., 0]
weights = np.linspace(1, 0, N_LINES)

for n in range(N_LINES):
    line_interpolation[n] = weights[n]*start_line + (1 - weights[n])*end_line

fig = plt.figure(figsize=(15, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(flatten_lines(line_interpolation), cmap=plt.cm.gray, interpolation='nearest')
plt.gca().set_axis_off()

plt.savefig('figures/line_data_interpolation.pdf', aspect='normal')
print lib.eval.line_eval(line_interpolation[np.newaxis, ..., np.newaxis])

Abrupt change


In [ ]:
line_interpolation = np.zeros((N_LINES, HEIGHT, WIDTH))
start_line = lib.data.draw_line(START_ANGLE, HEIGHT, WIDTH)[..., 0]
end_line = lib.data.draw_line(END_ANGLE, HEIGHT, WIDTH)[..., 0]

for n in range(N_LINES):
    line_interpolation[n] = start_line if n < N_LINES/2 else end_line

fig = plt.figure(figsize=(15, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(flatten_lines(line_interpolation), cmap=plt.cm.gray, interpolation='nearest')
plt.gca().set_axis_off()

plt.savefig('figures/line_abrupt_interpolation.pdf', aspect='normal')
print lib.eval.line_eval(line_interpolation[np.newaxis, ..., np.newaxis])

Overshooting


In [ ]:
line_interpolation = np.zeros((N_LINES, HEIGHT, WIDTH))

angles = np.linspace(START_ANGLE, END_ANGLE - 2*np.pi, N_LINES)

for n in range(N_LINES):
    line_interpolation[n] = lib.data.draw_line(angles[n], HEIGHT, WIDTH)[..., 0]

fig = plt.figure(figsize=(15, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(flatten_lines(line_interpolation), cmap=plt.cm.gray, interpolation='nearest')
plt.gca().set_axis_off()

plt.savefig('figures/line_overshooting_interpolation.pdf', aspect='normal')
print lib.eval.line_eval(line_interpolation[np.newaxis, ..., np.newaxis])

Unrealistic


In [ ]:
line_interpolation = np.zeros((N_LINES, HEIGHT, WIDTH))
angles = np.linspace(START_ANGLE, END_ANGLE, N_LINES)
blur = np.sin(np.linspace(0, np.pi, N_LINES))

for n in range(N_LINES):
    line = lib.data.draw_line(angles[n], HEIGHT, WIDTH)[..., 0]
    line_interpolation[n] = scipy.ndimage.gaussian_filter(line + np.sqrt(blur[n]), blur[n]*1.5)
        
fig = plt.figure(figsize=(15, 1))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(flatten_lines(line_interpolation), cmap=plt.cm.gray, interpolation='nearest', vmin=-1, vmax=1)
plt.gca().set_axis_off()

plt.savefig('figures/line_unrealistic_interpolation.pdf', aspect='normal')

Line results table


In [ ]:
RESULTS_PATH = '/home/craffel/data/dberth/RERUNS/*/lines32'
experiments = collections.defaultdict(list)
for run_path in glob.glob(RESULTS_PATH):
    for path in glob.glob(os.path.join(run_path, '*')):
        experiments[os.path.split(path)[-1]].append(os.path.join(path, 'tf', 'summaries'))

In [ ]:
ALGS = collections.OrderedDict([
    ('Baseline', 'AEBaseline_depth16_latent16_scales4'),
    ('Dropout', 'AEDropout_depth16_dropout0.5_latent16_scales4'),
    ('Denoising', 'AEDenoising_depth16_latent16_noise1.0_scales4'),
    ('VAE', 'VAE_beta1.0_depth16_latent16_scales4'),
    ('AAE', 'AAE_adversary_lr0.0001_depth16_disc_layer_sizes100,100_latent16_scales4'),
    ('VQ-VAE', 'AEVQVAE_advdepth16_advweight0.0_beta10.0_depth16_emaTrue_latent16_noise0.0_num_blocks1_num_latents10_num_residuals1_reg0.5_scales3_z_log_size14'),
    ('ACAI', 'ARAReg_advdepth16_advweight0.5_depth16_latent16_reg0.2_scales4'),
])

In [ ]:
experiment_results = collections.defaultdict(
    lambda: collections.defaultdict(
        lambda: collections.defaultdict(
            lambda: collections.defaultdict(list))))
for experiment_key, experiment_paths in experiments.items():
    for n, experiment_path in enumerate(experiment_paths):
        print 'Getting results for', experiment_key, n
        for events_file in glob.glob(os.path.join(experiment_path, 'events*')):
            try:
                for e in tf.train.summary_iterator(events_file):
                    for v in e.summary.value:
                        experiment_results[experiment_key][n][v.tag]['step'].append(e.step)
                        experiment_results[experiment_key][n][v.tag]['value'].append(v.simple_value)            
            except Exception as e:
                print e

In [ ]:
mean_distance = collections.defaultdict(list)
mean_smoothness = collections.defaultdict(list)

for experiment_name, events_lists in experiment_results.items():
    for events in events_lists.values():
        mean_distance[experiment_name].append(get_final_value_median(
            events['mean_distance_1']['value'], events['mean_distance_1']['step']))
        mean_smoothness[experiment_name].append(get_final_value_median(
            events['mean_smoothness_1']['value'], events['mean_smoothness_1']['step']))

In [ ]:
print 'Metric & ' + ' & '.join(ALGS.keys()) + ' \\\\'
print 'Mean Distance ($\\times 10^{-3}$) & ' + ' & '.join(
    ['{:.2f}$\pm${:.2f}'.format(np.mean(mean_distance[alg_name])*10**3, np.std(mean_distance[alg_name])*10**3)
     for alg_name in ALGS.values()]) + ' \\\\'
print 'Mean Smoothness & ' + ' & '.join(
    ['{:.2f}$\pm${:.2f}'.format(np.mean(mean_smoothness[alg_name]), np.std(mean_smoothness[alg_name]))
     for alg_name in ALGS.values()]) + ' \\\\'

Real line interpolation examples


In [ ]:
line_interpolation = np.zeros((N_LINES, HEIGHT, WIDTH))
start_line = lib.data.draw_line(START_ANGLE, HEIGHT, WIDTH)[..., 0]
end_line = lib.data.draw_line(END_ANGLE, HEIGHT, WIDTH)[..., 0]

DATASET = 'lines32'
BATCH = 64

for alg_name, alg_path in ALGS.items():

    ae_path = os.path.join(RESULTS_PATH.replace('*', 'RUN3'), alg_path)
    ae, _ = lib.utils.load_ae(ae_path, DATASET, BATCH, all_aes.ALL_AES)
    with lib.utils.HookReport.disable():
        ae.eval_mode()

    input_lines = np.concatenate([
        start_line[np.newaxis, ..., np.newaxis],
        end_line[np.newaxis, ..., np.newaxis]])
    start_latent, end_latent = ae.eval_sess.run(ae.eval_ops.encode, {ae.eval_ops.x: input_lines})
    weights = np.linspace(1, 0, N_LINES).reshape(-1, 1, 1, 1)
    interped_latents = weights*start_latent[np.newaxis] + (1 - weights)*end_latent[np.newaxis]
    output_interp = ae.eval_sess.run(ae.eval_ops.decode, {ae.eval_ops.h: interped_latents})

    fig = plt.figure(figsize=(15, 1))
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(flatten_lines(output_interp[..., 0]), cmap=plt.cm.gray, interpolation='nearest')
    plt.gca().set_axis_off()

    plt.savefig('figures/line_{}_example.pdf'.format(alg_name.lower()), aspect='normal')

Real data interpolations


In [ ]:
BATCH = 64

DBERTH_RESULTS_PATH = '/home/craffel/data/dberth/RERUNS/RUN2'
DATASETS_DEPTHS = collections.OrderedDict([('mnist32', 16), ('svhn32', 64), ('celeba32', 64)])
LATENTS = [2, 16]
ALGS_FORMAT = collections.OrderedDict([
    ('Baseline', 'AEBaseline_depth{depth}_latent{latent}_scales3'),
    ('Dropout', 'AEDropout_depth{depth}_dropout0.5_latent{latent}_scales3'),
    ('Denoising', 'AEDenoising_depth{depth}_latent{latent}_noise1.0_scales3'),
    ('VAE', 'VAE_beta1.0_depth{depth}_latent{latent}_scales3'),
    ('AAE', 'AAE_adversary_lr0.0001_depth{depth}_disc_layer_sizes100,100_latent{latent}_scales3'),
    ('VQ-VAE', 'AEVQVAE_beta10.0_depth{depth}_latent{latent}_num_latents10_run1_scales3_z_log_size14'),
    ('ACAI', 'ARAReg_advdepth{depth}_advweight0.5_depth{depth}_latent{latent}_reg0.2_scales3'),
])
DATASETS_MINS = {'mnist32': -1, 'celeba32': -1.2, 'svhn32': -1}
DATASETS_MAXS = {'mnist32': 1, 'celeba32': 1.2, 'svhn32': 1}

N_IMAGES_PER_INTERPOLATION = 16
N_IMAGES = 4

In [ ]:
def interpolate(sess,
                ops,
                image_left,
                image_right,
                dataset_min,
                dataset_max,
                interpolation=N_IMAGES_PER_INTERPOLATION):
    def batched_op(op, op_input, array):
        return sess.run(op, feed_dict={op_input: array})

    # Interpolations
    interpolation_x = np.array([image_left, image_right], 'f')
    latent_x = batched_op(ops.encode, ops.x, interpolation_x)
    latents = []
    for x in range(interpolation):
        latents.append((latent_x[:1] * (interpolation - x - 1) +
                        latent_x[1:] * x) / float(interpolation - 1))
    latents = np.concatenate(latents, axis=0)
    interpolation_y = batched_op(ops.decode, ops.h, latents)
    interpolation_y = interpolation_y.reshape(
        (interpolation, 1) + interpolation_y.shape[1:])
    interpolation_y = interpolation_y.transpose(1, 0, 2, 3, 4)
    image_interpolation = lib.utils.images_to_grid(interpolation_y)
    padding = np.ones((image_interpolation.shape[0], 2) + image_interpolation.shape[2:])
    image = np.concatenate(
        [image_left, padding, image_interpolation, padding, image_right],
        axis=1)
    image = (image - dataset_min)/(dataset_max - dataset_min)
    image = np.clip(image, 0, 1)
    return image

def get_dataset_samples(sess, ops, dataset, batches=100):
    batch = FLAGS.batch
    with tf.Graph().as_default():
        data_in = dataset.make_one_shot_iterator().get_next()
        with tf.Session() as sess_new:
            images = []
            labels = []
            while True:
                try:
                    payload = sess_new.run(data_in)
                    images.append(payload['x'])
                    assert images[-1].shape[0] == 1
                    labels.append(payload['label'])
                    if len(images) == batches:
                        break
                except tf.errors.OutOfRangeError:
                    break
    images = np.concatenate(images, axis=0)
    labels = np.concatenate(labels, axis=0)
    latents = [sess.run(ops.encode,
                        feed_dict={ops.x: images[p:p + batch]})
               for p in range(0, images.shape[0], FLAGS.batch)]
    latents = np.concatenate(latents, axis=0)
    latents = latents.reshape([latents.shape[0], -1])
    return images, latents, labels

left_images = collections.defaultdict(lambda: None)
right_images = collections.defaultdict(lambda: None)

for dataset, depth in DATASETS_DEPTHS.items():
    for latent in LATENTS:
        for alg_name, alg_format in ALGS_FORMAT.items():
            for n in range(N_IMAGES):
                output_name = '{}_{}_latent_{}_interpolation_{}'.format(dataset, alg_name.lower(), latent, n + 1)
                alg_path = os.path.join(DBERTH_RESULTS_PATH, dataset, alg_format.format(depth=depth, latent=latent))

                if 1: # try:
                    ae, ds = lib.utils.load_ae(
                        alg_path, dataset, BATCH, all_aes.ALL_AES, return_dataset=True)
                    with lib.utils.HookReport.disable():
                        ae.eval_mode()

                    images, latents, labels = get_dataset_samples(ae.eval_sess,
                                                                  ae.eval_ops,
                                                                  ds.test)
                    labels = np.argmax(labels, axis=1)
                    if left_images[n] is None:
                        left_img_idx = n
                        if dataset == 'celeba32':
                            right_img_idx = N_IMAGES + n
                        else:
                            if n < N_IMAGES/2:
                                right_img_idx = np.flatnonzero(labels == labels[n])[N_IMAGES + n]
                            else:
                                right_img_idx = np.flatnonzero(labels != labels[n])[N_IMAGES + n]
                        print left_img_idx, labels[left_img_idx]
                        print right_img_idx, labels[right_img_idx]
                        left_images[n] = images[left_img_idx]
                        right_images[n] = images[right_img_idx]
                    left_image = left_images[n]
                    right_image = right_images[n]
                    image = interpolate(ae.eval_sess, ae.eval_ops, left_image, right_image,
                                        DATASETS_MINS[dataset], DATASETS_MAXS[dataset])
                fig = plt.figure(figsize=(15, 1))
                ax = plt.Axes(fig, [0., 0., 1., 1.])
                ax.set_axis_off()
                fig.add_axes(ax)
                ax.imshow(np.squeeze(image), cmap=plt.cm.gray, interpolation='nearest')
                plt.gca().set_axis_off()

                plt.savefig('figures/{}.pdf'.format(output_name), aspect='normal')
                plt.close()
    for n in range(N_IMAGES):
        del left_images[n]
        del right_images[n]

In [ ]:
DATASET_NAMES = {'mnist32': 'MNIST', 'svhn32': 'SVHN', 'celeba32': 'CelebA'}

In [ ]:
output = ""

for dataset, depth in DATASETS_DEPTHS.items():
    for latent in LATENTS:
        output += r"""
    \begin{figure}
      \centering
"""
        for n in range(N_IMAGES):
            alg_list = collections.OrderedDict()
            for alg_name, alg_format in ALGS_FORMAT.items():
                figure_name = '{}_{}_latent_{}_interpolation_{}'.format(dataset, alg_name.lower(), latent, n + 1)
                alg_list[figure_name] = alg_name
                if alg_name == ALGS_FORMAT.keys()[-1]:
                    reset = r"\addtocounter{{subfigure}}{{-{}}}".format(len(ALGS_FORMAT))
                else:
                    reset = ""
                output += r"""
      \begin{{subfigure}}[b]{{\textwidth}}
        \centering\parbox{{.09\linewidth}}{{\vspace{{0.3em}}\subcaption{{}}\label{{fig:{figure_name}}}}}
        \parbox{{.75\linewidth}}{{\includegraphics[width=\linewidth]{{figures/{figure_name}.pdf}}}}{reset}
      \end{{subfigure}}
""".format(figure_name=figure_name, reset=reset)
                if alg_name == ALGS_FORMAT.keys()[-1]:
                    output += r"""
      \vspace{0.5em}
"""
        output += r"""
      \caption{{Example interpolations on {} with a latent dimensionality of {} for """.format(
            DATASET_NAMES[dataset], latent*16)
        output += ', '.join([r'(\subref{{fig:{}}}) {}'.format(fn, an) for fn, an in alg_list.items()])
        output += r""" autoencoders.}}
      \label{{fig:{}_{}_interpolations}}
    \end{{figure}}


""".format(dataset, latent)

print output

VAE line samples


In [ ]:
RESULTS_PATH = '/home/craffel/data/autoencoder/results_final/lines32'

line_interpolation = np.zeros((N_LINES, HEIGHT, WIDTH))
start_line = lib.data.draw_line(START_ANGLE, HEIGHT, WIDTH)[..., 0]
end_line = lib.data.draw_line(END_ANGLE, HEIGHT, WIDTH)[..., 0]

DATASET = 'lines32'
BATCH = 64

ae_path = os.path.join(RESULTS_PATH, 'VAE_beta1.0_depth16_latent16_scales4')
ae, _ = lib.utils.load_ae(ae_path, DATASET, BATCH, all_aes.ALL_AES)
with lib.utils.HookReport.disable():
    ae.eval_mode()

random_latents = np.random.standard_normal(size=(16*16, 2, 2, 16))

random_images = ae.eval_sess.run(ae.eval_ops.decode, {ae.eval_ops.h: random_latents})

fig = plt.figure(figsize=(15, 15))
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
padding = np.ones((2, WIDTH*N_LINES + 4*N_LINES))
line_matrix = np.concatenate([
    np.concatenate([padding, flatten_lines(random_images[n:n + 16, ..., 0]), padding], axis=0)
    for n in range(0, 16*16, 16)], axis=0)
ax.imshow(line_matrix, cmap=plt.cm.gray, interpolation='nearest')
plt.gca().set_axis_off()

plt.savefig('figures/line_vae_samples.pdf'.format(alg_name.lower()), aspect='normal')

Single-layer classifier table


In [ ]:
def get_all_results(results_path, event_key):
    experiments = collections.defaultdict(list)
    for run_path in glob.glob(results_path):
        for path in glob.glob(os.path.join(run_path, '*')):
            experiments[os.path.split(path)[-1]].append(os.path.join(path, 'tf', 'summaries'))

    experiment_results = collections.defaultdict(
        lambda: collections.defaultdict(
            lambda: collections.defaultdict(
                lambda: collections.defaultdict(list))))
    for experiment_key, experiment_paths in experiments.items():
        for n, experiment_path in enumerate(experiment_paths):
            print 'Getting results for', experiment_key, n
            for events_file in glob.glob(os.path.join(experiment_path, 'events*')):
                try:
                    for e in tf.train.summary_iterator(events_file):
                        for v in e.summary.value:
                            experiment_results[experiment_key][n][v.tag]['step'].append(e.step)
                            experiment_results[experiment_key][n][v.tag]['value'].append(v.simple_value)            
                except Exception as e:
                    print e

    event_values = collections.defaultdict(list)
    for experiment_name, events_lists in experiment_results.items():
        for events in events_lists.values():
            event_values[experiment_name].append(get_final_value_median(
                events[event_key]['value'], events[event_key]['step']))
    return event_values

In [ ]:
RESULTS_PATH = '/home/craffel/data/dberth/RERUNS/*/mnist32'
accuracy = get_all_results(RESULTS_PATH, 'latent_accuracy_1')

In [ ]:
ALGS = collections.OrderedDict([
    ('Baseline', 'AEBaseline_depth16_latent{}_scales3'),
    ('Dropout', 'AEDropout_depth16_dropout0.5_latent{}_scales3'),
    ('Denoising', 'AEDenoising_depth16_latent{}_noise1.0_scales3'),
    ('VAE', 'VAE_beta1.0_depth16_latent{}_scales3'),
    ('AAE', 'AAE_adversary_lr0.0001_depth16_disc_layer_sizes100,100_latent{}_scales3'),
    ('VQ-VAE', 'AEVQVAE_advdepth16_advweight0.0_beta10.0_depth16_emaTrue_latent{}_noiseFalse_num_blocks1_num_latents10_num_residuals1_reg0.5_scales3_z_log_size14'),
    ('ACAI', 'ARAReg_advdepth16_advweight0.5_depth16_latent{}_reg0.2_scales3')])

for latent_size in [2, 16]:
    print '{} & '.format(latent_size*16) + ' & '.join(
        ['{:.2f}$\pm${:.2f}'.format(
            np.mean(accuracy[alg_name.format(latent_size)]),
            np.std(accuracy[alg_name.format(latent_size)]))
         for alg_name in ALGS.values()]) + ' \\\\'

In [ ]:
RESULTS_PATH = '/home/craffel/data/dberth/RERUNS/*/svhn32'
accuracy = get_all_results(RESULTS_PATH, 'latent_accuracy_1')

In [ ]:
ALGS = collections.OrderedDict([
    ('Baseline', 'AEBaseline_depth64_latent{}_scales3'),
    ('Dropout', 'AEDropout_depth64_dropout0.5_latent{}_scales3'),
    ('Denoising', 'AEDenoising_depth64_latent{}_noise1.0_scales3'),
    ('VAE', 'VAE_beta1.0_depth64_latent{}_scales3'),
    ('AAE', 'AAE_adversary_lr0.0001_depth64_disc_layer_sizes100,100_latent{}_scales3'),
    ('VQ-VAE', 'AEVQVAE_advdepth16_advweight0.0_beta10.0_depth64_emaTrue_latent{}_noiseFalse_num_blocks1_num_latents10_num_residuals1_reg0.5_scales3_z_log_size14'),
    ('ACAI', 'ARAReg_advdepth64_advweight0.5_depth64_latent{}_reg0.2_scales3')])

for latent_size in [2, 16]:
    print '{} & '.format(latent_size*16) + ' & '.join(
        ['{:.2f}$\pm${:.2f}'.format(
            np.mean(accuracy[alg_name.format(latent_size)]),
            np.std(accuracy[alg_name.format(latent_size)]))
         for alg_name in ALGS.values()]) + ' \\\\'

In [ ]:
RESULTS_PATH = '/home/craffel/data/dberth/RERUNS/*/cifar10'
accuracy = get_all_results(RESULTS_PATH, 'latent_accuracy_1')

In [ ]:
ALGS = collections.OrderedDict([
    ('Baseline', 'AEBaseline_depth64_latent{}_scales3'),
    ('Dropout', 'AEDropout_depth64_dropout0.75_latent{}_scales3'),
    ('Denoising', 'AEDenoising_depth64_latent{}_noise1.0_scales3'),
    ('VAE', 'VAE_beta1.0_depth64_latent{}_scales3'),
    ('AAE', 'AAE_adversary_lr0.0001_depth64_disc_layer_sizes100,100_latent{}_scales3'),
    ('VQ-VAE', 'AEVQVAE_advdepth16_advweight0.0_beta10.0_depth64_emaTrue_latent{}_noiseFalse_num_blocks1_num_latents10_num_residuals1_reg0.5_scales3_z_log_size14'),
    ('ACAI', 'ARAReg_advdepth64_advweight0.5_depth64_latent{}_reg0.2_scales3')])

for latent_size in [16, 64]:
    print '{} & '.format(latent_size*16) + ' & '.join(
        ['{:.2f}$\pm${:.2f}'.format(
            np.mean(accuracy[alg_name.format(latent_size)]),
            np.std(accuracy[alg_name.format(latent_size)]))
         for alg_name in ALGS.values()]) + ' \\\\'