In [ ]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt
import pyjags
import pystan
import pickle
import triangle_linear
from IPython.display import display, Math, Latex
from __future__ import division, print_function
from pandas.tools.plotting import *
from matplotlib import rcParams
rcParams["savefig.dpi"] = 100
rcParams["font.size"] = 20
#plt.style.use('ggplot')
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
#%qtconsole
In [ ]:
## Read in Eclipsing Binary dataset containing projected eccentricity and orbital periods.
df_EB = pd.read_csv('EBs_for_jags.txt', delimiter=r"\s+")
print(df_EB)
print(len(df_EB.loc[:, "H_OBS"]))
## the only columns you need are P, H_OBS, K_OBS, H_SIGMA, and K_SIGMA
In [ ]:
code = '''
model {
# Period break point hyper parameter:
p_break ~ dunif(0.001,20)
#Population parameters
for (j in 1:Nm) {
e_sigma_a[j] ~ dunif(0, 1)
e_phi_a[j] <- 1/(e_sigma_a[j]*e_sigma_a[j])
a[j] <- 1;
}
e_sigma_b ~ dunif(0.,1.)
e_phi_b <- 1/(e_sigma_b*e_sigma_b)
f_a ~ ddirch(a[])
for (n in 1:Ndata){
c_a[n] ~ dcat(f_a[])
e_phi_ref[n] <- ifelse(P[n] < p_break, e_phi_b ,e_phi_a[c_a[n]])
#True planet properties
h[n] ~ dnorm(0, e_phi_ref[n]) T(-1,1)
k[n] ~ dnorm(0, e_phi_ref[n]) T(-sqrt(1-h[n]*h[n]),sqrt(1-h[n]*h[n]))
#Observed planet properties
hhat[n] ~ dnorm(h[n], 1.0/(hhat_sigma[n]*hhat_sigma[n])) T(-1,1)
khat[n] ~ dnorm(k[n], 1.0/(khat_sigma[n]*khat_sigma[n])) T(-sqrt(1-hhat[n]*hhat[n]),sqrt(1-hhat[n]*hhat[n]))
}
}
'''
In [ ]:
## Load additional JAGS module
pyjags.load_module('glm')
pyjags.load_module('dic')
## See blog post for origination of the adapted analysis tools used here and below:
## https://martynplummer.wordpress.com/2016/01/11/pyjags/
num_chains = 4
iterations = 10000
## data list include only variables in the model
model = pyjags.Model(code, data=dict(Nm=2, Ndata=len(df_EB.loc[:, "P"]), P=df_EB.loc[:, "P"],
hhat_sigma=df_EB.loc[:, "H_SIGMA"], khat=df_EB.loc[:, "K_OBS"],
khat_sigma=df_EB.loc[:, "K_SIGMA"]),
chains=num_chains, adapt=1000)
## Code to speed up compute time. This feature might not be
## well tested in pyjags at this time.
## threads=4, chains_per_thread=1
## 500 warmup / burn-in iterations, not used for inference.
model.sample(500, vars=[])
## Run model for desired steps, monitoring hyperparameter variables, and latent variables
## for hierarchical Bayesian model.
## Returns a dictionary with numpy array for each monitored variable.
## Shapes of returned arrays are (... shape of variable ..., iterations, chains).
## samples = model.sample(#iterations per chain here, vars=['e_sigma', 'h'])
samples_JAGS = model.sample(iterations, vars=['p_break','e_sigma_a', 'e_sigma_b', 'f_a' ])
## Code to save, open and use pickled dictionary of samples:
## -- Pickle the data --
#with open('ecc_1_test.pkl', 'wb') as handle:
# pickle.dump(samples, handle)
## -- Retrieve pickled data --
#with open('ecc_1_test.pkl', 'rb') as handle:
# retrieved_results = pickle.load(handle)
In [ ]:
iters=iterations
chain_thin = 100
start = int(iters-1000)
esigma_low = np.where(samples_JAGS['e_sigma_a'][0,start::,:] <= samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['e_sigma_a'][0,start::,:], samples_JAGS['e_sigma_a'][1,start::,:])
esigma_hi = np.where(samples_JAGS['e_sigma_a'][0,start::,:] > samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['e_sigma_a'][0,start::,:], samples_JAGS['e_sigma_a'][1,start::,:])
f_low = np.where(samples_JAGS['e_sigma_a'][0,start::,:] <= samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['f_a'][0,start::,:], samples_JAGS['f_a'][1,start::,:])
f_hi = np.where(samples_JAGS['e_sigma_a'][0,start::,:] > samples_JAGS['e_sigma_a'][1,start::,:], samples_JAGS['f_a'][0,start::,:], samples_JAGS['f_a'][1,start::,:])
print(np.min(f_hi))
plt.hist(f_low)
In [ ]:
## Scatter matrix plot:
## Redefine the trace so that we only vizualize every 10th latent variable element in
## the scatter_matrix plot below. Vizualizing all 50 is too cumbersome for the scatter
## matrix.
samples_EB_for_scatter_matrix = {}
numHyperParams = 6
dim = numHyperParams
print(dim)
samples_EB_for_scatter_matrix.update({'$f_{a_{low}}$': f_low})
samples_EB_for_scatter_matrix.update({'$e_{\sigma_{a_{low}}}$': esigma_low})
samples_EB_for_scatter_matrix.update({'$f_{a_{high}}$': f_low})
samples_EB_for_scatter_matrix.update({'$e_{\sigma_{a_{high}}}$': esigma_hi})
samples_EB_for_scatter_matrix.update({'$e_{\sigma_{b}}$': samples_JAGS['e_sigma_b'][0,start::,:]})
samples_EB_for_scatter_matrix.update({'$P_{break}$': samples_JAGS['p_break'][0,start::,:]})
for j, i in samples_EB_for_scatter_matrix.items():
print(j)
# print(i)
trace_2 = pd.Panel({k: v for k, v in samples_EB_for_scatter_matrix.items()})
sm = scatter_matrix(trace_2.to_frame(), color="darkturquoise", alpha=0.2, figsize=(dim*2, dim*2), diagonal='hist',hist_kwds={'bins':25,'histtype':'step', 'edgecolor':'r','linewidth':2})
## y labels size
[plt.setp(item.yaxis.get_label(), 'size', 20) for item in sm.ravel()]
## x labels size
[plt.setp(item.xaxis.get_label(), 'size', 20) for item in sm.ravel()]
## Change label rotation
## This is helpful for very long labels
#[s.xaxis.label.set_rotation(45) for s in sm.reshape(-1)]
[s.xaxis.label.set_rotation(0) for s in sm.reshape(-1)]
[s.yaxis.label.set_rotation(0) for s in sm.reshape(-1)]
## May need to offset label when rotating to prevent overlap of figure
[s.get_yaxis().set_label_coords(-0.5,0.5) for s in sm.reshape(-1)]
## Hide all ticks
#[s.set_xticks(()) for s in sm.reshape(-1)]
#[s.set_yticks(()) for s in sm.reshape(-1)]
plt.savefig('scatter_matrix_period_break_point.png')
In [ ]:
## Use pandas three dimensional Panel to represent the trace:
trace = pd.Panel({k: v for k, v in samples_EB_for_scatter_matrix.items()})
trace.axes[0].name = 'Variable'
trace.axes[1].name = 'Iteration'
trace.axes[2].name = 'Chain'
## Point estimates:
print(trace.to_frame().mean())
## Bayesian equal-tailed 95% credible intervals:
print(trace.to_frame().quantile([0.05, 0.95]))
## ^ entering the values here could be a good question part
def plot(trace, var):
fig, axes = plt.subplots(1, 3, figsize=(9, 3))
fig.suptitle(var, y=0.95, fontsize='xx-large')
## Marginal posterior density estimate:
trace[var].plot.density(ax=axes[0])
axes[0].set_xlabel('Parameter value')
axes[0].locator_params(tight=True)
## Autocorrelation for each chain:
axes[1].set_xlim(0, 100)
for chain in trace[var].columns:
autocorrelation_plot(trace[var,:,chain], axes[1], label=chain)
## Trace plot:
axes[2].set_ylabel('Parameter value')
trace[var].plot(ax=axes[2])
## Save figure
filename = var.replace("\\", "")
filename = filename.replace("/", "")
filename = filename.replace("$", "")
filename = filename.replace("}", "")
filename = filename.replace("{", "")
plt.tight_layout(pad=3)
fig.savefig('Break_point_'+'{}.png'.format(filename))
## Display diagnostic plots
for var in trace:
plot(trace, var)
In [ ]: