In [1]:
%load_ext autoreload
%autoreload 2
# Get the final results for the ICLR paper.
import pandas as pd
import glob
import os
import numpy as np
from results_utils import filter_result_files
from results_utils import latest_checkpoints
from results_utils import filename_to_method
from results_utils import extract_global_step
from results_utils import compute_mean_std
from IPython.display import display, HTML

import cPickle as pickle
from collections import namedtuple
from collections import defaultdict

In [2]:
Query = namedtuple('Query', ['label', 'mask'])

In [3]:
def build_latex_results(
    te_4_iid='',
    jm_4_iid='',
    bi_4_iid='',
    te_3_iid='',
    jm_3_iid='',
    bi_3_iid='',
    te_2_iid='',
    jm_2_iid='',
    bi_2_iid='',
    te_1_iid='',
    jm_1_iid='',
    bi_1_iid='',
    te_4_cmp='',
    jm_4_cmp='',
    bi_4_cmp='',
    te_4_cor_iid_pct=0.0, 
    te_4_cor_iid_err=0.0, 
    jm_4_cor_iid_pct=0.0, 
    jm_4_cor_iid_err=0.0, 
    bi_4_cor_iid_pct=0.0, 
    bi_4_cor_iid_err=0.0, 
    te_3_cov_iid_pct=0.0, 
    te_3_cov_iid_err=0.0, 
    jm_3_cov_iid_pct=0.0, 
    jm_3_cov_iid_err=0.0, 
    bi_3_cov_iid_pct=0.0, 
    bi_3_cov_iid_err=0.0, 
    te_3_cor_iid_pct=0.0, 
    te_3_cor_iid_err=0.0, 
    jm_3_cor_iid_pct=0.0, 
    jm_3_cor_iid_err=0.0, 
    bi_3_cor_iid_pct=0.0, 
    bi_3_cor_iid_err=0.0, 
    te_2_cov_iid_pct=0.0, 
    te_2_cov_iid_err=0.0, 
    jm_2_cov_iid_pct=0.0, 
    jm_2_cov_iid_err=0.0, 
    bi_2_cov_iid_pct=0.0, 
    bi_2_cov_iid_err=0.0, 
    te_2_cor_iid_pct=0.0, 
    te_2_cor_iid_err=0.0, 
    jm_2_cor_iid_pct=0.0, 
    jm_2_cor_iid_err=0.0, 
    bi_2_cor_iid_pct=0.0, 
    bi_2_cor_iid_err=0.0, 
    te_1_cov_iid_pct=0.0, 
    te_1_cov_iid_err=0.0, 
    jm_1_cov_iid_pct=0.0, 
    jm_1_cov_iid_err=0.0, 
    bi_1_cov_iid_pct=0.0, 
    bi_1_cov_iid_err=0.0, 
    te_1_cor_iid_pct=0.0, 
    te_1_cor_iid_err=0.0, 
    jm_1_cor_iid_pct=0.0, 
    jm_1_cor_iid_err=0.0, 
    bi_1_cor_iid_pct=0.0, 
    bi_1_cor_iid_err=0.0, 
    te_4_cor_cmp_pct=0.0, 
    te_4_cor_cmp_err=0.0, 
    jm_4_cor_cmp_pct=0.0, 
    jm_4_cor_cmp_err=0.0, 
    bi_4_cor_cmp_pct=0.0, 
    bi_4_cor_cmp_err=0.0,
    **_
  ):

  return r"""
\begin{{table}}
  \centering
  \begin{{tabular}}{{cccc}}
    \toprule

    \textbf{{Method}}  & \textbf{{\#Attributes}} & \textbf{{Coverage}} (\%)                        & \textbf{{\Correctness}} (\%)                                \\

    \toprule
    \rowcolor{{Gray}}\multicolumn{{4}}{{c}}{{\textbf{{\iid}}}}\\
    \toprule
    % {te_4_iid}
    \telbo  & \multirow{{3}}{{*}}{{4}} & -                                               & {te_4_cor_iid_pct:.2f} {{\tiny$\pm$}} {te_4_cor_iid_err:.2f}\\
    % {jm_4_iid}
    \jmvae  &  & -                                               & {jm_4_cor_iid_pct:.2f} {{\tiny$\pm$}} {jm_4_cor_iid_err:.2f}  \\
    % {bi_4_iid}
    \bivcca &  & -                                               & {bi_4_cor_iid_pct:.2f} {{\tiny$\pm$}} {bi_4_cor_iid_err:.2f} \\

    \midrule

    % {te_3_iid}
    \telbo  & \multirow{{3}}{{*}}{{3}} & {te_3_cov_iid_pct:.2f} {{\tiny$\pm$}} {te_3_cov_iid_err:.2f} & {te_3_cor_iid_pct:.2f} {{\tiny$\pm$}} {te_3_cor_iid_err:.2f} \\
    % {jm_3_iid}
    \jmvae  & & {jm_3_cov_iid_pct:.2f} {{\tiny$\pm$}} {jm_3_cov_iid_err:.2f} & {jm_3_cor_iid_pct:.2f} {{\tiny$\pm$}} {jm_3_cor_iid_err:.2f}  \\
    % {bi_3_iid}
    \bivcca & & {bi_3_cov_iid_pct:.2f} {{\tiny$\pm$}} {bi_3_cov_iid_err:.2f} & {bi_3_cor_iid_pct:.2f} {{\tiny$\pm$}} {bi_3_cor_iid_err:.2f}  \\

    \midrule

    % {te_2_iid}
    \telbo  & \multirow{{3}}{{*}}{{2}} & {te_2_cov_iid_pct:.2f} {{\tiny$\pm$}} {te_2_cov_iid_err:.2f} & {te_2_cor_iid_pct:.2f} {{\tiny$\pm$}} {te_2_cor_iid_err:.2f}  \\
    % {jm_2_iid}
    \jmvae  &  & {jm_2_cov_iid_pct:.2f} {{\tiny$\pm$}} {jm_2_cov_iid_err:.2f} & {jm_2_cor_iid_pct:.2f} {{\tiny$\pm$}} {jm_2_cor_iid_err:.2f} \\
    % {bi_2_iid}
    \bivcca &  & {bi_2_cov_iid_pct:.2f} {{\tiny$\pm$}} {bi_2_cov_iid_err:.2f} & {bi_2_cor_iid_pct:.2f} {{\tiny$\pm$}} {bi_2_cor_iid_err:.2f} \\

    \midrule

    % {te_1_iid}
    \telbo  & \multirow{{3}}{{*}}{{1}} & {te_1_cov_iid_pct:.2f} {{\tiny$\pm$}} {te_1_cov_iid_err:.2f} & {te_1_cor_iid_pct:.2f} {{\tiny$\pm$}} {te_1_cor_iid_err:.2f} \\
    % {jm_1_iid}
    \jmvae  & & {jm_1_cov_iid_pct:.2f} {{\tiny$\pm$}} {jm_1_cov_iid_err:.2f} & {jm_1_cor_iid_pct:.2f} {{\tiny$\pm$}} {jm_1_cor_iid_err:.2f} \\
    % {bi_1_iid}
    \bivcca & & {bi_1_cov_iid_pct:.2f} {{\tiny$\pm$}} {bi_1_cov_iid_err:.2f} & {bi_1_cor_iid_pct:.2f} {{\tiny$\pm$}} {bi_1_cor_iid_err:.2f} \\
    
    \toprule
    \rowcolor{{Gray}}\multicolumn{{4}}{{c}}{{\textbf{{\comp}}}}\\
    \toprule
    % {te_4_cmp}
    \telbo  & \multirow{{3}}{{*}}{{4}} & -                                               & {te_4_cor_cmp_pct:.2f} {{\tiny$\pm$}} {te_4_cor_cmp_err:.2f}  \\
    % {jm_4_cmp}
    \jmvae  & & -                                               & {jm_4_cor_cmp_pct:.2f} {{\tiny$\pm$}} {jm_4_cor_cmp_err:.2f}  \\
    % {bi_4_cmp}
    \bivcca & & -                                               & {bi_4_cor_cmp_pct:.2f} {{\tiny$\pm$}} {bi_4_cor_cmp_err:.2f}  \\

    \bottomrule
  \end{{tabular}}
\end{{table}}
  """.format(**locals()).replace('-                                               &', '-            &').replace(r'(\%)                        &', r'(\%) &').replace(r'(\%)                               &', r'(\%) &')

In [4]:
# Load the results files and put them in a nice pandas dataframe.
def plot_metrics(result_files, filt=('multimodal_elbo'), metrics=[(4.0, 'comprehensibility')], ban='_val_iclr_mnista_fresh'):
    results_latest = [x for x in result_files if filt in x]
    results_data = defaultdict(dict)
    results_data_mean_only = defaultdict(dict)
    for rfile in results_latest:
        pf = pickle.load(open(rfile, 'r'))

        # Iterate through the file and extract the mean and standard deviation in performance
        # across multiple splits.
        for metric in metrics:
            if metric[-1] == 'comprehensibility':
                metric_value, metric_std = compute_mean_std(pf[metric])
                metric_value = 1 - metric_value/metric[0]
                metric_std = metric_std/metric[0]
            else:
                metric_value, metric_std = compute_mean_std(pf[metric])
                
            metric_str = str(metric[0]) + '_' + str(metric[1])
                
            results_data[metric_str][filename_to_method(rfile, ban=ban)] = (metric_value, metric_std)
            results_data_mean_only[metric_str][filename_to_method(rfile, ban=ban)] = metric_value
        results_data['global_step'][filename_to_method(rfile, ban=ban)] = extract_global_step(rfile)  
    
    results_data = pd.DataFrame(results_data)
    results_data_mean_only = pd.DataFrame(results_data_mean_only)
    display(results_data)
    
    return results_data, results_data_mean_only

In [5]:
def pick_best_method(data_frame, sort_by='consolidated_jsd_sim'):
    """NOTE: data_frame should not have the std term, just mean."""
    cols = list(data_frame)
    col_subset = [x for x in cols if sort_by in x]
    # NOTE: Assumes that larger is better!!
    sorted_data_frame = pd.DataFrame(
        data_frame[col_subset].sum(axis=1)).sort_values(0, ascending=False)
    return sorted_data_frame.index[0]

In [6]:
SPLIT = 'val'
EXP_PREFIX = 'iclr_mnista_fresh_iid'
PATH_TO_RESULTS = ('/coc/scratch/rvedantam3/runs/imagination/%s_%s' % (EXP_PREFIX, SPLIT))

result_files = glob.glob(PATH_TO_RESULTS + "/*.p")
results_to_show = latest_checkpoints(filter_result_files(result_files))

In [7]:
_, triple_elbo_results = plot_metrics(results_to_show, filt='multimodal_elbo', metrics=[(4.0, 'comprehensibility'),
                                                               (3.0, 'comprehensibility'),
                                                               (2.0, 'comprehensibility'),
                                                               (1.0, 'comprehensibility'),
                                                               (3.0, 'parametric_jsd_sim'),
                                                               (2.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_consolidated_jsd_sim'),
                                                               (2.0, 'parametric_consolidated_jsd_sim'),
                                                               (3.0, 'parametric_consolidated_jsd_sim'),
                                                               (4.0, 'parametric_consolidated_jsd_sim'),
                                                              ])


1.0_comprehensibility 1.0_parametric_consolidated_jsd_sim 1.0_parametric_jsd_sim 2.0_comprehensibility 2.0_parametric_consolidated_jsd_sim 2.0_parametric_jsd_sim 3.0_comprehensibility 3.0_parametric_consolidated_jsd_sim 3.0_parametric_jsd_sim 4.0_comprehensibility 4.0_parametric_consolidated_jsd_sim global_step
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_100_xsg_1_ax_1_ay_10_l1_pxyz_5e-6kl (0.855260416667, 0.0119988515799) (0.916530703034, 0.00176841815444) (0.911301876555, 0.000476157525744) (0.830859375, 0.00492800268907) (0.912668456266, 0.00127829888608) (0.904941838378, 0.00293141758778) (0.824756944444, 0.00231261469519) (0.91327479986, 0.00155193704777) (0.895332853305, 0.00531431582596) (0.8230859375, 0.00339057099611) (0.921269880175, 0.00246967447622) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_100_xsg_1_ax_1_ay_1_l1_pxyz_5e-6kl (0.671875, 0.0124619647021) (0.896213330004, 0.00292616029605) (0.915394424145, 0.00210683439837) (0.649270833333, 0.0103322097003) (0.861449896751, 0.00240049438653) (0.899659194504, 0.00253792586341) (0.636979166667, 0.00785424802247) (0.835398439855, 0.00365639226455) (0.882448465678, 0.00269786973169) (0.60171875, 0.00438022846062) (0.799118557746, 0.00336525101372) 246859
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_100_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl (0.84953125, 0.0131950840488) (0.909283500795, 0.00150075460976) (0.903487765104, 0.000886120720294) (0.828411458333, 0.00766545648284) (0.911760145235, 0.00141427482062) (0.902245632008, 0.00169751840471) (0.820208333333, 0.00470623826708) (0.911130409966, 0.00141624031042) (0.8873653034, 0.00447628051904) (0.814205729167, 0.00394284077552) (0.918545537449, 0.00232487672662) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_1_xsg_1_ax_1_ay_10_l1_pxyz_5e-6kl (0.8578125, 0.0133500058521) (0.908075511093, 0.000926053829972) (0.899595389002, 0.00114543527532) (0.812526041667, 0.00591231867822) (0.901596782877, 0.0011905683371) (0.892034496115, 0.00222374546659) (0.795034722222, 0.00535285597879) (0.898628561178, 0.00270789874102) (0.882549061366, 0.00450840630135) (0.776549479167, 0.00360406693925) (0.896574317358, 0.00263121356643) 61549
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_1_xsg_1_ax_1_ay_1_l1_pxyz_5e-6kl (0.629270833333, 0.0094464867073) (0.891313673665, 0.00123916761562) (0.915751769515, 0.000452623529145) (0.610260416667, 0.00973603154935) (0.856171698585, 0.00257097046006) (0.908342014105, 0.00451920745113) (0.610711805556, 0.006951732287) (0.827570487601, 0.00257868307556) (0.890904142358, 0.00575443344768) (0.58046875, 0.00232231225788) (0.789469353063, 0.00190360662024) 87692
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_1_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl (0.857395833333, 0.0146219726211) (0.907668457605, 0.00223606825459) (0.898556675069, 0.000673219475358) (0.809453125, 0.00662278035405) (0.905797610424, 0.00146949751292) (0.899435326479, 0.00113930011476) (0.80828125, 0.00340676141534) (0.906390568374, 0.00142392717721) (0.889452079199, 0.00436218759749) (0.802747395833, 0.00380297506666) (0.912871839023, 0.0020818796278) 157975
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_10_l1_pxyz_5e-6kl (0.849635416667, 0.0149838961588) (0.912535187611, 0.00128414688446) (0.90763428063, 0.00209494016337) (0.818359375, 0.00798125040786) (0.908140743532, 0.00211710131646) (0.904372768173, 0.00237108601426) (0.816302083333, 0.00437224339964) (0.910901473159, 0.00133520471272) (0.896173428081, 0.0043222122254) (0.805963541667, 0.00373237139707) (0.9136277814, 0.00248184223237) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_1_l1_pxyz_5e-6kl (0.640520833333, 0.0133780175724) (0.886862138614, 0.00205198647207) (0.908632839122, 0.0013025524121) (0.627630208333, 0.0119199737726) (0.855299454409, 0.00322747809813) (0.900053273335, 0.00289162852713) (0.62671875, 0.00533565442476) (0.829408015166, 0.00246231320708) (0.874065784643, 0.00347868829835) (0.604765625, 0.00335154327727) (0.802558979773, 0.00218203974026) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl (0.856354166667, 0.00755944325276) (0.915070130391, 0.00100576443479) (0.907727065737, 0.00116678079356) (0.824114583333, 0.00706223881031) (0.911911594133, 0.00121431813238) (0.902264459367, 0.00276461115835) (0.819479166667, 0.00481206657901) (0.915642313804, 0.00244158403084) (0.896333364742, 0.00648708911956) (0.8208203125, 0.00556454418782) (0.923060436471, 0.00280914797257) 250000

In [8]:
_, jmvae_results = plot_metrics(results_to_show, filt='jmvae', metrics=[(4.0, 'comprehensibility'),
                                                               (3.0, 'comprehensibility'),
                                                               (2.0, 'comprehensibility'),
                                                               (1.0, 'comprehensibility'),
                                                               (3.0, 'parametric_jsd_sim'),
                                                               (2.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_consolidated_jsd_sim'),
                                                               (2.0, 'parametric_consolidated_jsd_sim'),
                                                               (3.0, 'parametric_consolidated_jsd_sim'),
                                                               (4.0, 'parametric_consolidated_jsd_sim'),             
                                                              ])


1.0_comprehensibility 1.0_parametric_consolidated_jsd_sim 1.0_parametric_jsd_sim 2.0_comprehensibility 2.0_parametric_consolidated_jsd_sim 2.0_parametric_jsd_sim 3.0_comprehensibility 3.0_parametric_consolidated_jsd_sim 3.0_parametric_jsd_sim 4.0_comprehensibility 4.0_parametric_consolidated_jsd_sim global_step
_iid_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_10_l1_pxyz_5e-6kl (0.84640625, 0.00730541857973) (0.903667637661, 0.000850983109392) (0.893409469645, 0.00144667802837) (0.817083333333, 0.00969554717294) (0.902882696082, 0.00197665854057) (0.885480764429, 0.00353597985551) (0.820486111111, 0.00181005715541) (0.911221739681, 0.0016159949409) (0.874098334938, 0.00726751590116) (0.838619791667, 0.00195243043205) (0.93434573567, 0.00098316499354) 262075
_iid_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_1_l1_pxyz_5e-6kl (0.675, 0.0131173506148) (0.878332405709, 0.00194590359759) (0.894262457134, 0.00135543290468) (0.669947916667, 0.009072222993) (0.851031487835, 0.00145668940319) (0.876520055521, 0.00306704034706) (0.671128472222, 0.00652177250366) (0.836573687371, 0.00186421463572) (0.854416836321, 0.00516678658147) (0.653528645833, 0.00318582302019) (0.8223064198, 0.00179937579376) 90908
_iid_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_50_l1_pxyz_5e-6kl (0.834322916667, 0.0106324014253) (0.897724525138, 0.00167424953689) (0.887948164229, 0.00190952818385) (0.81734375, 0.0122147354161) (0.902183648052, 0.00170937953784) (0.885548864737, 0.00359954468789) (0.831944444444, 0.00343249846578) (0.916068225137, 0.000764309326666) (0.879288490414, 0.00538393177257) (0.856341145833, 0.00297223160248) (0.941697087811, 0.00123434239788) 268201

In [9]:
_, bivcca_results = plot_metrics(results_to_show, filt='bivcca', metrics=[(4.0, 'comprehensibility'),
                                                               (3.0, 'comprehensibility'),
                                                               (2.0, 'comprehensibility'),
                                                               (1.0, 'comprehensibility'),
                                                               (3.0, 'parametric_jsd_sim'),
                                                               (2.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_consolidated_jsd_sim'),
                                                               (2.0, 'parametric_consolidated_jsd_sim'),
                                                               (3.0, 'parametric_consolidated_jsd_sim'),
                                                               (4.0, 'parametric_consolidated_jsd_sim'),
                                                              ])


1.0_comprehensibility 1.0_parametric_consolidated_jsd_sim 1.0_parametric_jsd_sim 2.0_comprehensibility 2.0_parametric_consolidated_jsd_sim 2.0_parametric_jsd_sim 3.0_comprehensibility 3.0_parametric_consolidated_jsd_sim 3.0_parametric_jsd_sim 4.0_comprehensibility 4.0_parametric_consolidated_jsd_sim global_step
_iid_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_10_l1_pxyz_5e-6kl (0.750625, 0.0155907610974) (0.846681250923, 0.000922412525209) (0.838372371919, 0.00200495326197) (0.6921875, 0.0121469790233) (0.818733930461, 0.00302257716197) (0.81041115435, 0.00198955220934) (0.671701388889, 0.00602708169304) (0.797739837146, 0.0019770450987) (0.765401335867, 0.00718778182752) (0.644244791667, 0.00582638281043) (0.784037612568, 0.00402106107369) 250000
_iid_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_1_l1_pxyz_5e-6kl (0.709166666667, 0.00963246013199) (0.85278301865, 0.000764532978388) (0.857322217565, 0.00240523835364) (0.672317708333, 0.011457623084) (0.818544295513, 0.0022177353761) (0.823693178383, 0.00477124184696) (0.656388888889, 0.00420217795621) (0.797744383726, 0.00188916525321) (0.783996024656, 0.00283975574106) (0.627356770833, 0.00642923341602) (0.77742568658, 0.0044047440634) 215441
_iid_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_50_l1_pxyz_5e-6kl (0.701510416667, 0.015889599859) (0.840186063714, 0.00170531124526) (0.844257240313, 0.0027877642866) (0.71765625, 0.009917602153) (0.845554782963, 0.00342673902417) (0.844665057102, 0.00342928592779) (0.706788194444, 0.0111650168074) (0.830685770911, 0.00397868880007) (0.812695184301, 0.00694197968591) (0.658125, 0.00633893966157) (0.799895932058, 0.00413472641861) 215774

In [10]:
# Load test numbers.
SPLIT = 'test'
EXP_PREFIX = 'iclr_mnista_fresh_iid'
PATH_TO_RESULTS = ('/coc/scratch/rvedantam3/runs/imagination/%s_%s' % (EXP_PREFIX, SPLIT))

result_files = glob.glob(PATH_TO_RESULTS + "/*.p")
results_to_show = latest_checkpoints(filter_result_files(result_files))

In [11]:
results_to_show = latest_checkpoints(filter_result_files(result_files))
raw_triple_elbo_results, _ = plot_metrics(results_to_show, filt='multimodal_elbo', metrics=[(4.0, 'comprehensibility'),
                                                               (3.0, 'comprehensibility'),
                                                               (2.0, 'comprehensibility'),
                                                               (1.0, 'comprehensibility'),
                                                               (3.0, 'parametric_jsd_sim'),
                                                               (2.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_consolidated_jsd_sim'),
                                                               (2.0, 'parametric_consolidated_jsd_sim'),
                                                               (3.0, 'parametric_consolidated_jsd_sim'),
                                                               (4.0, 'parametric_consolidated_jsd_sim'),
                                                              ], ban='_test_iclr_mnista_fresh')


1.0_comprehensibility 1.0_parametric_consolidated_jsd_sim 1.0_parametric_jsd_sim 2.0_comprehensibility 2.0_parametric_consolidated_jsd_sim 2.0_parametric_jsd_sim 3.0_comprehensibility 3.0_parametric_consolidated_jsd_sim 3.0_parametric_jsd_sim 4.0_comprehensibility 4.0_parametric_consolidated_jsd_sim global_step
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_100_xsg_1_ax_1_ay_10_l1_pxyz_5e-6kl (0.832552083333, 0.0164603110626) (0.916570871521, 0.000993537575057) (0.914220182417, 0.00209515167714) (0.8296875, 0.0114070229) (0.91202428749, 0.00102775685713) (0.905487560782, 0.00551402699894) (0.822083333333, 0.005625267912) (0.916956098959, 0.00216049373914) (0.91238697421, 0.00374369495676) (0.8230859375, 0.00339057099611) (0.921269880175, 0.00246967447622) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_100_xsg_1_ax_1_ay_1_l1_pxyz_5e-6kl (0.65421875, 0.0255863317665) (0.891724788451, 0.00360082809501) (0.912951386521, 0.00158509131198) (0.646041666667, 0.015499189677) (0.860462266595, 0.00311069390321) (0.898798933087, 0.00387988998253) (0.620902777778, 0.0075948462032) (0.831149391634, 0.00348307472838) (0.896753690418, 0.00411889355293) (0.6053515625, 0.00403507200656) (0.800349991716, 0.00331039291912) 243458
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_100_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl (0.83578125, 0.0174689705065) (0.911755028161, 0.00159288915079) (0.908214354556, 0.00220034083445) (0.82578125, 0.0137447216036) (0.910784366106, 0.00162435374551) (0.901759225817, 0.00512432148128) (0.813645833333, 0.00550169200966) (0.914205724421, 0.00164717332312) (0.908861245989, 0.00438816682836) (0.814205729167, 0.00394284077552) (0.918545537449, 0.00232487672662) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_1_xsg_1_ax_1_ay_10_l1_pxyz_5e-6kl (0.833958333333, 0.0205016566488) (0.907819150722, 0.00167965974964) (0.903421696192, 0.00312429832723) (0.812109375, 0.0134555051369) (0.901022420223, 0.00185718217617) (0.892612956549, 0.00561951810169) (0.795086805556, 0.00560675195084) (0.901722661748, 0.00266253206353) (0.895538863315, 0.00619962686885) (0.776549479167, 0.00360406693925) (0.896574317358, 0.00263121356643) 61549
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_1_xsg_1_ax_1_ay_1_l1_pxyz_5e-6kl (0.61359375, 0.0221555668089) (0.889614373804, 0.00148349417217) (0.917889665399, 0.00265169364627) (0.602213541667, 0.0147018989703) (0.85241803657, 0.00262347942184) (0.905071355445, 0.0038376202795) (0.585555555556, 0.00421363859306) (0.820284101789, 0.00199303192793) (0.903744632476, 0.00484909407223) (0.58046875, 0.00232231225788) (0.789469353063, 0.00190360662024) 87692
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_1_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl (0.82859375, 0.0169494243593) (0.908072655405, 0.00217619276614) (0.90410634997, 0.00317964318078) (0.814609375, 0.0130352524329) (0.905762663804, 0.00148267518316) (0.897407683543, 0.00518618245511) (0.802552083333, 0.00608917305957) (0.909872761974, 0.00241572223578) (0.90913577524, 0.00530342528859) (0.802747395833, 0.00380297506666) (0.912871839023, 0.0020818796278) 157975
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_10_l1_pxyz_5e-6kl (0.830052083333, 0.017568071971) (0.913823241809, 0.00177861575333) (0.912448207954, 0.00212694032765) (0.825078125, 0.0138167178703) (0.909551417362, 0.00194135708306) (0.902762371209, 0.00402951964186) (0.81015625, 0.00592698264414) (0.912308355289, 0.00240791224129) (0.908607368979, 0.00349799090672) (0.805963541667, 0.00373237139707) (0.9136277814, 0.00248184223237) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_1_l1_pxyz_5e-6kl (0.624427083333, 0.0295604004422) (0.884076441295, 0.00399014168474) (0.910286897716, 0.00144535875377) (0.629166666667, 0.0152698349776) (0.856155662683, 0.00237196038811) (0.897299635367, 0.00543558752126) (0.606875, 0.00451469009983) (0.823832058661, 0.00205379323207) (0.890590860292, 0.00693075598032) (0.604765625, 0.00335154327727) (0.802558979773, 0.00218203974026) 250000
_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl (0.836666666667, 0.0169758112738) (0.914071998262, 0.000895342392999) (0.90939019272, 0.00189095115273) (0.820260416667, 0.0137416621658) (0.911812905791, 0.00115156288156) (0.903168244609, 0.00573054026315) (0.816267361111, 0.00376724110152) (0.917885766707, 0.00209835373405) (0.911391488447, 0.0053247526609) (0.8208203125, 0.00556454418782) (0.923060436471, 0.00280914797257) 250000

In [12]:
raw_jmvae_results, _ = plot_metrics(results_to_show, filt='jmvae', metrics=[(4.0, 'comprehensibility'),
                                                               (3.0, 'comprehensibility'),
                                                               (2.0, 'comprehensibility'),
                                                               (1.0, 'comprehensibility'),
                                                               (3.0, 'parametric_jsd_sim'),
                                                               (2.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_consolidated_jsd_sim'),
                                                               (2.0, 'parametric_consolidated_jsd_sim'),
                                                               (3.0, 'parametric_consolidated_jsd_sim'),
                                                               (4.0, 'parametric_consolidated_jsd_sim'),             
                                                              ], ban='_test_iclr_mnista_fresh')


1.0_comprehensibility 1.0_parametric_consolidated_jsd_sim 1.0_parametric_jsd_sim 2.0_comprehensibility 2.0_parametric_consolidated_jsd_sim 2.0_parametric_jsd_sim 3.0_comprehensibility 3.0_parametric_consolidated_jsd_sim 3.0_parametric_jsd_sim 4.0_comprehensibility 4.0_parametric_consolidated_jsd_sim global_step
_iid_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_10_l1_pxyz_5e-6kl (0.8175, 0.0148207973128) (0.903706152911, 0.00205383616253) (0.898023650119, 0.00309518161576) (0.816354166667, 0.00952591209817) (0.901465443863, 0.00092736170613) (0.882567733203, 0.00551761657947) (0.812395833333, 0.00656204069527) (0.913724148733, 0.00272738656436) (0.895983147143, 0.00397726207604) (0.833776041667, 0.00194329121618) (0.932204578914, 0.00096880650964) 259972
_iid_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_1_l1_pxyz_5e-6kl (0.66203125, 0.0240754609755) (0.878314494975, 0.00293686342358) (0.896543130859, 0.00258266380241) (0.667005208333, 0.0167418956477) (0.850469531614, 0.00262346051442) (0.875562491203, 0.00620055838627) (0.659131944444, 0.00579335881555) (0.836373355313, 0.00343563828309) (0.876619984008, 0.00516937052054) (0.6541796875, 0.00385544214537) (0.821971118595, 0.00275717842871) 89997
_iid_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_50_l1_pxyz_5e-6kl (0.81578125, 0.0177661433827) (0.894305074941, 0.00235164647967) (0.886973533421, 0.00353865224374) (0.810208333333, 0.0105375794942) (0.896098937111, 0.000972679583202) (0.878949221666, 0.00688308981491) (0.820034722222, 0.00373663384406) (0.912085828513, 0.00120482959756) (0.885150592157, 0.00369402579347) (0.851458333333, 0.00262511108809) (0.939121858396, 0.0011250505976) 266054

In [13]:
raw_bivcca_results, _ = plot_metrics(results_to_show, filt='bivcca', metrics=[(4.0, 'comprehensibility'),
                                                               (3.0, 'comprehensibility'),
                                                               (2.0, 'comprehensibility'),
                                                               (1.0, 'comprehensibility'),
                                                               (3.0, 'parametric_jsd_sim'),
                                                               (2.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_jsd_sim'),
                                                               (1.0, 'parametric_consolidated_jsd_sim'),
                                                               (2.0, 'parametric_consolidated_jsd_sim'),
                                                               (3.0, 'parametric_consolidated_jsd_sim'),
                                                               (4.0, 'parametric_consolidated_jsd_sim'),
                                                              ], ban='_test_iclr_mnista_fresh')


1.0_comprehensibility 1.0_parametric_consolidated_jsd_sim 1.0_parametric_jsd_sim 2.0_comprehensibility 2.0_parametric_consolidated_jsd_sim 2.0_parametric_jsd_sim 3.0_comprehensibility 3.0_parametric_consolidated_jsd_sim 3.0_parametric_jsd_sim 4.0_comprehensibility 4.0_parametric_consolidated_jsd_sim global_step
_iid_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_10_l1_pxyz_5e-6kl (0.7303125, 0.0171816593612) (0.848171968332, 0.00144434087215) (0.844414130407, 0.00458211052723) (0.70140625, 0.0178063692544) (0.820257051925, 0.000939517219537) (0.808237436321, 0.0111181327022) (0.670190972222, 0.00886991382906) (0.802271313123, 0.00383662865963) (0.793114444712, 0.00651560419748) (0.644244791667, 0.00582638281043) (0.784037612568, 0.00402106107369) 250000
_iid_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_1_l1_pxyz_5e-6kl (0.67859375, 0.0260048698364) (0.851426208611, 0.00373123358139) (0.862283417006, 0.0025750452861) (0.673515625, 0.0200159644943) (0.815578513125, 0.00271297317554) (0.816886487236, 0.00744830284926) (0.650607638889, 0.00908153731607) (0.797097267933, 0.00344049304896) (0.805699369967, 0.00790432457609) (0.629661458333, 0.00572718360566) (0.777963474647, 0.00342467832633) 213682
_iid_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_50_l1_pxyz_5e-6kl (0.683645833333, 0.0220838246801) (0.845377616564, 0.00339023802545) (0.855290358294, 0.00269744174028) (0.723333333333, 0.0230750916018) (0.848823404959, 0.00433718635375) (0.850902309419, 0.00761648547158) (0.706840277778, 0.00872409221152) (0.841862787571, 0.00324935413829) (0.85284348958, 0.00677547205304) (0.67375, 0.00690956136835) (0.815095344497, 0.0053100973658) 214015

In [14]:
# Some logic which picks which methods to choose
best_triple_elbo = pick_best_method(triple_elbo_results)
best_jmvae = pick_best_method(jmvae_results)
best_bivcca = pick_best_method(bivcca_results)

# TODO(vrama): Load the test set results here and run val on test.

In [29]:
# Compositoinal split results
# Load test numbers.
SPLIT = 'test'
EXP_PREFIX = 'CORRECTCOMP_iclr_mnista_fresh_comp'
PATH_TO_RESULTS = ('/coc/scratch/rvedantam3/runs/imagination/%s_%s' % (EXP_PREFIX, SPLIT))

result_files = glob.glob(PATH_TO_RESULTS + "/*.p")
results_to_show = latest_checkpoints(filter_result_files(result_files))

comp_bivcca_results, _ = plot_metrics(results_to_show, filt='bivcca', metrics=[(4.0, 'comprehensibility'),
                                                               ], ban='_test_CORRECTCOMP_iclr_mnista_fresh')


4.0_comprehensibility global_step
_comp_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_50_l1_pxyz_5e-6kl (0.685789473684, 0.0102310564086) 137256

In [28]:
# Compositoinal split results
# Load test numbers.
SPLIT = 'test'
EXP_PREFIX = 'CORRECTCOMP_iclr_mnista_fresh_comp'
PATH_TO_RESULTS = ('/coc/scratch/rvedantam3/runs/imagination/%s_%s' % (EXP_PREFIX, SPLIT))

result_files = glob.glob(PATH_TO_RESULTS + "/*.p")
results_to_show = latest_checkpoints(filter_result_files(result_files))

comp_triple_elbo_results, _ = plot_metrics(results_to_show, filt='multimodal_elbo', metrics=[(4.0, 'comprehensibility'),
                                                               
                                                                                            ], ban='_test_CORRECTCOMP_iclr_mnista_fresh')


4.0_comprehensibility global_step
_comp_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl (0.756052631579, 0.0142955545848) 120007

In [27]:
# Compositoinal split results
# Load test numbers.
SPLIT = 'test'
EXP_PREFIX = 'CORRECTCOMP_iclr_mnista_fresh_comp'
PATH_TO_RESULTS = ('/coc/scratch/rvedantam3/runs/imagination/%s_%s' % (EXP_PREFIX, SPLIT))

result_files = glob.glob(PATH_TO_RESULTS + "/*.p")
results_to_show = latest_checkpoints(filter_result_files(result_files))

comp_jmvae, _ = plot_metrics(results_to_show, filt='jmvae', metrics=[(4.0, 'comprehensibility'),
                                                           ], ban='_test_CORRECTCOMP_iclr_mnista_fresh')


4.0_comprehensibility global_step
_comp_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_50_l1_pxyz_5e-6kl (0.768552631579, 0.0130150137041) 183564

In [42]:
# Some logic which maps the method to a variable in python (along with the results)
te_4_iid='triple ELBO'
jm_4_iid='JMVAE'
bi_4_iid='bi-VCCA'
te_3_iid='triple ELBO'
jm_3_iid='JMVAE'
bi_3_iid='bi-VCCA'
te_2_iid='triple ELBO'
jm_2_iid='JMVAE'
bi_2_iid='bi-VCCA'
te_1_iid='triple ELBO'
jm_1_iid='JMVAE'
bi_1_iid='bi-VCCA'
te_4_cmp='triple ELBO'
jm_4_cmp='JMVAE'
bi_4_cmp='bi-VCCA'
te_4_cor_iid_pct= raw_triple_elbo_results.loc[best_triple_elbo]['4.0_comprehensibility'][0]*100
te_4_cor_iid_err= raw_triple_elbo_results.loc[best_triple_elbo]['4.0_comprehensibility'][1]*100
jm_4_cor_iid_pct= raw_jmvae_results.loc[best_jmvae]['4.0_comprehensibility'][0]*100
jm_4_cor_iid_err= raw_jmvae_results.loc[best_jmvae]['4.0_comprehensibility'][1]*100
bi_4_cor_iid_pct= raw_bivcca_results.loc[best_bivcca]['4.0_comprehensibility'][0]*100
bi_4_cor_iid_err= raw_bivcca_results.loc[best_bivcca]['4.0_comprehensibility'][1]*100

te_3_cov_iid_pct= raw_triple_elbo_results.loc[best_triple_elbo]['3.0_parametric_jsd_sim'][0]*100
te_3_cov_iid_err= raw_triple_elbo_results.loc[best_triple_elbo]['3.0_parametric_jsd_sim'][1]*100
jm_3_cov_iid_pct= raw_jmvae_results.loc[best_jmvae]['3.0_parametric_jsd_sim'][0]*100
jm_3_cov_iid_err= raw_jmvae_results.loc[best_jmvae]['3.0_parametric_jsd_sim'][1]*100
bi_3_cov_iid_pct= raw_bivcca_results.loc[best_bivcca]['3.0_parametric_jsd_sim'][0]*100
bi_3_cov_iid_err= raw_bivcca_results.loc[best_bivcca]['3.0_parametric_jsd_sim'][1]*100

te_3_cor_iid_pct= raw_triple_elbo_results.loc[best_triple_elbo]['3.0_comprehensibility'][0]*100
te_3_cor_iid_err= raw_triple_elbo_results.loc[best_triple_elbo]['3.0_comprehensibility'][1]*100
jm_3_cor_iid_pct= raw_jmvae_results.loc[best_jmvae]['3.0_comprehensibility'][0]*100
jm_3_cor_iid_err= raw_jmvae_results.loc[best_jmvae]['3.0_comprehensibility'][1]*100
bi_3_cor_iid_pct= raw_bivcca_results.loc[best_bivcca]['3.0_comprehensibility'][0]*100
bi_3_cor_iid_err= raw_bivcca_results.loc[best_bivcca]['3.0_comprehensibility'][1]*100

te_2_cov_iid_pct= raw_triple_elbo_results.loc[best_triple_elbo]['2.0_parametric_jsd_sim'][0]*100
te_2_cov_iid_err= raw_triple_elbo_results.loc[best_triple_elbo]['2.0_parametric_jsd_sim'][1]*100
jm_2_cov_iid_pct= raw_jmvae_results.loc[best_jmvae]['2.0_parametric_jsd_sim'][0]*100
jm_2_cov_iid_err= raw_jmvae_results.loc[best_jmvae]['2.0_parametric_jsd_sim'][1]*100
bi_2_cov_iid_pct= raw_bivcca_results.loc[best_bivcca]['2.0_parametric_jsd_sim'][0]*100
bi_2_cov_iid_err= raw_bivcca_results.loc[best_bivcca]['2.0_parametric_jsd_sim'][1]*100

te_2_cor_iid_pct= raw_triple_elbo_results.loc[best_triple_elbo]['2.0_comprehensibility'][0]*100
te_2_cor_iid_err= raw_triple_elbo_results.loc[best_triple_elbo]['2.0_comprehensibility'][1]*100
jm_2_cor_iid_pct= raw_jmvae_results.loc[best_jmvae]['2.0_comprehensibility'][0]*100
jm_2_cor_iid_err= raw_jmvae_results.loc[best_jmvae]['2.0_comprehensibility'][1]*100
bi_2_cor_iid_pct= raw_bivcca_results.loc[best_bivcca]['2.0_comprehensibility'][0]*100
bi_2_cor_iid_err= raw_bivcca_results.loc[best_bivcca]['2.0_comprehensibility'][1]*100

te_1_cov_iid_pct= raw_triple_elbo_results.loc[best_triple_elbo]['1.0_parametric_jsd_sim'][0]*100
te_1_cov_iid_err= raw_triple_elbo_results.loc[best_triple_elbo]['1.0_parametric_jsd_sim'][1]*100
jm_1_cov_iid_pct= raw_jmvae_results.loc[best_jmvae]['1.0_parametric_jsd_sim'][0]*100
jm_1_cov_iid_err= raw_jmvae_results.loc[best_jmvae]['1.0_parametric_jsd_sim'][1]*100
bi_1_cov_iid_pct= raw_bivcca_results.loc[best_bivcca]['1.0_parametric_jsd_sim'][0]*100
bi_1_cov_iid_err= raw_bivcca_results.loc[best_bivcca]['1.0_parametric_jsd_sim'][1]*100

te_1_cor_iid_pct= raw_triple_elbo_results.loc[best_triple_elbo]['1.0_comprehensibility'][0]*100
te_1_cor_iid_err= raw_triple_elbo_results.loc[best_triple_elbo]['1.0_comprehensibility'][1]*100
jm_1_cor_iid_pct= raw_jmvae_results.loc[best_jmvae]['1.0_comprehensibility'][0]*100
jm_1_cor_iid_err= raw_jmvae_results.loc[best_jmvae]['1.0_comprehensibility'][1]*100
bi_1_cor_iid_pct= raw_bivcca_results.loc[best_bivcca]['1.0_comprehensibility'][0]*100
bi_1_cor_iid_err= raw_bivcca_results.loc[best_bivcca]['1.0_comprehensibility'][1]*100

te_4_cor_cmp_pct = comp_telbo_results.loc[best_triple_elbo.replace('iid', 'comp')]['4.0_comprehensibility'][0] * 100
te_4_cor_cmp_err= comp_telbo_results.loc[best_triple_elbo.replace('iid', 'comp')]['4.0_comprehensibility'][1] * 100
jm_4_cor_cmp_pct = comp_jmvae.loc[best_jmvae.replace('iid', 'comp')]['4.0_comprehensibility'][0] * 100
jm_4_cor_cmp_err= comp_jmvae.loc[best_jmvae.replace('iid', 'comp')]['4.0_comprehensibility'][1] * 100
bi_4_cor_cmp_pct= comp_bivcca_results.loc[best_bivcca.replace('iid', 'comp')]['4.0_comprehensibility'][0]*100
bi_4_cor_cmp_err= comp_bivcca_results.loc[best_bivcca.replace('iid', 'comp')]['4.0_comprehensibility'][1]*100

In [43]:
table = build_latex_results(
    te_4_iid,
    jm_4_iid,
    bi_4_iid,
    te_3_iid,
    jm_3_iid,
    bi_3_iid,
    te_2_iid,
    jm_2_iid,
    bi_2_iid,
    te_1_iid,
    jm_1_iid,
    bi_1_iid,
    te_4_cmp,
    jm_4_cmp,
    bi_4_cmp,
    te_4_cor_iid_pct, 
    te_4_cor_iid_err, 
    jm_4_cor_iid_pct, 
    jm_4_cor_iid_err, 
    bi_4_cor_iid_pct, 
    bi_4_cor_iid_err, 
    te_3_cov_iid_pct, 
    te_3_cov_iid_err, 
    jm_3_cov_iid_pct, 
    jm_3_cov_iid_err, 
    bi_3_cov_iid_pct, 
    bi_3_cov_iid_err, 
    te_3_cor_iid_pct, 
    te_3_cor_iid_err, 
    jm_3_cor_iid_pct, 
    jm_3_cor_iid_err, 
    bi_3_cor_iid_pct, 
    bi_3_cor_iid_err, 
    te_2_cov_iid_pct, 
    te_2_cov_iid_err, 
    jm_2_cov_iid_pct, 
    jm_2_cov_iid_err, 
    bi_2_cov_iid_pct, 
    bi_2_cov_iid_err, 
    te_2_cor_iid_pct, 
    te_2_cor_iid_err, 
    jm_2_cor_iid_pct, 
    jm_2_cor_iid_err, 
    bi_2_cor_iid_pct, 
    bi_2_cor_iid_err, 
    te_1_cov_iid_pct, 
    te_1_cov_iid_err, 
    jm_1_cov_iid_pct, 
    jm_1_cov_iid_err, 
    bi_1_cov_iid_pct, 
    bi_1_cov_iid_err, 
    te_1_cor_iid_pct, 
    te_1_cor_iid_err, 
    jm_1_cor_iid_pct, 
    jm_1_cor_iid_err, 
    bi_1_cor_iid_pct, 
    bi_1_cor_iid_err, 
    te_4_cor_cmp_pct, 
    te_4_cor_cmp_err, 
    jm_4_cor_cmp_pct, 
    jm_4_cor_cmp_err, 
    bi_4_cor_cmp_pct, 
    bi_4_cor_cmp_err)

In [44]:
print(table)


\begin{table}
  \centering
  \begin{tabular}{cccc}
    \toprule

    \textbf{Method}  & \textbf{\#Attributes} & \textbf{Coverage} (\%) & \textbf{\Correctness} (\%)                                \\

    \toprule
    \rowcolor{Gray}\multicolumn{4}{c}{\textbf{\iid}}\\
    \toprule
    % triple ELBO
    \telbo  & \multirow{3}{*}{4} & -            & 82.08 {\tiny$\pm$} 0.56\\
    % JMVAE
    \jmvae  &  & -            & 85.15 {\tiny$\pm$} 0.26  \\
    % bi-VCCA
    \bivcca &  & -            & 67.38 {\tiny$\pm$} 0.69 \\

    \midrule

    % triple ELBO
    \telbo  & \multirow{3}{*}{3} & 91.14 {\tiny$\pm$} 0.53 & 81.63 {\tiny$\pm$} 0.38 \\
    % JMVAE
    \jmvae  & & 88.52 {\tiny$\pm$} 0.37 & 82.00 {\tiny$\pm$} 0.37  \\
    % bi-VCCA
    \bivcca & & 85.28 {\tiny$\pm$} 0.68 & 70.68 {\tiny$\pm$} 0.87  \\

    \midrule

    % triple ELBO
    \telbo  & \multirow{3}{*}{2} & 90.32 {\tiny$\pm$} 0.57 & 82.03 {\tiny$\pm$} 1.37  \\
    % JMVAE
    \jmvae  &  & 87.89 {\tiny$\pm$} 0.69 & 81.02 {\tiny$\pm$} 1.05 \\
    % bi-VCCA
    \bivcca &  & 85.09 {\tiny$\pm$} 0.76 & 72.33 {\tiny$\pm$} 2.31 \\

    \midrule

    % triple ELBO
    \telbo  & \multirow{3}{*}{1} & 90.94 {\tiny$\pm$} 0.19 & 83.67 {\tiny$\pm$} 1.70 \\
    % JMVAE
    \jmvae  & & 88.70 {\tiny$\pm$} 0.35 & 81.58 {\tiny$\pm$} 1.78 \\
    % bi-VCCA
    \bivcca & & 85.53 {\tiny$\pm$} 0.27 & 68.36 {\tiny$\pm$} 2.21 \\
    
    \toprule
    \rowcolor{Gray}\multicolumn{4}{c}{\textbf{\comp}}\\
    \toprule
    % triple ELBO
    \telbo  & \multirow{3}{*}{4} & -            & 75.61 {\tiny$\pm$} 1.43  \\
    % JMVAE
    \jmvae  & & -            & 76.86 {\tiny$\pm$} 1.30  \\
    % bi-VCCA
    \bivcca & & -            & 68.58 {\tiny$\pm$} 1.02  \\

    \bottomrule
  \end{tabular}
\end{table}
  

In [49]:
best_jmvae


Out[49]:
'_iid_affine_mnist_loss_jmvae_poe_1_nl_10_jmalpha_1.0_ax_1_ay_50_l1_pxyz_5e-6kl'

In [50]:
best_triple_elbo


Out[50]:
'_iid_affine_mnist_loss_multimodal_elbo_poe_1_nl_10_sg_1_privatey_50_xsg_1_ax_1_ay_50_l1_pxyz_5e-6kl'

In [51]:
best_bivcca


Out[51]:
'_iid_affine_mnist_loss_bivcca_poe_1_nl_10_bmu0.7_ax_1_ay_50_l1_pxyz_5e-6kl'

In [ ]: