Comparison of Fits

Still need: Compare the values of inferred parameters in both the MCMC and nested sampling case, and the errors/covariances (other variance statistics?) in terms of numbers that are reported and nested sampling.

Initialization


In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table

import sncosmo
import simulate_lsst as sl
import triangle

function definitions:


In [2]:
# Helper Functions to load the SNANA Data
def snanadatafile(snanafileroot, filetype='head', location='./'):
    '''
    obtain the name of the head of phot file of an SNANA simulation and dataset
    
    '''
    import os
    suffix = '_HEAD.FITS'
    if filetype == 'phot':
        suffix = '_PHOT.FITS'
    fname = snanafileroot + suffix
    return os.path.join(location, fname)
    
def loadSNANAData(snanafileroot, location='/.', snids=None, n=None):
    '''
    load a SNANA fits file into a list of `~astropy.Table.table` objects.
    
    
    Parameters
    ----------
    snanafileroot: string, mandatory
        root file name for the SNANA which is the prefix to '_HEAD.FITS', or '_PHOT.FITS'
    location: string, optional defaults to current working directory './' 
        directory where the head and phot files are located
    snids: integer/string, optional defaults to None
        if not None, only SN observations corresponding to SNID snid are loaded
    n: Integer, defaults to None
        if not None, only the first n SN light curves are loaded
        
        
    Returns: data
        list of `~astropy.Table.Table` each Table containing a light curve of a SN. 
        
    ..note: The column names of the SNANA data files are not reformated for SNCosmo use
    '''
    headfile = snanadatafile(snanafileroot, filetype='head', location=location)
    photfile = snanadatafile(snanafileroot, filetype='phot', location=location)
    data = sncosmo.read_snana_fits(head_file=headfile, phot_file=photfile, snids=snids, n=None)
    return data

In [3]:
def addbands(sn, lsstbands, replacement):
    '''
    add a column called 'band' to the `~astropy.Table.Table` by 
    applying the map of lsstbands to replacements to the content
    of a column called 'FLT' 
    
    Parameters
    ----------
    sn: `~astropy.Table.Table` obtained by reading an SNANA light curve
    lsstbands: list of strings, mandatory
        list of strings representing the filters in sn, which can be found
        by `np.unique(sn['FLT'])
    replacements: list of strings, mandatory
        list of strings representing the filters as registered in SNCosmo in
        the same order as lsstbands
        
    Returns
    -------
    `~astropy.Table.Table` with 'FLT' column removed and 'band' column added
    '''
    filterarray = np.zeros(len(sn), dtype='S8')
    for i, flt in enumerate(lsstbands):
        mask = sn['FLT']==flt
        filterarray[mask] = replacement[i]
        band = Table.Column(filterarray, name='band', dtype='S8')
    sn.add_column(band)
    sn.remove_column('FLT')

In [4]:
def reformat_SNANASN(sn, lsstbands=None, replacements=None):
    '''
    reformat an SNANA light curve for use with SNCosmo
    
    Parameters
    ----------
    sn: `~astropy.Table.Table`, mandatory
        representing SNANA light curve
    lsstbands: list of strings, optional defaults to None
        list of unique strings in any of the 'FLT' column of SNANA files
    replacements: list of strings, optional defaults to None
        list of unique strings of the same size as lsstbands, and indexed in the 
        same order representing the keys in the sncosmo.bandpass registry for the
        same filters
    
    
    Returns
    -------
    `astropy.Table.Table` of the SNANA light curve reformatted for SNCosmo 
    '''
    #rename cols to names SNCosmo understands
    sn.rename_column("FLUXCAL",'flux')
    sn.rename_column("FLUXCALERR", 'fluxerr')
    #Add in SNANA magic ZP and sys
    sn["ZP"] = 27.5
    sn["ZPSYS"] = 'ab'
    # sn.rename_column('FLT', 'band')
    
    #Set up a truth dictionary from the metadata to set sim models
    truth ={}
    truth["c"] = sn.meta["SIM_SALT2c"]
    truth["x0"] = sn.meta["SIM_SALT2x0"]*10**(-0.4 * 0.27)
    truth["x1"] = sn.meta["SIM_SALT2x1"]
    truth["t0"] = sn.meta["SIM_PEAKMJD"]
    truth["mwebv"] = sn.meta["SIM_MWEBV"]
    truth["z"] = sn.meta["REDSHIFT_FINAL"]
    if replacements is not None:
        addbands(sn, lsstbands, replacements)
    return sn, truth

Load Data File and Initialize Model


In [5]:
sl.getlsstbandpassobjs()
lsstdeeplocation='/Users/lisaleemcb/data/'
data = loadSNANAData(snanafileroot='LSST_Ia', location=lsstdeeplocation)
lsstbands = ['u', 'g', 'r', 'i', 'z', 'Y']
replacement = ['LSST_' + flt for flt in lsstbands]

In [6]:
source = sncosmo.get_source('salt2-extended')
dust = sncosmo.CCM89Dust()
model = sncosmo.Model(source=source,
                      effects=[dust, dust],
                      effect_names=['host', 'mw'],
                      effect_frames=['rest', 'obs'])

In [7]:
sn, truth = reformat_SNANASN(data[1], lsstbands, replacement)
snr = sn['flux']/sn['fluxerr']
plt.hist(snr[snr > 3.])


Out[7]:
(array([ 4.,  1.,  2.,  2.,  1.,  1.,  1.,  1.,  1.,  1.]),
 array([  3.07398057,   3.9444258 ,   4.81487103,   5.68531625,
          6.55576148,   7.42620671,   8.29665194,   9.16709716,
         10.03754239,  10.90798762,  11.77843285]),
 <a list of 10 Patch objects>)

In [8]:
model.set(**truth)
print model


<Model at 0x10d848190>
source:
  class      : SALT2Source
  name       : 'salt2-extended'
  version    : 1.0
  phases     : [-20, .., 50] days
  wavelengths: [300, .., 18000] Angstroms
effect (name='host' frame='rest'):
  class           : CCM89Dust
  wavelength range: [909.09, 33333.3] Angstroms
effect (name='mw' frame='obs'):
  class           : CCM89Dust
  wavelength range: [909.09, 33333.3] Angstroms
parameters:
  z       = 1.0665580034255981
  t0      = 50916.4609375
  x0      = 1.3218102045744581e-06
  x1      = -0.70846551656723022
  c       = 0.017801763489842415
  hostebv = 0.0
  hostr_v = 3.1000000000000001
  mwebv   = 0.082109816372394562
  mwr_v   = 3.1000000000000001

Run Fits


In [9]:
fit_res, fit_model = sncosmo.fit_lc(sn, model, vparam_names=['t0', 'x0', 'x1', 'c'], 
                                    bounds={'c':(-0.3, 0.3), 'x1':(-3.0, 3.0)}, minsnr=3.0)

In [10]:
# note that the resulting number of samples is equal to the product of nwalkers and nsamples
# i.e. len(res.samples) = nwalkers * nsamples

mcmc_res, mcmc_model = sncosmo.mcmc_lc(sn, model, vparam_names=['t0', 'x0', 'x1', 'c'], 
                                    bounds={'c':(-0.3, 0.3), 'x1':(-3.0, 3.0)}, minsnr=3.0)

In [11]:
nest_res, nest_model = sncosmo.nest_lc(sn, model, vparam_names=['t0', 'x0', 'x1', 'c'], 
                                    bounds={'c':(-0.3, 0.3), 'x1':(-3.0, 3.0)}, guess_amplitude_bound=True, minsnr=3.0, 
                                    verbose=True)


 iter=  1849 logz=-60.624227calls=6072 time=430.839s

Visualize Fits

First let's plot the light curve with all the fitted models


In [12]:
fig_all = sncosmo.plot_lc(sn, [model, mcmc_model, nest_model], pulls=True)



In [13]:
# just syntactic sugar

mcmc = mcmc_res.vparam_names
nest = nest_res.vparam_names

In [14]:
# initialize triangle plot

# without host ext
mcmc_ndim, mcmc_nsamples = len(mcmc), len(mcmc_res.samples)
mcmc_samples = mcmc_res.samples

print "number of mcmc dimensions:", mcmc_ndim
print "number of mcmc samples:", mcmc_nsamples

nest_ndim, nest_nsamples = len(nest), len(nest_res.samples)
nest_samples = nest_res.samples

print
print "number of nest dimensions:", nest_ndim
print "number of nest samples:", nest_nsamples

# with host ext
#mcmc_ext_ndim, mcmc_ext_nsamples = len(mcmc_ext_res.vparam_names), len(mcmc_ext_res.samples)
#mcmc_ext_samples = mcmc_ext_res.samples

#nest_ext_ndim, nest_ext_nsamples = len(nest_ext_res.vparam_names), len(nest_ext_res.samples)
#nest_ext_samples = nest_ext_res.samples


number of mcmc dimensions: 4
number of mcmc samples: 10000

number of nest dimensions: 4
number of nest samples: 1950

MCMC Triangle Plot


In [16]:
print "True Parameters"
print mcmc[0] + ":", model.get(mcmc[0])
print mcmc[1] + ":", model.get(mcmc[1])
print mcmc[2] + ":", model.get(mcmc[2])
print mcmc[3] + ":", model.get(mcmc[3])

figure_mcmc = triangle.corner(mcmc_samples, labels=[mcmc[0], mcmc[1], mcmc[2], mcmc[3]],
                         truths=[model.get(mcmc[0]), model.get(mcmc[1]),
                                 model.get(mcmc[2]), model.get(mcmc[3])],
                         range=mcmc_ndim*[0.9999],
                         show_titles=True, title_args={"fontsize": 12})

figure_mcmc.gca().annotate("mcmc sampling", xy=(0.5, 1.0), xycoords="figure fraction",
                      xytext=(0, -5), textcoords="offset points",
                      ha="center", va="top")

axes = figure_mcmc.axes


True Parameters
t0: 50916.4609375
x0: 1.32181020457e-06
x1: -0.708465516567
c: 0.0178017634898

Nest Triangle Plot


In [17]:
print "True Parameters"
print nest[0] + ":", model.get(nest[0])
print nest[1] + ":", model.get(nest[1])
print nest[2] + ":", model.get(nest[2])
print nest[3] + ":", model.get(nest[3])

print "niter: total number of iterations", nest_res.niter
print "ncall: total number of likelihood function calls", nest_res.ncall

figure = triangle.corner(nest_samples, labels=[nest[0], nest[1], nest[2], nest[3]], 
                         truths=[model.get(nest[0]), model.get(nest[1]),
                                 model.get(nest[2]), model.get(nest[3])],
                         weights=nest_res.weights, range=nest_ndim*[0.9999],
                         show_titles=True, title_args={"fontsize": 12})

figure.gca().annotate("nest sampling", xy=(0.5, 1.0), xycoords="figure fraction",
                      xytext=(0, -5), textcoords="offset points",
                      ha="center", va="top")


True Parameters
t0: 50916.4609375
x0: 1.32181020457e-06
x1: -0.708465516567
c: 0.0178017634898
niter: total number of iterations 1850
ncall: total number of likelihood function calls 6072
Out[17]:
<matplotlib.text.Annotation at 0x11033e5d0>

Trace Plots


In [18]:
trace_fig = plt.figure(figsize=(20,8))

mcmc1 = trace_fig.add_subplot(241)
mcmc2 = trace_fig.add_subplot(242)
mcmc3 = trace_fig.add_subplot(243)
mcmc4 = trace_fig.add_subplot(244)

nest1 = trace_fig.add_subplot(245)
nest2 = trace_fig.add_subplot(246)
nest3 = trace_fig.add_subplot(247)
nest4 = trace_fig.add_subplot(248)

mcmc1.plot(mcmc_samples[:,0])
mcmc2.plot(mcmc_samples[:,1])
mcmc3.plot(mcmc_samples[:,2])
mcmc4.plot(mcmc_samples[:,3])

mcmc1.set_title('mcmc: ' + mcmc[0])
mcmc2.set_title('mcmc: ' + mcmc[1])
mcmc3.set_title('mcmc: ' + mcmc[2])
mcmc4.set_title('mcmc: ' + mcmc[3])

nest1.plot(nest_samples[:,0])
nest2.plot(nest_samples[:,1])
nest3.plot(nest_samples[:,2])
nest4.plot(nest_samples[:,3])

nest1.set_title('nest: ' + nest[0])
nest2.set_title('nest: ' + nest[1])
nest3.set_title('nest: ' + nest[2])
nest4.set_title('nest: ' + nest[3])

trace_fig.tight_layout()



In [ ]: