In [ ]:
%matplotlib inline
%config IPython.matplotlib.backend = "retina"
from matplotlib import rcParams
rcParams["savefig.dpi"] = 300
rcParams["figure.dpi"] = 300

from celerite import plot_setup
plot_setup.setup(auto=False)

In [ ]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import cho_solve, cho_factor

import celerite
from celerite import terms
from celerite.modeling import Model

class TrueModel(Model):
    parameter_names = ("log_amp", "log_ell", "log_period")

    def get_K(self, x):
        tau = x[:, None] - x[None, :]
        return (
            np.exp(self.log_amp - 0.5 * tau**2 * np.exp(-2.0*self.log_ell)) *
            np.cos(2*np.pi*tau*np.exp(-self.log_period))
        )

    def __call__(self, params, x, y, yerr):
        self.set_parameter_vector(params)
        lp = self.log_prior()
        if not np.isfinite(lp):
            return -np.inf

        K = self.get_K(x)
        K[np.diag_indices_from(K)] += yerr**2
        try:
            factor = cho_factor(K, overwrite_a=True)
        except (np.linalg.LinAlgError, ValueError):
            return -np.inf
        ld = 2.0 * np.sum(np.log(np.diag(factor[0])))
        return -0.5*(np.dot(y, cho_solve(factor, y))+ld) + lp

true_model = TrueModel(log_amp=0.0, log_ell=np.log(5.0), log_period=0.0,
                       bounds=[(-10, 10), (-10, 10), (-10, 10)])
tau = np.linspace(0, 10, 1000)
true_k = true_model.get_K(tau)[0]

np.random.seed(123)

# Simulate a dataset from the true model
np.random.seed(42)
N = 100
t = np.sort(np.random.uniform(0, 20, N))
yerr = 0.5
K = true_model.get_K(t)
K[np.diag_indices_from(K)] += yerr**2
y = np.random.multivariate_normal(np.zeros(N), K)

# Set up the celerite model that we will use to fit - product of two SHOs
log_Q = 1.0
kernel = terms.SHOTerm(
    log_S0=np.log(np.var(y))-2*log_Q,
    log_Q=log_Q,
    log_omega0=np.log(2*np.pi),
    bounds=dict(
        log_S0=(-15, 5),
        log_Q=(-10, 10),
        log_omega0=(-10, 10),
    )
)
kernel *= terms.SHOTerm(
    log_S0=0.0, log_omega0=0.0, log_Q=-0.5*np.log(2),
    bounds=dict(
        log_S0=(-10, 10),
        log_Q=(-10, 10),
        log_omega0=(-5, 5),
    )
)
# kernel.freeze_parameter("k1:log_S0")
kernel.freeze_parameter("k2:log_S0")
kernel.freeze_parameter("k2:log_Q")

gp = celerite.GP(kernel)
gp.compute(t, yerr)
init_params = gp.get_parameter_vector()

plt.errorbar(t, y, yerr=yerr, fmt=".k", lw=1)

In [ ]:
import copy
from scipy.optimize import minimize

def nll(params, gp, y):
    gp.set_parameter_vector(params)
    if not np.isfinite(gp.log_prior()):
        return 1e10
    ll = gp.log_likelihood(y)
    return -ll if np.isfinite(ll) else 1e10

p0 = gp.get_parameter_vector()
soln = minimize(nll, p0, method="L-BFGS-B", args=(gp, y))
gp.set_parameter_vector(soln.x)
print(soln)

ml_gp = copy.deepcopy(gp)
ml_gp.log_likelihood(y)

In [ ]:
bics = []
for i in range(2):
    for j in range(4):
        kernel2 = terms.SHOTerm(
            log_S0=np.log(np.var(y))-2*log_Q,
            log_Q=log_Q,
            log_omega0=np.log(2*np.pi),
            bounds=dict(
                log_S0=(-15, 5),
                log_Q=(-10, 10),
                log_omega0=(-10, 10),
            )
        )
        if i == 0:
            for _ in range(j):
                kernel2 += terms.SHOTerm(
                    log_S0=np.log(np.var(y))-2*log_Q,
                    log_Q=log_Q,
                    log_omega0=np.log(2*np.pi),
                    bounds=dict(
                        log_S0=(-15, 5),
                        log_Q=(-10, 10),
                        log_omega0=(-10, 10),
                    )
                )
        else:
            for _ in range(j):
                kernel2 *= terms.SHOTerm(
                    log_S0=0.0, log_omega0=0.0, log_Q=-0.5*np.log(2),
                    bounds=dict(
                        log_S0=(-15, 5),
                        log_Q=(-10, 10),
                        log_omega0=(-10, 10),
                    )
                )
                kernel2.k2.freeze_parameter("log_S0")
                kernel2.k2.freeze_parameter("log_Q")

        gp2 = celerite.GP(kernel2)
        gp2.compute(t, yerr)

        def nll2(params, gp, y):
            gp.set_parameter_vector(params)
            if not np.isfinite(gp.log_prior()):
                return 1e10
            ll = gp.log_likelihood(y)
            return -ll if np.isfinite(ll) else 1e10

        p0 = gp2.get_parameter_vector()
        p0 += 1e-4 * np.random.randn(len(p0))
        soln2 = minimize(nll2, p0, method="L-BFGS-B", args=(gp2, y))
        gp2.set_parameter_vector(soln2.x)
        bics.append(2 * soln2.fun + len(gp2.get_parameter_vector()) * np.log(len(t)))

In [ ]:
bic0 = 2 * soln.fun + len(gp.get_parameter_vector()) * np.log(len(t))
s = np.array(bics[:4])
p = np.array(bics[4:])

fig, ax = plt.subplots(1, 1, figsize=plot_setup.get_figsize(1, 1))
ax.plot(np.arange(4)+1, s, "o-", label="sum")
ax.plot(np.arange(4)+1, p, "s-", label="product")
ax.axhline(bic0, color="k", ls="dashed")
ax.legend()
ax.set_ylabel("BIC")
ax.set_xlabel("number of terms")
fig.savefig("bic.pdf", bbox_inches="tight")

In [ ]:
import emcee

# Do the MCMC with the correct model
ndim = 3
nwalkers = 32
coords = true_model.get_parameter_vector() + 1e-4 * np.random.randn(nwalkers,
                                                                    ndim)
true_sampler = emcee.EnsembleSampler(nwalkers, ndim, true_model,
                                     args=(t, y, yerr))
coords, _, _ = true_sampler.run_mcmc(coords, 500)
true_sampler.reset()
coords, _, _ = true_sampler.run_mcmc(coords, 2000);

In [ ]:
# Do the MCMC with the (wrong) celerite model
def log_probability(params):
    gp.set_parameter_vector(params)
    
    lp = gp.log_prior()
    if not np.isfinite(lp):
        return -np.inf
    
    ll = gp.log_likelihood(y)
    return ll + lp if np.isfinite(ll) else -np.inf

ndim = len(soln.x)
nwalkers = 32
coords = soln.x + 1e-4 * np.random.randn(nwalkers, ndim)

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability)
coords, _, _ = sampler.run_mcmc(coords, 500)
sampler.reset()
coords, _, _ = sampler.run_mcmc(coords, 2000)

In [ ]:
samples = sampler.flatchain
for s in samples[np.random.randint(len(samples), size=50)]:
    gp.set_parameter_vector(s)
    plt.plot(tau, gp.kernel.get_value(tau), "k", lw=0.5, alpha=0.3)
plt.plot(tau, true_k, ":")
plt.plot(tau, ml_gp.kernel.get_value(tau))
plt.ylim(-2, 2)
plt.ylabel(r"$k(\tau)$")
plt.xlabel(r"$\tau$")
plt.savefig("kernel-wrong.png")

In [ ]:
x = np.linspace(-0.5, 20.5, 1000)
mu, var = ml_gp.predict(y, x, return_var=True)
std = np.sqrt(var)

fig = plt.figure(figsize=plot_setup.get_figsize(1, 2))

ax1 = plt.subplot2grid((3, 2), (0, 0), rowspan=2)
ax2 = plt.subplot2grid((3, 2), (2, 0), rowspan=1)
ax3 = plt.subplot2grid((3, 2), (0, 1), rowspan=3)
fig.subplots_adjust(hspace=0, wspace=0.1)

ax1.errorbar(t, y, yerr=yerr, fmt=".k", lw=1)
ax1.plot(x, mu)
ax1.fill_between(x, mu+std, mu-std, alpha=0.5, edgecolor="none", zorder=100)
ax1.set_xticklabels([])

ax1.annotate("simulated data", xy=(0, 1), xycoords="axes fraction",
             xytext=(5, -5), textcoords="offset points",
             ha="left", va="top")
ax1.annotate("N = {0}".format(len(t)), xy=(0, 0),
             xycoords="axes fraction",
             xytext=(5, 5), textcoords="offset points",
             ha="left", va="bottom")

pred_mu, pred_var = ml_gp.predict(y, return_var=True)
std = np.sqrt(yerr**2 + pred_var)
ax2.errorbar(t, y - pred_mu, yerr=std, fmt=".k", lw=1)
ax2.axhline(0.0, color="k", lw=0.75)

ax1.set_ylim(-3.25, 3.25)
ax1.set_xlim(-0.5, 20.5)
ax2.set_ylim(-2.1, 2.1)
ax2.set_xlim(-0.5, 20.5)

ax2.set_xlabel("time [day]")
ax1.set_ylabel("relative flux [ppm]")
ax2.set_ylabel("residuals")

for ax in [ax1, ax2]:
    ax.yaxis.set_label_coords(-0.2, 0.5)

n, b, p = ax3.hist(np.exp(-sampler.flatchain[:, -2])*(2*np.pi), 20,
                   color="k", histtype="step", lw=2, normed=True)
ax3.hist(np.exp(true_sampler.flatchain[:, -1]), b,
         color=plot_setup.COLORS["MODEL_1"],
         lw=2, histtype="step", normed=True, ls="dashed")
ax3.yaxis.set_major_locator(plt.NullLocator())
ax3.set_xlim(b.min(), b.max())
ax3.axvline(1.0, color="k", lw=3, alpha=0.5)
ax3.set_xlabel("period [day]")

ax2.xaxis.set_label_coords(0.5, -0.3)
ax3.xaxis.set_label_coords(0.5, -0.1)

fig.savefig("wrong-qpo.pdf", bbox_inches="tight")

In [ ]:
p0 = true_model.get_parameter_vector()
slow_timing = %timeit -o true_model(p0, t, y, yerr)

p0 = gp.get_parameter_vector()
fast_timing = %timeit -o log_probability(p0)

In [ ]:
tau = sampler.get_autocorr_time(c=3)[-2]
neff = len(sampler.flatchain) / tau
tau, neff

In [ ]:
import json
c = gp.kernel.coefficients
with open("wrong-qpo.json", "w") as f:
    json.dump(dict(
        N=len(t),
        J=len(c[0]) + len(c[2]),
        tau=tau,
        neff=neff,
        time=fast_timing.average,
        direct_time=slow_timing.average,
        nwalkers=nwalkers,
        ndim=ndim,
        nburn=500,
        nsteps=2000,
    ), f)

In [ ]:
name_map = {
    "kernel:k1:log_S0": r"$\ln(S_0)$",
    "kernel:k1:log_Q": r"$\ln(Q)$",
    "kernel:k1:log_omega0": r"$\ln(\omega_1)$",
    "kernel:k2:log_omega0": r"$\ln(\omega_2)$",
}
params = list(zip(
    (name_map[n] for n in gp.get_parameter_names()),
    gp.get_parameter_bounds()
))
with open("wrong-qpo-params.json", "w") as f:
    json.dump(params, f)

In [ ]: