The Analysis Pipeline

Alex Malz (NYU) & Phil Marshall (SLAC)

In this notebook we use the "survey mode" machinery to demonstrate how one should choose the optimal parametrization for photo-$z$ PDF storage given the nature of the data, the storage constraints, and the fidelity necessary for a science use case.


In [ ]:
#comment out for NERSC
%load_ext autoreload

#comment out for NERSC
%autoreload 2

In [ ]:
from __future__ import print_function
    
import pickle
import hickle
import numpy as np
import random
import cProfile
import pstats
import StringIO
import sys
import os
import timeit
import bisect
import re

import qp
from qp.utils import calculate_kl_divergence as make_kld

# np.random.seed(seed=42)
# random.seed(a=42)

In [ ]:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['text.usetex'] = True
mpl.rcParams['mathtext.rm'] = 'serif'
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = 'Times New Roman'
mpl.rcParams['axes.titlesize'] = 16
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['savefig.dpi'] = 250
mpl.rcParams['savefig.format'] = 'pdf'
mpl.rcParams['savefig.bbox'] = 'tight'

#comment out for NERSC
%matplotlib inline

Analysis

We want to compare parametrizations for large catalogs, so we'll need to be more efficient. The qp.Ensemble object is a wrapper for qp.PDF objects enabling conversions to be performed and metrics to be calculated in parallel. We'll experiment on a subsample of 100 galaxies.


In [ ]:
def setup_dataset(dataset_key, skip_rows, skip_cols):
    start = timeit.default_timer()
    with open(dataset_info[dataset_key]['filename'], 'rb') as data_file:
        lines = (line.split(None) for line in data_file)
        for r in range(skip_rows):
            lines.next()
        pdfs = np.array([[float(line[k]) for k in range(skip_cols, len(line))] for line in lines])
    print('read in data file in '+str(timeit.default_timer()-start))
    return(pdfs)

In [ ]:
def make_instantiation(dataset_key, n_gals_use, pdfs, bonus=None):
    
    start = timeit.default_timer()
    
    n_gals_tot = len(pdfs)
    full_gal_range = range(n_gals_tot)
    subset = np.random.choice(full_gal_range, n_gals_use, replace=False)#range(n_gals_use)
#     subset = indices
    print('randos for debugging: '+str(subset))
    pdfs_use = pdfs[subset]
    
    modality = []
    dpdfs = pdfs_use[:,1:] - pdfs_use[:,:-1]
    iqrs = []
    for i in range(n_gals_use):
        modality.append(len(np.where(np.diff(np.signbit(dpdfs[i])))[0]))
        cdf = np.cumsum(qp.utils.normalize_integral((dataset_info[dataset_key]['z_grid'], pdfs_use[i]), vb=False)[1])
        iqr_lo = dataset_info[dataset_key]['z_grid'][bisect.bisect_left(cdf, 0.25)]
        iqr_hi = dataset_info[dataset_key]['z_grid'][bisect.bisect_left(cdf, 0.75)]
        iqrs.append(iqr_hi - iqr_lo)
    modality = np.array(modality)
        
    dataset_info[dataset_key]['N_GMM'] = int(np.median(modality))+1
#     print('n_gmm for '+dataset_info[dataset_key]['name']+' = '+str(dataset_info[dataset_key]['N_GMM']))
      
    # using the same grid for output as the native format, but doesn't need to be so
    dataset_info[dataset_key]['in_z_grid'] = dataset_info[dataset_key]['z_grid']
    dataset_info[dataset_key]['metric_z_grid'] = dataset_info[dataset_key]['z_grid']
    
    print('preprocessed data in '+str(timeit.default_timer()-start))
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['randos'] = randos
        info['z_grid'] = dataset_info[dataset_key]['in_z_grid']
        info['pdfs'] = pdfs_use
        info['modes'] = modality
        info['iqrs'] = iqrs
        hickle.dump(info, filename)
    
    return(pdfs_use)

In [ ]:
def plot_examples(n_gals_use, dataset_key, bonus=None, norm=False):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'r') as filename:
        info = hickle.load(filename)
        randos = info['randos']
        z_grid = info['z_grid']
        pdfs = info['pdfs']
    
    plt.figure()
    for i in range(n_plot):
        data = (z_grid, pdfs[randos[i]])
        data = qp.utils.normalize_integral(qp.utils.normalize_gridded(data))
        pz_max.append(np.max(data))
        plt.plot(data[0], data[1], label=dataset_info[dataset_key]['name']+' \#'+str(randos[i]), color=color_cycle[i])
    plt.xlabel(r'$z$', fontsize=14)
    plt.ylabel(r'$p(z)$', fontsize=14)
    plt.xlim(min(z_grid), max(z_grid))
    plt.title(dataset_info[dataset_key]['name']+' data examples', fontsize=16)
    if norm:
        plt.ylim(0., max(pz_max))
        plt.savefig(loc+'norm.pdf', dpi=250)
    else:
        plt.savefig(loc+'.pdf', dpi=250)
    plt.close()
    
    if 'modes' in info.keys():
        modes = info['modes']
        modes_max.append(np.max(modes))
        plt.figure()
        ax = plt.hist(modes, color='k', alpha=1./n_plot, histtype='stepfilled', bins=range(max(modes_max)+1))
        plt.xlabel('modes')
        plt.ylabel('frequency')
        plt.title(dataset_info[dataset_key]['name']+' data modality distribution (median='+str(dataset_info[dataset_key]['N_GMM'])+')', fontsize=16)
        plt.savefig(loc+'modality.pdf', dpi=250)
        plt.close()
        
    if 'iqrs' in info.keys():
        iqrs = info['iqrs']
        iqr_min.append(min(iqrs))
        iqr_max.append(max(iqrs))
        plot_bins = np.linspace(min(iqr_min), max(iqr_max), 20)
        plt.figure()
        ax = plt.hist(iqrs, bins=plot_bins, color='k', alpha=1./n_plot, histtype='stepfilled')
        plt.xlabel('IQR')
        plt.ylabel('frequency')
        plt.title(dataset_info[dataset_key]['name']+' data IQR distribution', fontsize=16)
        plt.savefig(loc+'iqrs.pdf', dpi=250)
        plt.close()

We're going to incrementally save the quantities that are costly to calculate.


In [ ]:
def save_one_stat(dataset_name, n_gals_use, N_f, i, stat, stat_name):
    path = os.path.join(dataset_name, str(n_gals_use))
    loc = os.path.join(path, stat_name+str(n_gals_use)+dataset_name+str(N_f)+'_'+str(i))
    with open(loc+'.hkl', 'w') as filename:
        hickle.dump(stat, filename)
        
def load_one_stat(dataset_name, n_gals_use, N_f, i, stat_name):
    path = os.path.join(dataset_name, str(n_gals_use))
    loc = os.path.join(path, stat_name+str(n_gals_use)+dataset_name+str(N_f)+'_'+str(i))
    with open(loc+'.hkl', 'r') as filename:
        stat = hickle.load(filename)
#     print(stat)
    return stat

def save_moments_wrapper(dataset_name, n_gals_use, N_f, i, stat_name):
    stat = load_one_stat(dataset_name, n_gals_use, N_f, i, stat_name)
    save_moments(dataset_name, n_gals_use, N_f, stat, stat_name)
        
def save_metrics_wrapper(dataset_name, n_gals_use, N_f, i, stat_name):
    stat = load_one_stat(dataset_name, n_gals_use, N_f, i, stat_name)
    save_nz_metrics(dataset_name, n_gals_use, N_f, stat, stat_name)
    
def clear_stats(dataset_name, n_gals_use, stat_name):
    path = os.path.join(dataset_name, str(n_gals_use))
    loc = os.path.join(path, stat_name+str(n_gals_use)+dataset_name+'.hkl')
    if os.path.isfile(loc):
        os.remove(loc)

We'll start by reading in our catalog of gridded PDFs, sampling them, fitting GMMs to the samples, and establishing a new qp.Ensemble object where each meber qp.PDF object has qp.PDF.truth$\neq$None.


In [ ]:
def setup_from_grid(dataset_key, in_pdfs, z_grid, N_comps, high_res=1000, bonus=None):
    
    #read in the data, happens to be gridded
    zlim = (min(z_grid), max(z_grid))
    N_pdfs = len(in_pdfs)
    
    start = timeit.default_timer()
#     print('making the initial ensemble of '+str(N_pdfs)+' PDFs')
    E0 = qp.Ensemble(N_pdfs, gridded=(z_grid, in_pdfs), limits=dataset_info[dataset_key]['z_lim'], vb=False)
    print('made the initial ensemble of '+str(N_pdfs)+' PDFs in '+str(timeit.default_timer() - start))    
    
    #fit GMMs to gridded pdfs based on samples (faster than fitting to gridded)
    start = timeit.default_timer()
#     print('sampling for the GMM fit')
    samparr = E0.sample(high_res, vb=False)
    print('took '+str(high_res)+' samples in '+str(timeit.default_timer() - start))
    
    start = timeit.default_timer()
#     print('making a new ensemble from samples')
    Ei = qp.Ensemble(N_pdfs, samples=samparr, limits=dataset_info[dataset_key]['z_lim'], vb=False)
    print('made a new ensemble from samples in '+str(timeit.default_timer() - start))
    
    start = timeit.default_timer()
#     print('fitting the GMM to samples')
    GMMs = Ei.mix_mod_fit(comps=N_comps, vb=False)
    print('fit the GMM to samples in '+str(timeit.default_timer() - start))
    
    #set the GMMS as the truth
    start = timeit.default_timer()
#     print('making the final ensemble')
    Ef = qp.Ensemble(N_pdfs, truth=GMMs, limits=dataset_info[dataset_key]['z_lim'], vb=False)
    print('made the final ensemble in '+str(timeit.default_timer() - start))
    
    path = os.path.join(dataset_key, str(N_pdfs))
    loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['randos'] = randos
        info['z_grid'] = z_grid
        info['pdfs'] = Ef.evaluate(z_grid, using='truth', norm=True, vb=False)[1]
        hickle.dump(info, filename)
        
    start = timeit.default_timer()
#     print('calculating '+str(n_moments_use)+' moments of original PDFs')
    in_moments, vals = [], []
    for n in range(n_moments_use):
        in_moments.append(Ef.moment(n, using='truth', limits=zlim, 
                                    dx=delta_z, vb=False))
        vals.append(n)
    moments = np.array(in_moments)
    print('calculated '+str(n_moments_use)+' moments of original PDFs in '+str(timeit.default_timer() - start))
    
    path = os.path.join(dataset_key, str(N_pdfs))
    loc = os.path.join(path, 'pz_moments'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['truth'] = moments
        info['orders'] = vals
        hickle.dump(info, filename)
    
    return(Ef)

Next, we compute the KLD between each approximation and the truth for every member of the ensemble. We make the qp.Ensemble.kld into a qp.PDF object of its own to compare the moments of the KLD distributions for different parametrizations.


In [ ]:
def analyze_individual(E, z_grid, N_floats, dataset_key, N_moments=4, i=None, bonus=None):
    zlim = (min(z_grid), max(z_grid))
    z_range = zlim[-1] - zlim[0]
    delta_z = z_range / len(z_grid)
    path = os.path.join(dataset_key, str(n_gals_use))
    
    Eq, Eh, Es = E, E, E
    inits = {}
    for f in formats:
        inits[f] = {}
        for ff in formats:
            inits[f][ff] = None
            
    qstart = timeit.default_timer()
    inits['quantiles']['quantiles'] = Eq.quantize(N=N_floats, vb=True)
    print('finished quantization in '+str(timeit.default_timer() - qstart))
    hstart = timeit.default_timer()
    inits['histogram']['histogram'] = Eh.histogramize(N=N_floats, binrange=zlim, vb=False)
    print('finished histogramization in '+str(timeit.default_timer() - hstart))
    sstart = timeit.default_timer()
    inits['samples']['samples'] = Es.sample(samps=N_floats, vb=False)
    print('finished sampling in '+str(timeit.default_timer() - sstart))
        
    Eo = {}
    
    metric_start = timeit.default_timer()
    inloc = os.path.join(path, 'pz_moments'+str(n_gals_use)+dataset_key+bonus)
    with open(inloc+'.hkl', 'r') as infilename:
        pz_moments = hickle.load(infilename)
    
    klds, metrics, kld_moments, pz_moment_deltas = {}, {}, {}, {}
    
    for f in formats:
        fstart = timeit.default_timer()
        Eo[f] = qp.Ensemble(E.n_pdfs, truth=E.truth, 
                            quantiles=inits[f]['quantiles'], 
                            histogram=inits[f]['histogram'],
                            samples=inits[f]['samples'], 
                            limits=dataset_info[dataset_key]['z_lim'])
        
        fbonus = str(N_floats)+f+str(i)
        loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+fbonus)
        with open(loc+'.hkl', 'w') as filename:
            info = {}
            info['randos'] = randos
            info['z_grid'] = z_grid
            info['pdfs'] = Eo[f].evaluate(z_grid, using=f, norm=True, vb=False)[1]
            hickle.dump(info, filename)
        print('made '+f+' ensemble in '+str(timeit.default_timer()-fstart))

        key = f
        
        fstart = timeit.default_timer()
        klds[key] = Eo[key].kld(using=key, limits=zlim, dx=delta_z, vb=False)
        print('calculated the '+key+' individual klds in '+str(timeit.default_timer() - fstart))
        
        fstart = timeit.default_timer()
        kld_moments[key] = []
        samp_metric = qp.PDF(samples=klds[key])
        gmm_metric = samp_metric.mix_mod_fit(n_components=dataset_info[dataset_key]['N_GMM'], 
                                             using='samples', vb=False)
        metrics[key] = qp.PDF(truth=gmm_metric)
        for n in range(N_moments):
            kld_moments[key].append(qp.utils.calculate_moment(metrics[key], n,
                                                          using='truth', 
                                                          limits=zlim, 
                                                          dx=delta_z, 
                                                          vb=False))
        save_one_stat(name, size, n_floats_use, i, kld_moments, 'pz_kld_moments')
        print('calculated the '+key+' kld moments in '+str(timeit.default_timer() - fstart))
        
        pz_moment_deltas[key], pz_moments[key] = [], []
        for n in range(N_moments):
            start = timeit.default_timer()
            new_moment = Eo[key].moment(n, using=key, limits=zlim, 
                                                  dx=delta_z, vb=False)
            pz_moments[key].append(new_moment)
            #NOTE: delta_moment is crazy for clean data!
            delta_moment = (new_moment - pz_moments['truth'][n]) / pz_moments['truth'][n]
            pz_moment_deltas[key].append(delta_moment)
            print('calculated the '+key+' individual moment '+str(n)+' in '+str(timeit.default_timer() - start))
        save_one_stat(name, size, n_floats_use, i, pz_moments, 'pz_moments')
        save_one_stat(name, size, n_floats_use, i, pz_moment_deltas, 'pz_moment_deltas')
        
    loc = os.path.join(path, 'kld_hist'+str(n_gals_use)+dataset_key+str(N_floats)+'_'+str(i))
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['z_grid'] = z_grid
        info['N_floats'] = N_floats
        info['pz_klds'] = klds
        hickle.dump(info, filename)

    outloc = os.path.join(path, 'pz_moments'+str(n_gals_use)+dataset_key+str(N_floats)+'_'+str(i))
    with open(outloc+'.hkl', 'w') as outfilename:
        hickle.dump(pz_moments, outfilename)
    
#     save_moments(name, size, n_floats_use, kld_moments, 'pz_kld_moments')
#     save_moments(name, size, n_floats_use, pz_moments, 'pz_moments')
#     save_moments(name, size, n_floats_use, pz_moment_deltas, 'pz_moment_deltas')
    
    return(Eo)#, klds, kld_moments, pz_moments, pz_moment_deltas)

In [ ]:
def plot_all_examples(name, size, N_floats, init, bonus={}):
    path = os.path.join(name, str(size))
    fig, ax = plt.subplots()
#     fig_check, ax_check = plt.subplots()
    lines = []
    loc = os.path.join(path, 'pzs'+str(size)+name+'_postfit'+str(init))
    with open(loc+'.hkl', 'r') as filename:
        info = hickle.load(filename)
        ref_pdfs = info['pdfs']  
#     klds = {}
    for bonus_key in bonus.keys():
        loc = os.path.join(path, 'pzs'+str(size)+name+bonus_key)
        with open(loc+'.hkl', 'r') as filename:
            info = hickle.load(filename)
            randos = info['randos']
            z_grid = info['z_grid']
            pdfs = info['pdfs']
        ls = bonus[bonus_key][0]
        a = bonus[bonus_key][1]
        lab = re.sub(r'[\_]', '', bonus_key)
        line, = ax.plot([-1., 0.], [0., 0.], linestyle=ls, alpha=a, color='k', label=lab[:-1])
        lines.append(line)
        leg = ax.legend(loc='upper right', handles=lines)
#         klds[bonus_key] = []
        for i in range(n_plot):
            data = (z_grid, pdfs[randos[i]])
            data = qp.utils.normalize_integral(qp.utils.normalize_gridded(data))
            ax.plot(data[0], data[1], linestyle=ls, alpha=a, color=color_cycle[i])
            #     ax.legend(loc='upper right')
#         for i in range(size):
#             data = (z_grid, pdfs[i])
#             kld = qp.utils.quick_kl_divergence(ref_pdfs[i], pdfs[i], dx=0.01)
#             klds[bonus_key].append(kld)
#     plot_bins = np.linspace(-3., 3., 20)
#     for bonus_key in bonus.keys()[1:-1]:
#         ax_check.hist(np.log(np.array(klds[bonus_key])), alpha=a, 
#                       histtype='stepfilled', edgecolor='k', 
#                       label=bonus_key, normed=True, bins=plot_bins, lw=2)
    ax.set_xlabel(r'$z$', fontsize=14)
    ax.set_ylabel(r'$p(z)$', fontsize=14)
    ax.set_xlim(min(z_grid), max(z_grid))
    ax.set_title(dataset_info[name]['name']+r' examples with $N_{f}=$'+str(N_floats), fontsize=16)
    saveloc = os.path.join(path, 'pzs'+str(size)+name+str(N_floats)+'all'+str(init))
    fig.savefig(saveloc+'.pdf', dpi=250)
#     ax_check.legend()
#     ax_check.set_ylabel('frequency', fontsize=14)
#     ax_check.set_xlabel(r'$\mathrm{KLD}$', fontsize=14)
#     ax_check.set_title(name+r' data $p(\mathrm{KLD})$ with $N_{f}='+str(N_floats)+r'$', fontsize=16)
#     fig_check.savefig(saveloc+'kld_check.pdf', dpi=250)
    plt.close()
#     with open(saveloc+'.p', 'w') as kldfile:
#         pickle.dump(klds, kldfile)

In [ ]:
def plot_individual_kld(n_gals_use, dataset_key, N_floats, i):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    a = 1./len(formats)
    loc = os.path.join(path, 'kld_hist'+str(n_gals_use)+dataset_key+str(N_floats)+'_'+str(i))
    with open(loc+'.hkl', 'r') as filename:
        info = hickle.load(filename)
        z_grid = info['z_grid']
        N_floats = info['N_floats']
        pz_klds = info['pz_klds']
    
    plt.figure()
    plot_bins = np.linspace(-10., 5., 30)
    for key in pz_klds.keys():
        logdata = qp.utils.safelog(pz_klds[key])
        dist_min.append(min(logdata))
        dist_max.append(max(logdata))
#         plot_bins = np.linspace(-10., 5., 20)
        kld_hist = plt.hist(logdata, color=colors[key], alpha=a, histtype='stepfilled', edgecolor='k',
             label=key, normed=True, bins=plot_bins, linestyle=stepstyles[key], ls=stepstyles[key], lw=2)
#         kld_hist = plt.hist(pz_klds[key], color=colors[key], alpha=a, histtype='stepfilled', edgecolor='k',
#              label=key, normed=True, bins=plot_bins, linestyle=stepstyles[key], ls=stepstyles[key], lw=2)
        hist_max.append(max(kld_hist[0]))
#     print(loc+': min log[KLD]='+str(logdata)+' at N='+str(np.argmin(logdata)))
    plt.legend()
    plt.ylabel('frequency', fontsize=14)
#     plt.xlabel(r'$\log[\mathrm{KLD}]$', fontsize=14)
    plt.xlabel(r'$\log[\mathrm{KLD}]$', fontsize=14)
#     plt.xlim(min(dist_min), max(dist_max))
    plt.ylim(0., max(hist_max))
    plt.title(dataset_info[dataset_key]['name']+r' data $p(\log[\mathrm{KLD}])$ with $N_{f}='+str(N_floats)+r'$', fontsize=16)
    plt.savefig(loc+'.pdf', dpi=250)
    plt.close()

In [ ]:
def plot_all_kld(size, name, i):
    path = os.path.join(name, str(size))
    fig, ax = plt.subplots()
    fig.canvas.draw()
    for i in instantiations:
        to_plot = {}
        for f in formats:
            to_plot[f] = []
        for Nf in floats:
            place = os.path.join(path, 'kld_hist'+str(size)+name+str(Nf)+'_'+str(i))
            with open(place+'.hkl', 'r') as filename:
                klds = hickle.load(filename)['pz_klds']
                for f in formats:
                    to_plot[f].append(klds[f])
#                         print(name, size, i, Nf, f, klds[f])
        for f in formats:
            to_plot[f] = np.array(to_plot[f])
            delta_info = np.ones((len(floats), size))
            for Nf in floats:
                delta_info[:-1] = to_plot[f][1:] - to_plot[f][:-1]
                delta_info[-1] = -1. * to_plot[f][-1]
            ax.plot(floats, delta_info, color=colors[f])
        ax.set_xlabel()
        ax.set_ylabel()
        ax.semilogx()
        ax.set_xticks(floats)
        ax.set_xticklabels([r'$3\to 10$', r'$10\to 30$', r'$30\to 100$', r'$100\to \infty$'])

Finally, we calculate metrics on the stacked estimator $\hat{n}(z)$ that is the average of all members of the ensemble.


In [ ]:
def analyze_stacked(E0, E, z_grid, n_floats_use, dataset_key, i=None):
    
    zlim = (min(z_grid), max(z_grid))
    z_range = zlim[-1] - zlim[0]
    delta_z = z_range / len(z_grid)
    
    n_gals_use = E0.n_pdfs
    
#     print('stacking the ensembles')
#     stack_start = timeit.default_timer()
    stacked_pdfs, stacks = {}, {}
    for key in formats:
        start = timeit.default_timer()
        stacked_pdfs[key] = qp.PDF(gridded=E[key].stack(z_grid, using=key, 
                                                        vb=False)[key])
        stacks[key] = stacked_pdfs[key].evaluate(z_grid, using='gridded', norm=True, vb=False)[1]
        print('stacked '+key+ ' in '+str(timeit.default_timer()-start))
    
    stack_start = timeit.default_timer()
    stacked_pdfs['truth'] = qp.PDF(gridded=E0.stack(z_grid, using='truth', 
                                                    vb=False)['truth'])
    
    stacks['truth'] = stacked_pdfs['truth'].evaluate(z_grid, using='gridded', norm=True, vb=False)[1]
    print('stacked truth in '+str(timeit.default_timer() - stack_start))
    
    klds = {}
    for key in formats:
        kld_start = timeit.default_timer()
        klds[key] = qp.utils.calculate_kl_divergence(stacked_pdfs['truth'],
                                                     stacked_pdfs[key], 
                                                     limits=zlim, dx=delta_z)
        print('calculated the '+key+' stacked kld in '+str(timeit.default_timer() - kld_start))
    save_one_stat(dataset_key, n_gals_use, n_floats_use, i, klds, 'nz_klds')
#     save_nz_metrics(name, size, n_floats_use, klds, 'nz_klds')
        
    moments = {}
    for key in formats_plus:
        moment_start = timeit.default_timer()
        moments[key] = []
        for n in range(n_moments_use):
            moments[key].append(qp.utils.calculate_moment(stacked_pdfs[key], n, 
                                                          limits=zlim, 
                                                          dx=delta_z, 
                                                          vb=False))
        print('calculated the '+key+' stacked moments in '+str(timeit.default_timer() - moment_start))
    save_one_stat(dataset_key, n_gals_use, n_floats_use, i, moments, 'nz_moments')
#     save_moments(name, size, n_floats_use, moments, 'nz_moments') 
    
    path = os.path.join(dataset_key, str(E0.n_pdfs))
    loc = os.path.join(path, 'nz_comp'+str(n_gals_use)+dataset_key+str(n_floats_use)+'_'+str(i))
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['z_grid'] = z_grid
        info['stacks'] = stacks
        info['klds'] = klds
        info['moments'] = moments
        hickle.dump(info, filename)
    
    return(stacked_pdfs)

In [ ]:
def plot_estimators(n_gals_use, dataset_key, n_floats_use, i=None):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'nz_comp'+str(n_gals_use)+dataset_key+str(n_floats_use)+'_'+str(i))
    with open(loc+'.hkl', 'r') as filename:
        info = hickle.load(filename)
        z_grid = info['z_grid']
        stacks = info['stacks']
        klds = info['klds']
    
    plt.figure()
    plt.plot(z_grid, stacks['truth'], color='black', lw=3, alpha=0.3, label='original')
    nz_max.append(max(stacks['truth']))
    for key in formats:
        nz_max.append(max(stacks[key]))
        plt.plot(z_grid, stacks[key], label=key+r' KLD='+str(klds[key])[:8], color=colors[key], linestyle=styles[key])
    plt.xlabel(r'$z$', fontsize=14)
    plt.ylabel(r'$\hat{n}(z)$', fontsize=14)
    plt.xlim(min(z_grid), max(z_grid))
    plt.ylim(0., max(nz_max))
    plt.legend()
    plt.title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ with $N_{f}='+str(n_floats_use)+r'$', fontsize=16)
    plt.savefig(loc+'.pdf', dpi=250)
    plt.close()

We save the data so we can remake the plots later without running everything again.

Scaling

We'd like to do this for many values of $N_{f}$ as well as larger catalog subsamples, repeating the analysis many times to establish error bars on the KLD as a function of format, $N_{f}$, and dataset. The things we want to plot across multiple datasets/number of parametes are:

  1. KLD of stacked estimator, i.e. N_f vs. nz_output[dataset][format][instantiation][KLD_val_for_N_f]
  2. moments of KLD of individual PDFs, i.e. n_moment, N_f vs. pz_output[dataset][format][n_moment][instantiation][moment_val_for_N_f]

So, we ned to make sure these are saved!

We want to plot the moments of the KLD distribution for each format as $N_{f}$ changes.


In [ ]:
def save_moments(dataset_name, n_gals_use, N_f, stat, stat_name):

    path = os.path.join(dataset_name, str(n_gals_use))
    loc = os.path.join(path, stat_name+str(n_gals_use)+dataset_name)
    
    if os.path.exists(loc+'.hkl'):
        with open(loc+'.hkl', 'r') as stat_file:
        #read in content of list/dict
            stats = hickle.load(stat_file)
    else:
        stats = {}
        stats['N_f'] = []
        for f in stat.keys():
            stats[f] = []
            for m in range(n_moments_use):
                stats[f].append([])

    if N_f not in stats['N_f']:
        stats['N_f'].append(N_f)
        for f in stat.keys():
            for m in range(n_moments_use):
                stats[f][m].append([])
        
    where_N_f = stats['N_f'].index(N_f)
        
    for f in stat.keys():
        for m in range(n_moments_use):
            stats[f][m][where_N_f].append(stat[f][m])

    with open(loc+'.hkl', 'w') as stat_file:
        hickle.dump(stats, stat_file)

In [ ]:
# include second axis with mean KLD values?
# somehow combining pz_kld_moments with this?
# something is not right here with limits, need to check after nersc run
def plot_kld_stats(name, size):
    a = 1./len(formats)
    topdir = os.path.join(name, str(size))
    
    fig_one, ax_one = plt.subplots(figsize=(5, 5))
    fig_one.canvas.draw()
    mean_deltas, std_deltas = {}, {}
    for f in formats:
        mean_deltas[f], std_deltas[f] = [], []
        ax_one.plot([1000., 1000.], [1., 10.], color=colors[f], alpha=a, label=f, linestyle=styles[f])
    for i in instantiations:
        to_plot = {}
        for f in formats:
            to_plot[f] = []
            mean_deltas[f].append([])
            std_deltas[f].append([])
        for Nf in floats:
            loc = os.path.join(topdir, 'kld_hist'+str(size)+name+str(Nf)+'_'+str(i))
            with open(loc+'.hkl', 'r') as filename:
                klds = hickle.load(filename)['pz_klds']
                for f in formats:
                    to_plot[f].append(klds[f])
        for f in formats:
            to_plot[f] = np.array(to_plot[f])
            delta_info = np.zeros((len(floats), size))
            delta_info[:-1] = to_plot[f][:-1] - to_plot[f][1:]
            delta_info[-1] = to_plot[f][-1]
#             delta_info[delta_info < qp.utils.epsilon] = qp.utils.epsilon
#             log_delta_info = np.log(delta_info)
#             ax_one.plot(floats, log_delta_info)
            mean_deltas[f][i] = np.mean(delta_info, axis=1)
            std_deltas[f][i] = np.std(delta_info, axis=1)
            indie_delta_kld_min.append(np.min(mean_deltas[f][i] - std_deltas[f][i]))
            indie_delta_kld_max.append(np.max(mean_deltas[f][i] + std_deltas[f][i]))
            ax_one.plot(floats, mean_deltas[f][i], color=colors[f], alpha=a, linestyle=styles[f])
    ax_one.set_ylabel(r'$\Delta\mathrm{KLD}$ (nats)')
#     ax_one.semilogy()
    ax_one.set_ylim(0., np.max(indie_delta_kld_max))
    ax_one.set_xlim(min(floats), max(floats))
    ax_one.set_xlabel('change in number of parameters')
    ax_one.semilogx()
    ax_one.set_xticks(floats)
    ax_one.set_xticklabels([r'$3\to 10$', r'$10\to 30$', r'$30\to 100$', r'$100\to \infty$'])
    ax_one.legend(loc='upper right')
    ax_one.set_title(dataset_info[name]['name']+r' data per-PDF $\Delta\mathrm{KLD}$', fontsize=16)
    place = os.path.join(topdir, 'indie_klds'+str(size)+name)
    fig_one.savefig(place+'_each.pdf', dpi=250)
    plt.close()
    
    fig, ax = plt.subplots(figsize=(5, 5))
    for f in formats:
        mean_deltas[f] = np.array(mean_deltas[f])
        std_deltas[f] = np.array(std_deltas[f])
        global_delta_mean = np.mean(mean_deltas[f], axis=0)
        global_delta_std = np.sqrt(np.sum(mean_deltas[f]**2, axis=0))
        print(global_delta_mean, global_delta_std)
#         x_cor = np.array([floats[:-1], floats[:-1], floats[1:], floats[1:]])
        y_plus = global_delta_mean + global_delta_std
        y_minus = global_delta_mean - global_delta_std
#         y_minus[y_minus < qp.utils.epsilon] = qp.utils.epsilon
        indie_delta_kld_min.append(np.min(y_minus))
        indie_delta_kld_max.append(np.max(y_plus))
#         y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
#         ax.fill(x_cor, y_cor, color=colors[f], alpha=a, linewidth=0.)
        ax.fill_between(floats, y_minus, y_plus, color=colors[f], alpha=a, linewidth=0.)
        ax.plot(floats, global_delta_mean, color=colors[f], linestyle=styles[f], label=f)
    ax.set_ylabel(r'$\Delta\mathrm{KLD}$ (nats)')
#     ax.semilogy()
    ax.set_ylim(0., np.max(indie_delta_kld_max))
    ax.set_xlim(min(floats), max(floats))
    ax.set_xlabel('change in number of parameters')
    ax.semilogx()
    ax.set_xticks(floats)
    ax.set_xticklabels([r'$3\to 10$', r'$10\to 30$', r'$30\to 100$', r'$100\to \infty$'])
    ax.legend(loc='upper right')
    ax.set_title(dataset_info[name]['name']+r' data per-PDF $\Delta\mathrm{KLD}$', fontsize=16)
    place = os.path.join(topdir, 'indie_klds'+str(size)+name)
    fig.savefig(place+'_clean.pdf', dpi=250)
    plt.close()

In [ ]:
def plot_pz_metrics(dataset_key, n_gals_use):

    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'pz_kld_moments'+str(n_gals_use)+dataset_key)
    with open(loc+'.hkl', 'r') as pz_file:
        pz_stats = hickle.load(pz_file)
  
    flat_floats = np.array(pz_stats['N_f']).flatten()
    in_x = np.log(flat_floats)

    def make_patch_spines_invisible(ax):
        ax.set_frame_on(True)
        ax.patch.set_visible(False)
        for sp in ax.spines.values():
            sp.set_visible(False)

    shapes = moment_shapes
    marksize = 10
    a = 1./len(formats)
    
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax.plot([-1], [0], color=colors[key], label=key, linewidth=2, linestyle=styles[key], alpha=0.5)
    for n in range(1, n_moments_use):
        ax.scatter([-1], [0], color='k', alpha=0.5, marker=shapes[n], s=50, label=moment_names[n])
        n_factor = 0.1 * (n - 2)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
#             print('pz metrics data shape '+str(pz_stats[f][n]))
            data_arr = np.log(np.swapaxes(np.array(pz_stats[f][n]), 0, 1))#go from n_floats*instantiations to instantiations*n_floats
            mean = np.mean(data_arr, axis=0).flatten()
            std = np.std(data_arr, axis=0).flatten()
            y_plus = mean + std
            y_minus = mean - std
#             y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
            ax_n.plot(np.exp(in_x+n_factor), mean, marker=shapes[n], mfc='none', markersize=marksize, linestyle=styles[f], alpha=a, color=colors[f])
            ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color=colors[f])
            pz_mean_max[n] = max(pz_mean_max[n], np.max(y_plus))
            pz_mean_min[n] = min(pz_mean_min[n], np.min(y_minus))
        ax_n.set_ylabel(r'$\log[\mathrm{'+moment_names[n]+r'}]$', rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim((pz_mean_min[n]-1., pz_mean_max[n]+1.))
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\log[\mathrm{KLD}]$ log-moments', fontsize=16)
    ax.legend(loc='lower left')
    fig.tight_layout()
    fig.savefig(loc+'_clean.pdf', dpi=250)
    plt.close()
    
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax_n.plot([-1], [0], color=colors[key], label=key, linestyle=styles[key], alpha=0.5, linewidth=2)
    for n in range(1, n_moments_use):
        n_factor = 0.1 * (n - 2)
        ax.scatter([-1], [0], color='k', alpha=0.5, marker=shapes[n], s=50, label=moment_names[n])
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
#             print('pz metrics data shape '+str(pz_stats[f][n]))
            data_arr = np.log(np.swapaxes(np.array(pz_stats[f][n]), 0, 1))#go from n_floats*instantiations to instantiations*n_floats
            for i in data_arr:
                ax_n.plot(np.exp(in_x+n_factor), i, linestyle=styles[f], marker=shapes[n], mfc='none', markersize=marksize, color=colors[f], alpha=a)
#                 pz_moment_max[n-1].append(max(i))
        ax_n.set_ylabel(r'$\log[\mathrm{'+moment_names[n]+r'}]$', rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim(pz_mean_min[n]-1., pz_mean_max[n]+1.)
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\log[\mathrm{KLD}]$ log-moments', fontsize=16)
    ax.legend(loc='lower left')
    fig.tight_layout()
    fig.savefig(loc+'_all.pdf', dpi=250)
    plt.close()

In [ ]:
def plot_pz_delta_moments(name, size):
    n_gals_use = size
    extremum = np.zeros(n_moments_use)
    
    # should look like nz_moments
    path = os.path.join(name, str(size))
    loc = os.path.join(path, 'pz_moment_deltas'+str(size)+name)
    with open(loc+'.hkl', 'r') as pz_file:
        pz_stats = hickle.load(pz_file)
    flat_floats = np.array(pz_stats['N_f']).flatten()
    in_x = np.log(flat_floats)
    a = 1./len(formats)
    shapes = moment_shapes
    marksize = 10
    
    def make_patch_spines_invisible(ax):
        ax.set_frame_on(True)
        ax.patch.set_visible(False)
        for sp in ax.spines.values():
            sp.set_visible(False)   
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], alpha=0.5, linewidth=2)
    for n in range(1, n_moments_use):
        ax.scatter([-10], [0], color='k', alpha=0.5, marker=shapes[n], s=50, label=moment_names[n])
        n_factor = 0.1 * (n - 2)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
            old_shape = np.shape(np.array(pz_stats[f][n]))
            new_shape = (old_shape[0], np.prod(old_shape[1:]))
            data_arr = np.abs(np.array(pz_stats[f][n]).reshape(new_shape)) * 100.#go from n_floats*instantiations*n_gals n_floats*(n_gals*n_instantiations)
#             data_arr = np.median(data_arr, axis=2) * 100.
#             data_arr = np.swapaxes(np.array(nz_stats[f][n]), 0, 1)* 100.#np.log(np.swapaxes(np.array(nz_stats[f]), 0, 1)[:][:][n])#go from n_floats*instantiations to instantiations*n_floats
#             mean = np.mean(data_arr, axis=0).flatten()
#             std = np.std(data_arr, axis=0).flatten()
#             mean = np.median(data_arr, axis=-1)
            std = np.log10(np.percentile(data_arr, [25, 50, 75], axis=-1))
            y_plus = std[-1]#mean + std
            y_minus = std[0]#mean - std
            mean = std[1]
#             y_cor = np.array([y_minus, y_plus, y_plus, y_minus])
            ax_n.plot(np.exp(in_x+n_factor), mean, linestyle=styles[f], marker=shapes[n], mfc='none', markersize=marksize, alpha=a, color=colors[f])
            ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color=colors[f])
#             print('before '+str((np.shape(data_arr), n, n_delta_max, n_delta_min, y_plus, y_minus)))
            n_delta_max[n] = max(n_delta_max[n], np.max(y_plus))
            n_delta_min[n] = min(n_delta_min[n], np.min(y_minus))
#             old_shape = np.shape(np.array(pz_stats[f][n]))
#             new_shape = (old_shape[0], np.prod(old_shape[1:]))
#             data_arr = np.array(pz_stats[f][n]).reshape(new_shape)#go from n_floats*instantiations to instantiations*n_floats
# #             data_arr = np.array(pz_stats[f][n])
# #             data_arr = np.median(data_arr, axis=2) * 100.
#             mean = np.mean(data_arr, axis=1)
#             std = np.std(data_arr, axis=1)
#             y_plus = (mean + std) * 100.
#             y_minus = (mean - std) * 100.
# #             y_cor = np.array([y_minus, y_plus, y_plus, y_minus])
#             ax_n.plot(np.exp(in_x+n_factor), mean, linestyle=styles[f], marker=shapes[n], mfc='none', markersize=marksize, alpha=a, color=colors[f])
#             ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color=colors[f])
#             print('before '+str((np.shape(data_arr), n, n_delta_max, n_delta_min, y_plus, y_minus)))
#             n_delta_max[n] = np.max(n_delta_max[n], np.max(y_plus))
#             n_delta_min[n] = np.min(n_delta_min[n], np.min(y_minus))
#             print('after '+str((n_delta_max, n_delta_min)))
        ax_n.set_ylabel(r'$\log_{10}$-percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        extremum[n] = np.max(np.abs(np.array([n_delta_min[n], n_delta_max[n]]))) + 0.25
        ax_n.set_ylim(-1.*extremum[n], extremum[n])
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[name]['name']+r' data $\hat{p}(z)$ moment log-percent errors', fontsize=16)
    ax.legend(loc=dataset_info[name]['legloc_p'])
    fig.tight_layout()
    fig.savefig(loc+'_clean.pdf', dpi=250)
    plt.close()
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax_n.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], alpha=0.5, linewidth=2)
    for n in range(1, n_moments_use):
        n_factor = 0.1 * (n - 2)
        ax.scatter([-10], [0], color='k', alpha=a, marker=shapes[n], s=50, label=moment_names[n])
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
            data_arr = np.swapaxes(np.array(pz_stats[f][n]), 0, 1)
            data_arr = np.median(data_arr, axis=2) * 100.
            for i in data_arr:
                ax_n.plot(np.exp(in_x+n_factor), i, linestyle=styles[f], marker=shapes[n], mfc='none', markersize=marksize, color=colors[f], alpha=a)
        ax_n.set_ylabel(r'median percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim(-10., 10.)#(-1.*extremum[n], extremum[n])
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[name]['name']+r' data $\hat{p}(z)$ moment percent errors', fontsize=16)
    ax.legend(loc='upper left')
    fig.tight_layout()
    fig.savefig(loc+'_all.pdf', dpi=250)
    plt.close()

We want to plot the KLD on $\hat{n}(z)$ for all formats as $N_{f}$ changes. We want to repeat this for many subsamples of the catalog to establush error bars on the KLD values.


In [ ]:
def save_nz_metrics(dataset_name, n_gals_use, N_f, nz_klds, stat_name):
    
    path = os.path.join(dataset_name, str(n_gals_use))
    loc = os.path.join(path, stat_name+str(n_gals_use)+dataset_name)
    if os.path.exists(loc+'.hkl'):
        with open(loc+'.hkl', 'r') as nz_file:
        #read in content of list/dict
            nz_stats = hickle.load(nz_file)
    else:
        nz_stats = {}
        nz_stats['N_f'] = []
        for f in formats:
            nz_stats[f] = []
    
    if N_f not in nz_stats['N_f']:
        nz_stats['N_f'].append(N_f)
        for f in formats:
            nz_stats[f].append([])
        
    where_N_f = nz_stats['N_f'].index(N_f) 
    
    for f in formats:
        nz_stats[f][where_N_f].append(nz_klds[f])

    with open(loc+'.hkl', 'w') as nz_file:
        hickle.dump(nz_stats, nz_file)

In [ ]:
def plot_nz_klds(dataset_key, n_gals_use):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'nz_klds'+str(n_gals_use)+dataset_key)
    with open(loc+'.hkl', 'r') as nz_file:
        nz_stats = hickle.load(nz_file)
#     if len(instantiations) == 10:
#         for f in formats:
#             if not np.shape(nz_stats[f]) == (4, 10):
#                 for s in range(len(floats)):
#                     nz_stats[f][s] = np.array(np.array(nz_stats[f][s])[:10]).flatten()

    flat_floats = np.array(nz_stats['N_f']).flatten()
    
    plt.figure(figsize=(5, 5))
    for f in formats:
#         print('nz klds data shape '+str(nz_stats[f][n]))
        data_arr = np.swapaxes(np.array(nz_stats[f]), 0, 1)#turn N_f * instantiations into instantiations * N_f
        n_i = len(data_arr)
        a = 1./len(formats)#1./n_i
        plt.plot([10. * max(flat_floats), 10. * max(flat_floats)], [1., 10.], color=colors[f], alpha=a, label=f, linestyle=styles[f])
        for i in data_arr:
            plt.plot(flat_floats, i, color=colors[f], alpha=a, linestyle=styles[f])
            kld_min.append(min(i))
            kld_max.append(max(i))
    plt.semilogy()
    plt.semilogx()
    plt.xticks(flat_floats, [str(ff) for ff in flat_floats])
    plt.ylim(min(kld_min), max(kld_max))
    plt.xlim(min(flat_floats), max(flat_floats))
    plt.xlabel(r'number of parameters', fontsize=14)
    plt.ylabel(r'KLD', fontsize=14)
    plt.legend(loc='upper right')
    plt.title(r'$\hat{n}(z)$ KLD on '+str(n_gals_use)+' from '+dataset_info[dataset_key]['name']+' mock catalog', fontsize=16)
    plt.savefig(loc+'_all.pdf', dpi=250)
    plt.close()

    plt.figure(figsize=(5, 5))
    a = 1./len(formats)
    for f in formats:
#         print('nz klds data shape '+str(nz_stats[f][n]))
        data_arr = np.swapaxes(np.array(nz_stats[f]), 0, 1)#turn N_f * instantiations into instantiations * N_f
        plt.plot([10. * max(flat_floats), 10. * max(flat_floats)], [1., 10.], color=colors[f], label=f, linestyle=styles[f])
        kld_min.append(np.min(data_arr))
        kld_max.append(np.max(data_arr))
        mean = np.mean(np.log(data_arr), axis=0)
        std = np.std(np.log(data_arr), axis=0)
        x_cor = np.array([flat_floats[:-1], flat_floats[:-1], flat_floats[1:], flat_floats[1:]])
        y_plus = np.exp(mean + std)
        y_minus = np.exp(mean - std)
        y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
        plt.plot(flat_floats, np.exp(mean), color=colors[f], linestyle=styles[f])
        plt.fill(x_cor, y_cor, color=colors[f], alpha=a, linewidth=0.)
    plt.semilogy()
    plt.semilogx()
    plt.xticks(flat_floats, [str(ff) for ff in flat_floats])
    plt.ylim(min(kld_min), max(kld_max))
    plt.xlim(min(flat_floats), max(flat_floats))
    plt.xlabel(r'number of parameters', fontsize=14)
    plt.ylabel(r'KLD', fontsize=14)
    plt.legend(loc='upper right')
    plt.title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ KLD', fontsize=16)
    plt.savefig(loc+'_clean.pdf', dpi=250)
    plt.close()

In [ ]:
def plot_nz_moments(dataset_key, n_gals_use):

    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'nz_moments'+str(n_gals_use)+dataset_key)
    with open(loc+'.hkl', 'r') as nz_file:
        nz_stats = hickle.load(nz_file)
    flat_floats = np.array(nz_stats['N_f']).flatten()
    in_x = np.log(flat_floats)
    a = 1./len(formats)
    shapes = moment_shapes
    marksize = 10
    
    def make_patch_spines_invisible(ax):
        ax.set_frame_on(True)
        ax.patch.set_visible(False)
        for sp in ax.spines.values():
            sp.set_visible(False)   
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], alpha=0.5, linewidth=2)
    for n in range(1, n_moments_use):
        ax.scatter([-10], [0], color='k', alpha=0.5, marker=shapes[n], s=50, label=moment_names[n])
        n_factor = 0.1 * (n - 2)
        truth = np.swapaxes(np.array(nz_stats['truth'][n]), 0, 1)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
            data_arr = (np.swapaxes(np.array(nz_stats[f][n]), 0, 1) - truth) / truth * 100.#np.log(np.swapaxes(np.array(nz_stats[f]), 0, 1)[:][:][n])#go from n_floats*instantiations to instantiations*n_floats
#             data_arr = np.abs(np.array(pz_stats[f][n]).reshape(new_shape)) * 100.#go from n_floats*instantiations*n_gals n_floats*(n_gals*n_instantiations)
#             data_arr = np.median(data_arr, axis=2) * 100.
#             data_arr = np.swapaxes(np.array(nz_stats[f][n]), 0, 1)* 100.#np.log(np.swapaxes(np.array(nz_stats[f]), 0, 1)[:][:][n])#go from n_floats*instantiations to instantiations*n_floats
#             mean = np.mean(data_arr, axis=0).flatten()
#             std = np.std(data_arr, axis=0).flatten()
#             mean = np.median(data_arr, axis=-1)
#             std = np.log10(np.percentile(np.abs(data_arr), [25, 50, 75], axis=0))
            std = np.percentile(data_arr, [25, 50, 75], axis=0)
            y_plus = std[-1]#mean + std
            y_minus = std[0]#mean - std
            mean = std[1]
#             mean = np.mean(data_arr, axis=0).flatten()
#             std = np.std(data_arr, axis=0).flatten()
#             y_plus = mean + std
#             y_minus = mean - std
#             y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
            ax_n.plot(np.exp(in_x+n_factor), mean, linestyle=styles[f], marker=shapes[n], mfc='none', markersize=marksize, alpha=a, color=colors[f])
            ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color=colors[f])
            nz_mean_max[n] = max(nz_mean_max[n], np.max(y_plus))
            nz_mean_min[n] = min(nz_mean_min[n], np.min(y_minus))
        ax_n.set_ylabel(r'percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        extremum = np.max(np.abs([nz_mean_min[n], nz_mean_max[n]])) + 1.#0.25
        ax_n.set_ylim(-1. * extremum, extremum)
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ moment percent errors', fontsize=16)
    ax.legend(loc=dataset_info[name]['legloc_n'])#FINDME!
    fig.tight_layout()
    fig.savefig(loc+'_clean_unlog.pdf', dpi=250)
    plt.close()
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax_n.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], alpha=0.5, linewidth=2)
    for n in range(1, n_moments_use):
        n_factor = 0.1 * (n - 2)
        ax.scatter([-10], [0], color='k', alpha=0.5, marker=shapes[n], s=50, label=moment_names[n])
        truth = np.swapaxes(np.array(nz_stats['truth'][n]), 0, 1)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
            data_arr = (np.swapaxes(np.array(nz_stats[f][n]), 0, 1) - truth) / truth * 100.
            for i in data_arr:
                ax_n.plot(np.exp(in_x+n_factor), i, linestyle=styles[f], marker=shapes[n], mfc='none', markersize=marksize, color=colors[f], alpha=a)
        ax_n.set_ylabel(r'median percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim(-15., 15.)
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ moment percent errors', fontsize=16)
    ax.legend(loc='lower left')
    fig.tight_layout()
    fig.savefig(loc+'_all.pdf', dpi=250)
    plt.close()

In [ ]:
# def print_nz_moments(dataset_key, n_gals_use):
#     path = os.path.join(dataset_key, str(n_gals_use))
    
#     dz = dataset_info[dataset_key]['delta_z']
#     z_grid = dataset_info[dataset_key]['z_grid']
#     full_stack = {}
#     all_moments = {}
#     for f in formats_plus:
#         full_stack[f] = []
#         all_moments[f] = []
#     for nf in range(len(floats)):
#         n_floats_use = floats[nf]
#         for f in formats_plus:
#             full_stack[f].append(np.zeros(len(z_grid)))
#             all_moments[f].append([])
#         for i in instantiations:
#             loc = os.path.join(path, 'nz_comp'+str(n_gals_use)+dataset_key+str(n_floats_use)+'_'+str(i))
#             with open(loc+'.hkl', 'r') as filename:
#                 info = hickle.load(filename)
# #                 z_grid = info['z_grid']
#                 stacks = info['stacks']
# #                 klds = info['klds']
#             for key in formats_plus:
#                 full_stack[key][nf] += stacks[key]
#         for n in range(1, n_moments_use):
#             ngrid = z_grid**n
#             all_moments['truth'][nf].append(qp.utils.quick_moment(full_stack['truth'][nf], ngrid, dz))
#             for key in formats:
#                 all_moments[key][nf].append((qp.utils.quick_moment(full_stack[key][nf], ngrid, dz) - all_moments['truth'][nf][-1]) / all_moments['truth'][nf][-1])
#     for f in formats:
#         all_moments[f] = np.array(all_moments[f])
#     print(dataset_key, n_gals_use, all_moments)
          
#     in_x = np.log(floats)
#     a = 1./len(formats)
#     shapes = moment_shapes
#     marksize = 7.5
    
#     def make_patch_spines_invisible(ax):
#         ax.set_frame_on(True)
#         ax.patch.set_visible(False)
#         for sp in ax.spines.values():
#             sp.set_visible(False)
        
#     fig, ax = plt.subplots()
#     fig.subplots_adjust(right=1.)
#     ax_n = ax
#     for key in formats:
#         ax_n.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], alpha=0.5, linewidth=2)
#     for n in range(1, n_moments_use):
#         n_factor = 0.1 * (n - 2)
#         ax.scatter([-10], [0], color='k', alpha=0.5, marker=shapes[n], facecolors='none', s=50, label=moment_names[n])
# #         truth = np.swapaxes(np.array(nz_stats['truth'][n]), 0, 1)
#         if n>1:
#             ax_n = ax.twinx()
#             rot_ang = 270
#             label_space = 15.
#         else:
#             rot_ang = 90
#             label_space = 0.
#         if n>2:
#             ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
#             make_patch_spines_invisible(ax_n)
#             ax_n.spines["right"].set_visible(True)
#         for s in range(len(formats)):
#             f = formats[s]
#             f_factor = 0.05 * (s - 1)
#             data_arr = np.swapaxes(all_moments[f], 0, 1) * 100.
#             ax_n.plot(np.exp(in_x+n_factor), data_arr[n-1], linestyle=styles[f], color=colors[f], alpha=a)
#             ax_n.scatter(np.exp(in_x+n_factor), data_arr[n-1], marker=shapes[n], mfc='none', markersize=marksize, color=colors[f], alpha=0.5)
#         ax_n.set_ylabel(r'percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
# #         ax_n.set_ylim(-1. * extremum, extremum)
#     ax.set_xscale('log')
#     ax.set_xticks(floats)
#     ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
#     ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
#     ax.set_xlabel('number of parameters', fontsize=14)
#     ax.set_title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ moments', fontsize=16)
#     ax.legend(loc='lower left')
#     fig.tight_layout()
#     outloc = os.path.join(path, 'nz_moments'+str(n_gals_use)+dataset_key)
#     fig.savefig(outloc+'_final.pdf', dpi=250)
#     plt.close()
    
#     for nf in range(len(floats)):
#         n_floats_use = floats[nf]
#         plt.figure()
#         plt.plot(z_grid, full_stack['truth'][nf], color='black', lw=3, alpha=0.3, label='original')
#         for key in formats:
#             kld = qp.utils.quick_kl_divergence(full_stack['truth'][nf], full_stack[key][nf], dx=dz)
#             plt.plot(z_grid, full_stack[key][nf], color=colors[key], linestyle=styles[key], label=key+r' KLD='+str(kld)[:8])#+r'; '+str(all_moments[f][nf])+' percent error')
#         plt.xlabel(r'$z$', fontsize=14)
#         plt.ylabel(r'$\hat{n}(z)$', fontsize=14)
#         plt.xlim(min(z_grid), max(z_grid))
#     #     plt.ylim(0., max(nz_max))
#         plt.legend()
#         plt.title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ with $N_{f}='+str(n_floats_use)+r'$', fontsize=16)
#         outloc = os.path.join(path, 'global_nz'+str(n_gals_use)+dataset_key+str(n_floats_use))
#         plt.savefig(outloc+'.pdf', dpi=250)
#         plt.close()

Okay, now all I have to do is have this loop over both datasets, number of galaxies, number of floats, and instantiations!

Note: It takes about 5 minutes per # floats considered for 100 galaxies, and about 40 minutes per # floats for 1000 galaxies. (So, yes, it scales more or less as expected!)


In [ ]:
dataset_info = {}
delta = 0.01

dataset_keys = ['mg', 'ss']

for name in dataset_keys:
    dataset_info[name] = {}
    if name == 'mg':
        datafilename = 'bpz_euclid_test_10_3.probs'
        z_low = 0.01
        z_high = 3.51
        nc_needed = 3
        plotname = 'brighter'
        skip_rows = 1
        skip_cols = 1
        legloc_p = 'upper right'
        legloc_n = 'upper left'
    elif name == 'ss':
        datafilename = 'test_magscat_trainingfile_probs.out'
        z_low = 0.005
        z_high = 2.11
        nc_needed = 5
        plotname = 'fainter'
        skip_rows = 1
        skip_cols = 1
        legloc_p = 'lower left'
        legloc_n = 'lower right'
    dataset_info[name]['filename'] = datafilename  
    
    dataset_info[name]['z_lim'] = (z_low, z_high)
    z_grid = np.arange(z_low, z_high, delta, dtype='float')#np.arange(z_low, z_high + delta, delta, dtype='float')
    z_range = z_high - z_low
    delta_z = z_range / len(z_grid)
    dataset_info[name]['z_grid'] = z_grid
    dataset_info[name]['delta_z'] = delta_z

    dataset_info[name]['N_GMM'] = nc_needed# will be overwritten later
    dataset_info[name]['name'] = plotname
    dataset_info[name]['legloc_p'] = legloc_p
    dataset_info[name]['legloc_n'] = legloc_n

In [ ]:
formats = ['quantiles', 'histogram', 'samples']
formats_plus = list(formats)
formats_plus.append('truth')
n_formats =len(formats)

high_res = 300

color_cycle = np.array([(230, 159, 0), (86, 180, 233), (0, 158, 115), (240, 228, 66), (0, 114, 178), (213, 94, 0), (204, 121, 167)])/256.
color_cycle_names = ['Orange', 'Sky blue', 'Bluish green', 'Yellow', 'Blue', 'Vermilion', 'Reddish purple']
n_plot = len(color_cycle)

n_moments_use = 4
n_symb = 5
moment_names = ['integral', 'mean', 'variance', 'kurtosis']
moment_shapes = [(n_symb, 3, 0), (n_symb, 0, 0), (n_symb, 1, 0), (n_symb, 2, 0)]

For debugging, specify the randomly selected PDFs.


In [ ]:
#change all for NERSC

floats = [3, 10, 30, 100]
sizes = [100]#[10, 100, 1000]
names = dataset_info.keys()
instantiations = range(0, 10)

all_randos = [[np.random.choice(size, n_plot, replace=False) for size in sizes] for name in names]
# all_randos = [[np.random.choice(indices, n_plot, replace=False) for size in sizes] for name in names]

The "pipeline" is a bunch of nested for loops because qp.Ensemble makes heavy use of multiprocessing. Doing multiprocessing within multiprocessing may or may not cause problems, but I am certain that it makes debugging a nightmare.

Okay, without further ado, let's do it!


In [ ]:
# # the "pipeline"
# global_start = timeit.default_timer()
# for n in range(len(names)):
#     name = names[n]
    
#     dataset_start = timeit.default_timer()
#     print('started '+name)
    
#     pdfs = setup_dataset(name, skip_rows, skip_cols)
    
#     for s in range(len(sizes)):
#         size = sizes[s]
        
#         size_start = timeit.default_timer()
#         print('started '+name+str(size))
        
#         path = os.path.join(name, str(size))
#         if not os.path.exists(path):
#             os.makedirs(path)
        
#         n_gals_use = size
        
#         randos = all_randos[n][s]
        
#         for i in instantiations:
# #             top_bonusdict = {}
#             i_start = timeit.default_timer()
#             print('started '+name+str(size)+' #'+str(i))
        
#             original = '_original'+str(i)
#             pdfs_use = make_instantiation(name, size, pdfs, bonus=original)
# #             plot = plot_examples(size, name, bonus=original)
# #             top_bonusdict[original] = ['-', 0.25]
        
#             z_grid = dataset_info[name]['in_z_grid']
#             N_comps = dataset_info[name]['N_GMM']
        
#             postfit = '_postfit'+str(i)
#             catalog = setup_from_grid(name, pdfs_use, z_grid, N_comps, high_res=high_res, bonus=postfit)
# #             plot = plot_examples(size, name, bonus=postfit)
# #             top_bonusdict[postfit] = ['-', 0.5]
        
#             for n_floats_use in floats:
# #                 bonusdict = top_bonusdict.copy()
#                 float_start = timeit.default_timer()
#                 print('started '+name+str(size)+' #'+str(i)+' with '+str(n_floats_use))
        
#                 ensembles = analyze_individual(catalog, z_grid, n_floats_use, name, n_moments_use, i=i, bonus=postfit)
                
# #                 for f in formats:
# #                     fname = str(n_floats_use)+f+str(i)
# #                     plot = plot_examples(size, name, bonus=fname)
# #                     bonusdict[fname] = [styles[f], 0.5]
# #                 plot = plot_all_examples(name, size, n_floats_use, i, bonus=bonusdict)
# #                 plot = plot_individual_kld(size, name, n_floats_use, i=i)
            
#                 stack_evals = analyze_stacked(catalog, ensembles, z_grid, n_floats_use, name, i=i)
# #                 plot = plot_estimators(size, name, n_floats_use, i=i)
            
#                 print('FINISHED '+name+str(size)+' #'+str(i)+' with '+str(n_floats_use)+' in '+str(timeit.default_timer() - float_start))
#             print('FINISHED '+name+str(size)+' #'+str(i)+' in '+str(timeit.default_timer() - i_start))
# #         plot = plot_pz_metrics(name, size)
# #         plot = plot_pz_delta_moments(name, size)      
# #         plot = plot_nz_klds(name, size)
# #         plot = plot_nz_moments(name, size)
        
#         print('FINISHED '+name+str(size)+' in '+str(timeit.default_timer() - size_start))
        
#     print('FINISHED '+name+' in '+str(timeit.default_timer() - dataset_start))
# print('FINISHED everything in '+str(timeit.default_timer() - global_start))

Remake the plots to share axes, enabling combination of runs.


In [ ]:
floats = [3, 10, 30, 100]
sizes = [100]#[10, 100, 1000]
names = dataset_info.keys()
instantiations = range(0, 10)

all_randos = [[np.random.choice(size, n_plot, replace=False) 
               for size in sizes] for name in names]

In [ ]:
#make this a more clever structure, i.e. a dict
colors = {'quantiles': 'darkviolet', 'histogram': 'darkorange', 'samples': 'g'}
styles = {'quantiles': '--', 'histogram': ':', 'samples': '-.'}
stepstyles = {'quantiles': 'dashed', 'histogram': 'dotted', 'samples': 'dashdot'}

colors_plus = colors.copy()
colors_plus['truth'] = 'black'
styles_plus = styles.copy()
styles_plus['truth'] = '-'

iqr_min = [3.5]
iqr_max = [delta]
modes_max = [0]
pz_max = [1.]
nz_max = [1.]
hist_max = [1.]
dist_min = [0.]
dist_max = [0.]
pz_mean_max = -10.*np.ones(n_moments_use)
pz_mean_min = 10.*np.ones(n_moments_use)
kld_min = [1.]
kld_max = [1.]
indie_delta_kld_min = [1.]
indie_delta_kld_max = [-1.]
nz_mean_max = -10.*np.ones(n_moments_use)
nz_mean_min = 10.*np.ones(n_moments_use)
n_delta_max = -10.*np.ones(n_moments_use)
n_delta_min = 10.*np.ones(n_moments_use)

norm = False#true for shared axes on individual instantiation plots, otherwise false

moments_to_save = ['pz_kld_moments', 'pz_moments', 'pz_moment_deltas', 'nz_moments']
metrics_to_save = ['nz_klds']

In [ ]:
# comment out for NERSC
# set norm to True and run twice to match axis limits

for name in names:
    for size in sizes:
#         for stat_name in moments_to_save + metrics_to_save:
#             clear_stats(name, size, stat_name)
#         for i in instantiations:
#             top_bonusdict = {}
#             bo = '_original'+str(i)
#             plot = plot_examples(size, name, bonus=bo, norm=norm)
#             top_bonusdict[bo] = ['-', 0.25]
#             bp = '_postfit'+str(i)
#             plot = plot_examples(size, name, bonus=bp, norm=norm)
#             top_bonusdict[bp] = ['-', 0.5]
#             for n in range(len(floats)):
#                 bonusdict = top_bonusdict.copy()
#                 n_floats_use = floats[n]
#                 for f in formats:
#                     fname = str(n_floats_use)+f+str(i)
#                     plot = plot_examples(size, name, bonus=fname, norm=norm)
#                     bonusdict[fname] = [styles[f], 0.5]
#                 plot = plot_all_examples(name, size, n_floats_use, i, bonus=bonusdict)
#                 plot = plot_individual_kld(size, name, n_floats_use, i)
#                 plot = plot_estimators(size, name, n_floats_use, i)
#                 for stat_name in moments_to_save:
#                     save_moments_wrapper(name, size, n_floats_use, i, stat_name)
#                 for stat_name in metrics_to_save:
#                     save_metrics_wrapper(name, size, n_floats_use, i, stat_name)
        plot = plot_kld_stats(name, size)
        plot = plot_pz_metrics(name, size)
        plot = plot_pz_delta_moments(name, size)
#         plot = plot_nz_klds(name, size)
        plot = plot_nz_moments(name, size)

In [ ]:
# def just_modality(dataset_key, n_gals_use, bonus=None):
#     import scipy.signal
#     path = os.path.join(dataset_key, str(n_gals_use))
#     loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+bonus)
#     with open(loc+'.hkl', 'r') as filename:
#         info = hickle.load(filename)
#         pdfs_use = info['pdfs']
#     modality, iqrs = [], []
#     dpdfs = pdfs_use[:,1:] - pdfs_use[:,:-1]
#     ddpdfs = dpdfs[:, 1:] - dpdfs[:, :-1]
#     for i in range(n_gals_use):
#         modality.append(len(scipy.signal.argrelmax(pdfs_use[i])[0]))#(len(np.where(np.signbit(ddpdfs[i]))[0]))
#         cdf = np.cumsum(qp.utils.normalize_integral((dataset_info[dataset_key]['z_grid'], pdfs_use[i]), vb=False)[1])
#         iqr_lo = dataset_info[dataset_key]['z_grid'][bisect.bisect_left(cdf, 0.25)]
#         iqr_hi = dataset_info[dataset_key]['z_grid'][bisect.bisect_left(cdf, 0.75)]
#         iqrs.append(iqr_hi - iqr_lo)
# #     modality = np.array(modality)
# #     iqrs = np.array(iqrs)

# #     loc = os.path.join(path, 'modality'+str(n_gals_use)+dataset_key+bonus)
# #     with open(loc+'.hkl', 'w') as filename:
# #         info = {}
# #         info['modes'] = modality
# #         info['iqrs'] = iqrs
# #         hickle.dump(info, filename)
#     return(modality, iqrs)

In [ ]:
# all_modes, all_iqrs = {}, {}
# for name in names:
#     all_modes[name], all_iqrs[name] = {}, {}
#     for size in sizes:
#         all_modes[name][str(size)], all_iqrs[name][str(size)] = [], []
#         for i in instantiations:
# #         print_nz_moments(name, size)
#             original = '_original'+str(i)
#             (modality, iqrs) = just_modality(name, size, bonus=original)
#             all_modes[name][str(size)].append(modality)
#             all_iqrs[name][str(size)].append(iqrs)

In [ ]:
# for name in names:
#     for size in sizes:
#         modality = np.array(all_modes[name][str(size)]).flatten()
#         modality_cdf = []
#         modegrid = range(np.max(modality))
#         for x in modegrid:
#             modality_cdf.append(len(modality[modality==x]))
#         plt.hist(modality, normed=True)
#         plt.title(name+str(size)+'modality'+str(np.median(modality)))
#         plt.show()
#         plt.close()
#         print(zip(modegrid, modality_cdf))
# #         iqrdist = np.array(all_iqrs[name][str(size)]).flatten()
# #         plt.title(name+str(size)+'iqrdist'+str(np.median(iqrdist)))
# #         plt.hist(iqrdist, normed=True)
# #         plt.show()
# #         plt.close()

In [ ]:
# thing = load_one_stat('ss', 100, 3, 0, 'pz_moment_deltas')
# print(np.mean(np.shinape(thing['quantiles']), axis=0))

In [ ]:
# save_moments('ss', 100, 3, thing, 'pz_moment_deltas')

In [ ]:
# path = os.path.join('ss', str(100))
# loc = os.path.join(path, 'pz_moment_deltas'+str(100)+'ss')
# with open(loc+'.hkl', 'r') as pz_file:
#     pz_stats = hickle.load(pz_file)
    
# print(np.shape(pz_stats['quantiles'][0]))#N_f * n_m * n_i * n_g

In [ ]:
# modified = np.array(pz_stats['quantiles']).reshape(4, 4, 1000)*100.
# print(np.shape(modified))

In [ ]:
# print(np.shape(np.array(pz_stats[f][0]).reshape(4, 1000)))

In [ ]:
# more_modified = modified * 100.
# mean = np.mean(more_modified, axis=-1)
# print(mean)
# std = np.std(more_modified, axis=-1)
# print(std)

In [ ]:
# # print(np.shape(modified))
# # plt.hist(modified[0][3])
# weird_x = np.log(np.array(floats))

# moment_num = 3
# for s in range(3):
#     f = formats[s]
#     const = 0.1
#     f_factor = const * (s - 1)
#     new_data = np.array(pz_stats[f][moment_num]).reshape(4, 1000)*100.
#     plt.plot(np.exp(weird_x+f_factor), np.median(new_data, axis=-1), linestyle=styles[f], marker=moment_shapes[moment_num], mfc='none', markersize=5, alpha=0.5, color=colors[f])
#     violin = plt.violinplot(list(new_data), np.exp(weird_x+f_factor), showextrema=False, showmeans=False, showmedians=False, widths=np.exp(weird_x+const/2.)-np.exp(weird_x))
# #     for partname in ['cmedians']:
# #         vp = violin[partname]
# #         vp.set_edgecolor(colors[f])
# #         vp.set_linewidth(3)
# # Make the violin body blue with a red border:
#     for vp in violin['bodies']:
#         vp.set_facecolor(coplors[f])
# #         vp.set_edgecolor('k')
# #         vp.set_linewidth(0)
#         vp.set_alpha(0.5)
# plt.semilogx()
# plt.ylim(-50., 50.)

In [ ]:
# print(np.shape(new_data))

In [ ]:
# plt.boxplot(list(new_data), floats, '')

In [ ]:
# print(np.shape(pz_stats['quantiles'][0][0]))

In [ ]:
# print(violin.keys())

In [ ]:
# help(plt.boxplot)

In [ ]: