Problem 3 - Practical HBM

Joint Orbital Period Breakpoint and Eccentricity Distribution Hierarchical Bayesian Model for Eclipsing Binaries with PyJAGS

Simulating a joint eccentricity and Period distribution of Eclpising Binaries from the Kepler Mission

LSSTC DSFP Session 4, September 21st, 2017

Author: Megan I. Shabram, PhD, NASA Postdoctoral Program Fellow, mshabram@gmail.com

In [ ]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import pyjags
import pystan
import pickle
import triangle_linear

from IPython.display import display, Math, Latex
from __future__ import division, print_function
from pandas.tools.plotting import *
from matplotlib import rcParams
rcParams["savefig.dpi"] = 100
rcParams["font.size"] = 20
#plt.style.use('ggplot')

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
#%qtconsole

In [ ]:
## Read in Eclipsing Binary dataset containing projected eccentricity and orbital periods.
df_EB = pd.read_csv('EBs_for_jags.txt', delimiter=r"\s+")
print(df_EB)

print(len(df_EB.loc[:, "H_OBS"]))

## the only columns you need are P, H_OBS, K_OBS, H_SIGMA, and K_SIGMA

In [ ]:
code = '''

model {
    
    # Period break point hyper parameter:
    p_break ~ dunif(0.001,20)

    #Population parameters
    for (j in 1:Nm) {
        e_sigma_a[j] ~ dunif(0, 1)
        e_phi_a[j] <- 1/(e_sigma_a[j]*e_sigma_a[j])
        a[j] <- 1;
    }
    
    e_sigma_b ~ dunif(0.,1.)
    e_phi_b <- 1/(e_sigma_b*e_sigma_b)
    
    f_a ~ ddirch(a[])

    for (n in 1:Ndata){

        c_a[n] ~ dcat(f_a[])

        e_phi_ref[n] <- ifelse(P[n] < p_break, e_phi_b ,e_phi_a[c_a[n]])

        #True planet properties
        h[n] ~ dnorm(0, e_phi_ref[n]) T(-1,1)
        k[n] ~ dnorm(0, e_phi_ref[n]) T(-sqrt(1-h[n]*h[n]),sqrt(1-h[n]*h[n]))
    
        #Observed planet properties
        hhat[n] ~ dnorm(h[n], 1.0/(hhat_sigma[n]*hhat_sigma[n])) T(-1,1)
        khat[n] ~ dnorm(k[n], 1.0/(khat_sigma[n]*khat_sigma[n])) T(-sqrt(1-hhat[n]*hhat[n]),sqrt(1-hhat[n]*hhat[n]))
    }
        
}
'''

In [ ]:
## Load additional JAGS module
pyjags.load_module('glm')
pyjags.load_module('dic')


## See blog post for origination of the adapted analysis tools used here and below:
## https://martynplummer.wordpress.com/2016/01/11/pyjags/

num_chains = 4
iterations = 10000


## data list include only variables in the model
model = pyjags.Model(code, data=dict(Nm=2, Ndata=len(df_EB.loc[:, "P"]), P=df_EB.loc[:, "P"], 
                                     hhat_sigma=df_EB.loc[:, "H_SIGMA"], khat=df_EB.loc[:, "K_OBS"], 
                                     khat_sigma=df_EB.loc[:, "K_SIGMA"]), 
                     chains=num_chains, adapt=1000)

## Code to speed up compute time. This feature might not be 
## well tested in pyjags at this time. 
## threads=4, chains_per_thread=1 

## 500 warmup / burn-in iterations, not used for inference.
model.sample(500, vars=[])

## Run model for desired steps, monitoring hyperparameter variables, and latent variables
## for hierarchical Bayesian model.
## Returns a dictionary with numpy array for each monitored variable.
## Shapes of returned arrays are (... shape of variable ..., iterations, chains).
## samples = model.sample(#iterations per chain here, vars=['e_sigma', 'h'])
samples_JAGS = model.sample(iterations, vars=['p_break','e_sigma_a', 'e_sigma_b', 'f_a' ])

## Code to save, open and use pickled dictionary of samples:
## -- Pickle the data --
#with open('ecc_1_test.pkl', 'wb') as handle:
#   pickle.dump(samples, handle)
## -- Retrieve pickled data --
#with open('ecc_1_test.pkl', 'rb') as handle:
#   retrieved_results = pickle.load(handle)

In [ ]:
iters=iterations
chain_thin = 100
start = int(iters-1000)
esigma_low = np.where(samples_JAGS['e_sigma_a'][0,start::,:] <= samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['e_sigma_a'][0,start::,:], samples_JAGS['e_sigma_a'][1,start::,:])
esigma_hi = np.where(samples_JAGS['e_sigma_a'][0,start::,:] > samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['e_sigma_a'][0,start::,:], samples_JAGS['e_sigma_a'][1,start::,:])
f_low = np.where(samples_JAGS['e_sigma_a'][0,start::,:] <= samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['f_a'][0,start::,:], samples_JAGS['f_a'][1,start::,:])
f_hi = np.where(samples_JAGS['e_sigma_a'][0,start::,:] > samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['f_a'][0,start::,:], samples_JAGS['f_a'][1,start::,:])
print(np.min(f_hi))
plt.hist(f_low)

In [ ]:
## Scatter matrix plot:

## Redefine the trace so that we only vizualize every 10th latent variable element in 
## the scatter_matrix plot below. Vizualizing all 50 is too cumbersome for the scatter
## matrix. 

samples_EB_for_scatter_matrix = {}
numHyperParams = 6
dim = numHyperParams
print(dim)

samples_EB_for_scatter_matrix.update({'$f_{a_{low}}$': f_low})
samples_EB_for_scatter_matrix.update({'$e_{\sigma_{a_{low}}}$': esigma_low})
samples_EB_for_scatter_matrix.update({'$f_{a_{high}}$': f_low})
samples_EB_for_scatter_matrix.update({'$e_{\sigma_{a_{high}}}$': esigma_hi})
samples_EB_for_scatter_matrix.update({'$e_{\sigma_{b}}$': samples_JAGS['e_sigma_b'][0,start::,:]})
samples_EB_for_scatter_matrix.update({'$P_{break}$': samples_JAGS['p_break'][0,start::,:]})

for j, i in samples_EB_for_scatter_matrix.items():
    print(j)
#    print(i)

trace_2 = pd.Panel({k: v for k, v in samples_EB_for_scatter_matrix.items()})

sm = scatter_matrix(trace_2.to_frame(),  color="darkturquoise", alpha=0.2, figsize=(dim*2, dim*2), diagonal='hist',hist_kwds={'bins':25,'histtype':'step', 'edgecolor':'r','linewidth':2})
## y labels size
[plt.setp(item.yaxis.get_label(), 'size', 20) for item in sm.ravel()]
## x labels size 
[plt.setp(item.xaxis.get_label(), 'size', 20) for item in sm.ravel()]
## Change label rotation
## This is helpful for very long labels
#[s.xaxis.label.set_rotation(45) for s in sm.reshape(-1)]
[s.xaxis.label.set_rotation(0) for s in sm.reshape(-1)]
[s.yaxis.label.set_rotation(0) for s in sm.reshape(-1)]
## May need to offset label when rotating to prevent overlap of figure
[s.get_yaxis().set_label_coords(-0.5,0.5) for s in sm.reshape(-1)]
## Hide all ticks
#[s.set_xticks(()) for s in sm.reshape(-1)]
#[s.set_yticks(()) for s in sm.reshape(-1)]

plt.savefig('scatter_matrix_period_break_point.png')

In [ ]:
## Use pandas three dimensional Panel to represent the trace:

trace = pd.Panel({k: v for k, v in samples_EB_for_scatter_matrix.items()})
trace.axes[0].name = 'Variable'
trace.axes[1].name = 'Iteration'
trace.axes[2].name = 'Chain'
 
## Point estimates:
print(trace.to_frame().mean())
 
## Bayesian equal-tailed 95% credible intervals:
print(trace.to_frame().quantile([0.05, 0.95]))
  ## ^ entering the values here could be a good question part
    
def plot(trace, var):
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    fig.suptitle(var, y=0.95, fontsize='xx-large')
 
    ## Marginal posterior density estimate:
    trace[var].plot.density(ax=axes[0])
    axes[0].set_xlabel('Parameter value')
    axes[0].locator_params(tight=True)
 
    ## Autocorrelation for each chain:
    axes[1].set_xlim(0, 100)
    for chain in trace[var].columns:
        autocorrelation_plot(trace[var,:,chain], axes[1], label=chain)
 
    ## Trace plot:
    axes[2].set_ylabel('Parameter value')
    trace[var].plot(ax=axes[2])
 
    ## Save figure
    filename = var.replace("\\", "") 
    filename = filename.replace("/", "") 
    filename = filename.replace("$", "") 
    filename = filename.replace("}", "") 
    filename = filename.replace("{", "") 
    plt.tight_layout(pad=3)
    fig.savefig('Break_point_'+'{}.png'.format(filename))

    
## Display diagnostic plots
for var in trace:
    plot(trace, var)

In [ ]: