In [ ]:
import pickle
import glob
import os
from os import path
from astropy.io import fits
import astropy.units as u
from astropy.table import Table
from astropy.time import Time

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import h5py
from scipy.stats import scoreatpercentile

from twoface.io import load_samples

In [ ]:
def compute_stats(group):
    samples = load_samples(group)
    lnP = np.log(samples['P'].to(u.day).value)
    pers = scoreatpercentile(lnP, [0.5, 99.5, # 99
                                   2.5, 97.5, # 95
                                   5, 95, # 90
                                   17, 83, # 68
                                   50])
    
    try:
        true_lnP = np.log(group.attrs['P'])
    except:
        true_lnP = np.nan

    return pers[1]-pers[0], pers[3]-pers[2], pers[5]-pers[4], pers[7]-pers[6], pers[-1] - true_lnP

width_names = ['99th', '95th', '90th', '68th']

In [ ]:
for fn in glob.glob('cache/*-128.hdf5'):
    stats_fn = '{0}-stats.pickle'.format(path.splitext(fn)[0])
    print(stats_fn)

In [ ]:
for fn in glob.glob('cache/*-128.hdf5'):
    ecccirc, _, loguniform, period = path.splitext(path.basename(fn))[0].split('-')
    
    stats_fn = '{0}-stats.pickle'.format(path.splitext(fn)[0])
    if not path.exists(stats_fn):
        with h5py.File(fn) as f:
            Ns = []
            for key in f:
                N, _ = map(int, key.split('-'))
                Ns.append(N)
            Ns = np.unique(Ns)
            NORBITS = len(f) // len(Ns)

            all_stats = dict()
            for key in f:
                N, i = map(int, key.split('-'))
                if N not in all_stats:
                    all_stats[N] = np.full((NORBITS, 5), np.nan) 
                all_stats[N][i] = compute_stats(f[key])
        
        with open(stats_fn, 'wb') as f2:
            pickle.dump(all_stats, f2)
            
    with open(stats_fn, 'rb') as f2:
        all_stats = pickle.load(f2)
    
    # What to plot:
    width_idx = (3, 1)

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))

    for N in sorted(list(all_stats.keys())):
        y_val = all_stats[N][:,width_idx[0]] / all_stats[N][:,width_idx[1]] - 0.5
        ax.scatter(np.full(NORBITS, N) + np.random.uniform(-0.2, 0.2, size=NORBITS), 
                   y_val, alpha=0.25, marker='.', color='k', linewidth=0)

    ax.set_xlabel('$N$ epochs')
    ax.set_ylabel(r'"Gaussianity"')
    ax.set_xlim(2.25, 12.75)
    ax.set_ylim(-0.55, 0.55)
    ax.xaxis.set_ticks(np.arange(3, 12+1))
    
    ax.axhline(0., linestyle='dashed', zorder=-100, linewidth=2, alpha=0.3, color='tab:blue')
    
    fig.savefig('plots/{2}-{0}-{1}.png'.format(ecccirc, loguniform, period), dpi=256)
    del fig

In [ ]:
# with h5py.File('cache/circ-samples-uniform-128.hdf5') as f:
#     Ns = []
#     for key in f:
#         N, _ = map(int, key.split('-'))
#         Ns.append(N)
#     Ns = np.unique(Ns)
#     NORBITS = len(f) // len(Ns)
        
#     all_stats = dict()
#     for key in f:
#         N, i = map(int, key.split('-'))
#         if N not in all_stats:
#             all_stats[N] = np.full((NORBITS, 5), np.nan) 
#         all_stats[N][i] = compute_stats(f[key])

In [ ]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

for N in sorted(list(all_stats.keys())):
    y_val = np.abs(all_stats[N][:,-1])
    c_val = all_stats[N][:,width_idx[0]] / all_stats[N][:,width_idx[1]]
    plt.scatter(np.full(NORBITS, N) + np.random.uniform(-0.2, 0.2, size=NORBITS), 
                y_val, alpha=0.25, marker='.', c=np.log10(c_val),
                vmin=-1, vmax=0)

# plt.ylim(-0.05, 1.)
plt.xlabel('$N$ epochs')
plt.ylabel(r'$\epsilon_{\ln P}$')


In [ ]:
with h5py.File('cache/circ-samples-uniform-128.hdf5') as f:
    Ns = []
    for key in f:
        N, _ = map(int, key.split('-'))
        Ns.append(N)
    Ns = np.unique(Ns)
    NORBITS = len(f) // len(Ns)
        
    all_stats = dict()
    for key in f:
        N, i = map(int, key.split('-'))
        if N not in all_stats:
            all_stats[N] = np.full((NORBITS, 5), np.nan) 
        all_stats[N][i] = compute_stats(f[key])

In [ ]:
_y = all_stats[10][:, 3] / all_stats[10][:, 1]
np.where(np.abs(_y - 1.) < 0.05)

In [ ]:
with h5py.File('cache/circ-samples-uniform-128.hdf5') as f:
    samples = load_samples(f['10-12'])
    
plt.hist(np.log(samples['P'].to(u.day).value), bins=64);
plt.xlabel(r'$\ln P$')
plt.ylabel('$N$ samples')
plt.tight_layout()
plt.savefig('plots/lnP_high_68_90.png', dpi=256)

In [ ]:
with h5py.File('cache/circ-samples-uniform-128.hdf5') as f:
    samples = load_samples(f['7-6'])
    
plt.hist(np.log(samples['P'].to(u.day).value), bins=64);
plt.xlabel(r'$\ln P$')
plt.ylabel('$N$ samples')
plt.tight_layout()
plt.savefig('plots/lnP_low_68_90.png', dpi=256)


In [ ]:
from thejoker import JokerParams, TheJoker, RVData
from thejoker.plot import plot_rv_curves
from twobody import KeplerOrbit

In [ ]:
orb = KeplerOrbit(P=150*u.day, e=0.35, M0=35*u.deg, omega=92*u.deg)

In [ ]:
np.random.seed(42)

t0 = Time('2013-01-01')
baseline = 2 * u.yr # similar to APOGEE2

K = 2 * u.km/u.s
err = 150 * u.m/u.s

size = 10
t = Time(np.random.uniform(t0.mjd, (t0 + baseline).mjd, size=size), format='mjd')
t = t[np.argsort(t.mjd)]

rv = K * orb.unscaled_radial_velocity(t)
data = RVData(t=t, rv=rv, stddev=np.ones_like(rv.value) * err)

In [ ]:
_ = data.plot()

In [ ]:
pars = JokerParams(P_min=8*u.day, P_max=1024*u.day)
joker = TheJoker(pars)

In [ ]:
dt = data.t.mjd.max() - data.t.mjd.min()
t_grid = Time(np.linspace(data.t.mjd.min()-dt/8., data.t.mjd.max()+dt/4, 4096), format='mjd')

In [ ]:
idx = np.arange(len(data))

xlims = (t_grid.mjd.min(), t_grid.mjd.max())
_rv = data.rv.to(u.km/u.s).value
h = np.ptp(_rv)
ylims = (_rv.min()-2*h, _rv.max()+2*h)

for i in range(len(idx)-3):
    print(idx)
    _data = data[idx]
    samples = joker.iterative_rejection_sample(_data, n_requested_samples=128, 
                                               n_prior_samples=1000000)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    plot_kwargs = dict(color='#aaaaaa', alpha=0.25, linewidth=0.5)
    fig = plot_rv_curves(samples, t_grid=t_grid, n_plot=128, 
                         plot_kwargs=plot_kwargs, ax=ax)
    
    # Darken the shortest period sample
    dark_style = dict(color='#333333', alpha=0.5, linewidth=0.5, zorder=10)

    P_min_samples = samples[samples['P'].argmin()]
    plot_rv_curves(P_min_samples, t_grid, rv_unit=u.km/u.s, ax=ax,
                   n_plot=1, plot_kwargs=dark_style)

    # Darken the longest period sample
    P_max_samples = samples[samples['P'].argmax()]
    plot_rv_curves(P_max_samples, t_grid, rv_unit=u.km/u.s, ax=ax,
                   n_plot=1, plot_kwargs=dark_style)
    
    _ = _data.plot(ax=fig.axes[0], marker='.')
    
    fig.axes[0].set_xlim(xlims)
    fig.axes[0].set_ylim(ylims)
    
    fig.tight_layout()
    fig.savefig('plots/{0}-orbits.png'.format(i), dpi=256)
    
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    ax.scatter(samples['P'].to(u.day).value, samples['e'], 
               marker='.', alpha=0.5, linewidth=0)
    ax.set_xscale('log')
    ax.set_xlim(1, 1024)
    ax.set_ylim(0, 1)
    ax.set_xlabel('period, $P$ [day]')
    ax.set_ylabel('eccentricity, $e$')
    fig.tight_layout()
    fig.savefig('plots/{0}-samples.png'.format(i), dpi=256)
    
    idx = np.delete(idx, np.random.randint(len(idx)))

In [ ]:
print(idx)
i = 7
_data = data[idx]
samples = joker.iterative_rejection_sample(_data, n_requested_samples=128, 
                                           n_prior_samples=1000000)

fig, ax = plt.subplots(1, 1, figsize=(8, 6))
plot_kwargs = dict(color='#aaaaaa', alpha=0.25, linewidth=0.5)
fig = plot_rv_curves(samples, t_grid=t_grid, n_plot=128, 
                     plot_kwargs=plot_kwargs, ax=ax)

# Darken the shortest period sample
dark_style = dict(color='#333333', alpha=0.5, linewidth=0.5, zorder=10)

P_min_samples = samples[samples['P'].argmin()]
plot_rv_curves(P_min_samples, t_grid, rv_unit=u.km/u.s, ax=ax,
               n_plot=1, plot_kwargs=dark_style)

# Darken the longest period sample
P_max_samples = samples[samples['P'].argmax()]
plot_rv_curves(P_max_samples, t_grid, rv_unit=u.km/u.s, ax=ax,
               n_plot=1, plot_kwargs=dark_style)

_ = _data.plot(ax=fig.axes[0], marker='.')

fig.axes[0].set_xlim(xlims)
fig.axes[0].set_ylim(ylims)

fig.tight_layout()
fig.savefig('plots/{0}-orbits.png'.format(i), dpi=256)

fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.scatter(samples['P'].to(u.day).value, samples['e'], 
           marker='.', alpha=0.5, linewidth=0)
ax.set_xscale('log')
ax.set_xlim(1, 1024)
ax.set_ylim(0, 1)
ax.set_xlabel('period, $P$ [day]')
ax.set_ylabel('eccentricity, $e$')
fig.tight_layout()
fig.savefig('plots/{0}-samples.png'.format(i), dpi=256)

In [ ]:


In [ ]:


In [ ]:
fig, axes = plt.subplots(2, 5, figsize=(12, 4.8), sharex=True, sharey=True)

for i, N in enumerate(sorted(list(all_stats.keys()))):
    if i > len(axes.flat)-1: break
        
    axes.flat[i].scatter(all_stats[N][:,1], all_stats[N][:,2], alpha=0.1, marker='.')

fig.tight_layout()

In [ ]: