Summary of statistics of 4 models for the SDSS test set

In this notebook we examine the ROC curve, the FoM, the accuracy, and the ROC AUC for 4 models: the ML model, the simple model, the PS1 model, and the SDSS model, for the SDSS test set. The uncertainty for each statistic is derived by 100 bootstrap resampling.

The performance of the ML model is the best than that of altanative models except the accuracy. The bias in the SDSS test set towards point sources at the faint end and the hard star--galaxy cut employed in the SDSS model makes the reversal of the performance. This bias affects the ROC curve. The kink and the cross of the curve around the TPR=0.015 are due to the bias.


In [47]:
import sys,os,math
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams["font.family"] = "sans-serif"
rcParams['font.sans-serif'] = ['DejaVu Sans']
from matplotlib import gridspec as grs
%matplotlib notebook
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import cm
from astropy.table import Table
import seaborn as sns
import statsmodels.nonparametric.api as smnp
from statsmodels.nonparametric.kernel_density import KDEMultivariate
from scipy.special import expit
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, accuracy_score, auc, make_scorer
from sklearn.model_selection import GridSearchCV, KFold

In [2]:
hst_tbl = Table.read("HST_COSMOS_features_adamamiller.fit").to_pandas()
sdss_tbl = Table.read("sdssSP_MLfeats_adamamiller.fit").to_pandas()

In [5]:
sdss_tbl = sdss_tbl[~sdss_tbl.duplicated('objid')]

In [6]:
hst_det = np.where(hst_tbl.nDetections > 0)
hst_kron_mag = np.array(-2.5*np.log10(hst_tbl['wwKronFlux'].loc[hst_det]/3631))

sdss_det = np.where(sdss_tbl.nDetections > 0)
sdss_photo_det = np.logical_and(np.isfinite(sdss_tbl.type),
                                sdss_tbl.countRatio > 0)
sdss_in_common = np.logical_and(np.isfinite(sdss_tbl.iPSFminusKron), 
                                sdss_photo_det)
low_z_gal = np.logical_and(sdss_tbl.z < 1e-4, sdss_tbl['class'] == 'GALAXY')
low_z_qso = np.logical_and(sdss_tbl.z < 1, sdss_tbl['class'] == '   QSO')

sdss_test_set = sdss_in_common & ~low_z_gal & ~low_z_qso

In [8]:
features = ['wwpsfChiSq', 'wwExtNSigma', 'wwpsfLikelihood',
            'wwPSFKronRatio', 'wwPSFKronDist',  'wwPSFApRatio',
            'wwmomentRH', 'wwmomentXX', 'wwmomentXY', 'wwmomentYY', 
            'wwKronRad']

In [9]:
hst_ml_train_X = np.array(hst_tbl[features].loc[hst_det])
hst_ml_train_y = np.array(hst_tbl["MU_CLASS"].loc[hst_det] - 1)
sdss_ml_test_X = np.array(sdss_tbl[features].loc[sdss_test_set])
sdss_spec_class = np.array(sdss_tbl['class'].loc[sdss_test_set])
sdss_ml_test_y = np.ones_like(sdss_spec_class).astype(int)
sdss_ml_test_y[np.where(sdss_spec_class == "GALAXY")] = 0

In [10]:
import star_galaxy_models

In [11]:
rf_obj = star_galaxy_models.RandomForestModel()
rf_obj.read_rf_from_pickle()

In [12]:
ML_predict = rf_obj.rf_clf_.predict_proba(sdss_ml_test_X)

In [13]:
ML_fpr, ML_tpr, ML_thre = roc_curve(sdss_ml_test_y, ML_predict[:,1])
sdss_fpr, sdss_tpr,sdss_thre = roc_curve(sdss_ml_test_y, sdss_tbl["countRatio"].loc[sdss_test_set])
dist_fpr, dist_tpr, dist_thre = roc_curve(sdss_ml_test_y, sdss_tbl["wwPSFKronDist"].loc[sdss_test_set])
i_fpr, i_tpr, i_thre = roc_curve(sdss_ml_test_y, -1.*sdss_tbl["iPSFminusKron"].loc[sdss_test_set])

In [14]:
ps1_preds = -1.*sdss_tbl["iPSFminusKron"].loc[sdss_test_set]
simple_preds = sdss_tbl["wwPSFKronDist"].loc[sdss_test_set]
rf_preds = ML_predict[:,1]
sdss_preds = sdss_tbl["countRatio"].loc[sdss_test_set]
y_sdss = sdss_ml_test_y

In [19]:
from scipy import interp
from sklearn.metrics import roc_curve, accuracy_score, auc, roc_auc_score
def calc_fom(fpr, tpr, thresh, fpr_at=0.005):
    return interp(fpr_at, fpr, thresh), interp(fpr_at, fpr, tpr)

In [50]:
def calc_summary_stats(y_sdss, ps1_preds, simple_preds, rf_preds, sdss_preds, 
                       ps1_ct = -0.05, 
                       simple_ct = 9.2e-7,
                       rf_ct = 0.5, 
                       sdss_ct = 10**(-0.185/2.5)):
    ps1_fpr, ps1_tpr, ps1_thresh = roc_curve(y_sdss, ps1_preds)
    ps1_fom = calc_fom(ps1_fpr, ps1_tpr, ps1_thresh)
    ps1_auc = roc_auc_score(y_sdss, ps1_preds)
    ps1_acc = accuracy_score(y_sdss, -1*ps1_preds <= ps1_ct)
    
    simple_fpr, simple_tpr, simple_thresh = roc_curve(y_sdss, simple_preds)
    simple_fom = calc_fom(simple_fpr, simple_tpr, simple_thresh)
    simple_auc = roc_auc_score(y_sdss, simple_preds)
    simple_acc = accuracy_score(y_sdss, simple_preds >= simple_ct)
    
    rf_fpr, rf_tpr, rf_thresh = roc_curve(y_sdss, rf_preds)
    rf_fom = calc_fom(rf_fpr, rf_tpr, rf_thresh)
    rf_auc = roc_auc_score(y_sdss, rf_preds)
    rf_acc = accuracy_score(y_sdss, rf_preds >= rf_ct)
    
    sdss_fpr, sdss_tpr, sdss_thresh = roc_curve(y_sdss, sdss_preds)
    sdss_fom = calc_fom(sdss_fpr, sdss_tpr, sdss_thresh)
    sdss_auc = roc_auc_score(y_sdss, sdss_preds)
    sdss_acc = accuracy_score(y_sdss, sdss_preds >= sdss_ct)
    
    return  rf_auc, rf_acc, rf_fom,  simple_auc, simple_acc, simple_fom, ps1_auc, ps1_acc, ps1_fom, sdss_auc, sdss_acc, sdss_fom

In [52]:
summary_stats = calc_summary_stats(y_sdss, ps1_preds, simple_preds, rf_preds, sdss_preds)

In [61]:
print(r"""
    RF & {2:.3f} & {1:.3f}  & {0:.3f}  \\
    simple & {5:.3f}  & {4:.3f} & {3:.3f} \\
    PS1 & {8:.3f} & {7:.3f} & {6:.3f} \\
    SDSS & {11:.3f} & {10:.3f} & {9:.3f} \\
""".format(*summary_stats))


    RF & 0.843 & 0.963  & 0.987  \\
    simple & 0.798  & 0.956 & 0.985 \\
    PS1 & 0.290 & 0.900 & 0.984 \\
    SDSS & 0.777 & 0.971 & 0.987 \\

bootstrap resampling


In [20]:
kron_mag = np.array(-2.5*np.log10(sdss_tbl['wwKronFlux'].loc[sdss_test_set]/3631))

In [24]:
def summary_stats_bootstrap(gt, preds, ct=0.5, fom_at=[0.005, 0.01, 0.02, 0.05, 0.1], Nboot=100, mag_max = 30):
    mag_mask = np.where(kron_mag  <= mag_max)
    acc = accuracy_score(gt[mag_mask], preds[mag_mask] >= ct)
    auc = roc_auc_score(gt[mag_mask], preds[mag_mask])
    fpr, tpr, thresh = roc_curve(gt[mag_mask], preds[mag_mask])
    fom_thresh = np.array([calc_fom(fpr, tpr, thresh, f) for f in fom_at])
    thresh = fom_thresh[:,0]
    fom = fom_thresh[:,1]
    acc_std_arr = np.empty(Nboot)
    auc_std_arr = np.empty_like(acc_std_arr)
    fom_std_arr = np.empty((Nboot, len(fom_at)))
    thresh_std_arr = np.empty_like(fom_std_arr)
    for i in range(Nboot):
        if(i%1==0):
            print("%d."%i, end='')
        boot_sources = np.random.choice(mag_mask[0], len(mag_mask[0]), replace=True)
        auc_std_arr[i] = roc_auc_score(gt[boot_sources], preds[boot_sources])
        acc_std_arr[i] = accuracy_score(gt[boot_sources], preds[boot_sources] >= ct)
        _fpr, _tpr, _thresh = roc_curve(gt[boot_sources], preds[boot_sources])
        _fom_thresh = np.array([calc_fom(_fpr, _tpr, _thresh, f) for f in fom_at])
        thresh_std_arr[i,:] = _fom_thresh[:,0]
        fom_std_arr[i,:] = _fom_thresh[:,1]
    acc_std = np.std(acc_std_arr)
    auc_std = np.std(auc_std_arr)
    fom_std = np.std(fom_std_arr, axis=0)
    thresh_std = np.std(thresh_std_arr, axis=0)
    return {'Num': len(kron_mag[mag_mask]), 
                'Acc': acc*100, 
                'AUC': auc, 
                'FoM': fom, 
                'Thresh': thresh, 
                'AccSTD': acc_std*100, 
                'AUCSTD': auc_std, 
                'FoMSTD': fom_std, 
                'ThreshSTD': thresh_std}

In [38]:
ps1_ct = -0.05, 
simple_ct = 9.2e-7
rf_ct = 0.5
sdss_ct = 10**(-0.185/2.5)

In [25]:
stat_rf = summary_stats_bootstrap(y_sdss, rf_preds, ct=rf_ct, Nboot=100)


0.1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.57.58.59.60.61.62.63.64.65.66.67.68.69.70.71.72.73.74.75.76.77.78.79.80.81.82.83.84.85.86.87.88.89.90.91.92.93.94.95.96.97.98.99.

In [45]:
stat_rf


Out[45]:
{'AUC': 0.98713117838108511,
 'AUCSTD': 6.6746500963364271e-05,
 'Acc': 96.25373440899898,
 'AccSTD': 0.010122492154387774,
 'FoM': array([ 0.84308974,  0.91322055,  0.94742827,  0.96477697,  0.97594995]),
 'FoMSTD': array([ 0.00112882,  0.00053063,  0.00021365,  0.00017044,  0.0001365 ]),
 'Num': 3592938,
 'Thresh': array([ 0.810375  ,  0.50862524,  0.26413417,  0.135     ,  0.06066667]),
 'ThreshSTD': array([ 0.00272024,  0.00285574,  0.00100327,  0.00034501,  0.00014297])}

In [39]:
stat_ps1 = summary_stats_bootstrap(y_sdss, ps1_preds.values, ct=ps1_ct, Nboot=100)
stat_simple = summary_stats_bootstrap(y_sdss, simple_preds.values, ct=simple_ct, Nboot=100)
stat_sdss = summary_stats_bootstrap(y_sdss, sdss_preds.values, ct=sdss_ct, Nboot=100)


0.1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.57.58.59.60.61.62.63.64.65.66.67.68.69.70.71.72.73.74.75.76.77.78.79.80.81.82.83.84.85.86.87.88.89.90.91.92.93.94.95.96.97.98.99.0.1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.57.58.59.60.61.62.63.64.65.66.67.68.69.70.71.72.73.74.75.76.77.78.79.80.81.82.83.84.85.86.87.88.89.90.91.92.93.94.95.96.97.98.99.0.1.2.3.4.5.6.7.8.9.10.11.12.13.14.15.16.17.18.19.20.21.22.23.24.25.26.27.28.29.30.31.32.33.34.35.36.37.38.39.40.41.42.43.44.45.46.47.48.49.50.51.52.53.54.55.56.57.58.59.60.61.62.63.64.65.66.67.68.69.70.71.72.73.74.75.76.77.78.79.80.81.82.83.84.85.86.87.88.89.90.91.92.93.94.95.96.97.98.99.

In [40]:
stat_ps1


Out[40]:
{'AUC': 0.98411794372530115,
 'AUCSTD': 6.7086071564212565e-05,
 'Acc': 96.117383600830294,
 'AccSTD': 0.0090463838760812768,
 'FoM': array([ 0.28994216,  0.7613966 ,  0.935268  ,  0.9685331 ,  0.97898192]),
 'FoMSTD': array([ 0.0036232 ,  0.00401264,  0.00034549,  0.00013789,  0.00012308]),
 'Num': 3592938,
 'Thresh': array([ 0.13290072,  0.04500008, -0.07390013, -0.24310112, -0.35819817]),
 'ThreshSTD': array([ 0.00089204,  0.00100816,  0.00079227,  0.00042546,  0.00033153])}

In [41]:
stat_simple


Out[41]:
{'AUC': 0.98503017551718985,
 'AUCSTD': 8.3192975159283847e-05,
 'Acc': 95.568751812583457,
 'AccSTD': 0.011220871543247714,
 'FoM': array([ 0.79792979,  0.89413228,  0.95014616,  0.96744146,  0.97444493]),
 'FoMSTD': array([ 0.00158545,  0.00080701,  0.0002654 ,  0.00017194,  0.00015392]),
 'Num': 3592938,
 'Thresh': array([  1.70552150e-06,   9.31750294e-07,   5.13090947e-08,
         -2.12707404e-06,  -4.57542294e-06]),
 'ThreshSTD': array([  1.24870938e-08,   6.69814634e-09,   7.64245548e-09,
          9.21674138e-09,   8.17511288e-09])}

In [42]:
stat_sdss


Out[42]:
{'AUC': 0.98660473889196332,
 'AUCSTD': 7.8183918684798814e-05,
 'Acc': 97.128533807151697,
 'AccSTD': 0.0088219189858879405,
 'FoM': array([ 0.77720695,  0.90346533,  0.95586257,  0.97331779,  0.97965236]),
 'FoMSTD': array([ 0.00297895,  0.00087134,  0.00030467,  0.00015363,  0.00012823]),
 'Num': 3592938,
 'Thresh': array([ 0.96247655,  0.9231685 ,  0.87565661,  0.68230901,  0.60383077]),
 'ThreshSTD': array([ 0.00053257,  0.00047051,  0.00059737,  0.00040919,  0.00024205])}

ROC curve


In [25]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

cmap = plt.get_cmap("Dark2")

color_dict = {'ml': cmap(0.33), 
              'sdss': cmap(0.66), 
              'simple': cmap(0.),
              'ps1': cmap(1.)}

# apple colors
color_dict = {'ml': "#0072c6", #"#1C1858",
              'sdss': "#5BC236", #"#00C78E",
              'simple': "#C864AF", #"#C70039",
              'ps1': "#C65400"}

# color blind friendly
color_dict = {'ml': '#0072b2', 
              'sdss': '#d55e00', 
              'simple': '#cc79a7',
              'ps1': '#009e73'}

# color brewer
color_dict = {'ml': '#7570b3', 
              'sdss': '#d95f02', 
              'simple': '#1b9e77',
              'ps1': '#34495e'}

col_star_sdss = '#e41a1c'
col_star_ps1 = '#001e43'

ls_dict = {'ml': '-', 
              'sdss': '-.', 
              'simple': '--',
              'ps1': '--'}

lw_dict = {'ml': 1.75, 
              'sdss': 1.5, 
              'simple': 1.5, 
              'ps1': 1.5}

ylims = [0, 1.02]
xlims = [1.7e-4, 1]

fig, main_ax = plt.subplots(figsize=(7,5))


axins = inset_axes(main_ax, width="43.5%",  
                   height="53%", loc=3,
                   bbox_to_anchor=(0.55, 0.07, 1., 1.),
                   bbox_transform=main_ax.transAxes)

main_ax.grid(alpha=0.5, lw=0.5, c='grey', linestyle=':') 
main_ax.tick_params(which="both", top=True, right=True)
main_ax.minorticks_on()

main_ax.plot(ML_fpr, ML_tpr, color=color_dict['ml'], label='RF model', ls=ls_dict['ml'], lw=lw_dict['ml'], alpha=0.9, zorder=4)
main_ax.plot(dist_fpr, dist_tpr, color=color_dict['simple'], label='Simple model', ls=ls_dict['simple'], lw=lw_dict['simple'], alpha=0.9, zorder=3)
main_ax.plot(i_fpr, i_tpr, color=color_dict['ps1'], label='PS1 model', ls=ls_dict['ps1'], lw=lw_dict['ps1'], dashes=(8, 4), alpha=0.9, zorder=2)
main_ax.plot(sdss_fpr, sdss_tpr, color=color_dict['sdss'], label='SDSS photo', ls=ls_dict['sdss'], lw=lw_dict['sdss'], alpha=0.9, zorder=1)
main_ax.plot(sdss_fpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], 
        sdss_tpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], '*', markersize=10, color=col_star_sdss, zorder=5)
main_ax.plot(i_fpr[np.argmin(np.abs(i_thre+0.05))], 
        i_tpr[np.argmin(np.abs(i_thre+0.05))], '*', markersize=10, color=col_star_ps1, zorder=5)

main_ax.vlines([5e-3], 1e-3, 1.02, 
          color='DarkSlateGrey', lw=2., linestyles=":", zorder=1)
main_ax.text(4.9e-3, 0.1, 'FoM', 
        color='DarkSlateGrey', 
        rotation=90, ha="right", fontsize=14)

main_ax.set_xlim(xlims); main_ax.set_ylim(ylims)
main_ax.set_xscale('log')
main_ax.tick_params(labelsize = 15)
main_ax.set_xlabel('False Positive Rate', fontsize=15)
main_ax.set_ylabel('True Positive Rate', fontsize=15)

main_ax.legend(loc=3, borderaxespad=0, fontsize=13, 
               handlelength=1.5, 
               bbox_to_anchor=(0.0125, 0.725, 1., 0.102))


ylims = [0.875, 0.965]
xlims = np.array([7.5e-3, 2.75e-2])

origin = 'lower'
axins.tick_params(which="both", top=True)
axins.minorticks_on()

axins.plot(ML_fpr, ML_tpr, color=color_dict['ml'], label='RF model', ls=ls_dict['ml'], lw=lw_dict['ml'], alpha=0.9, zorder=5)
axins.plot(dist_fpr, dist_tpr, color=color_dict['simple'], label='Simple model', ls=ls_dict['simple'], lw=lw_dict['simple'], alpha=0.9, zorder=4)
axins.plot(i_fpr, i_tpr, color=color_dict['ps1'], label='PS1 model', ls=ls_dict['ps1'], lw=lw_dict['ps1'], dashes=(8, 4), alpha=0.9, zorder=3)
axins.plot(sdss_fpr, sdss_tpr, color=color_dict['sdss'], label='SDSS photo', ls=ls_dict['sdss'], lw=lw_dict['sdss'], alpha=0.9, zorder=2)
axins.plot(sdss_fpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], 
        sdss_tpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], '*', markersize=10, color=col_star, zorder=6)
axins.plot(i_fpr[np.argmin(np.abs(i_thre+0.05))], 
        i_tpr[np.argmin(np.abs(i_thre+0.05))], '*', markersize=10, color=col_star_ps1, zorder=5)

"""
axins.vlines([5e-3], 1e-3, 1, 
          color='DarkSlateGrey', lw=2., linestyles=":", zorder=1)
axins.text(5e-3, 0.59, 'FoM', 
        color='DarkSlateGrey', 
        rotation=90, ha="right", fontsize=14)
"""
axins.set_xlim(xlims); axins.set_ylim(ylims)
#axins.set_yticks([0.85,0.90, 0.95])
axins.tick_params(labelsize = 15)

fig.subplots_adjust(right=0.97,top=0.98,bottom=0.11,left=0.1)
fig.savefig('../paper/Figures/ROC_curves_log_inset2.pdf')


Accuracy curve for the original SDSS test set


In [31]:
ml_decision_thresh = 0.5
sdss_decision_thresh = 10**(-0.185/2.5)
simple_decision_thresh = 9.2e-07 #  maximize acc on training set
ps1_decision_thresh = 0.05

ml_labels = np.logical_not(ML_predict[:,1] < ml_decision_thresh).astype(int)
sdss_labels = np.logical_not(np.array(sdss_tbl["countRatio"].ix[sdss_test_set]) < sdss_decision_thresh).astype(int)
simple_labels = np.logical_not(np.array(sdss_tbl["wwPSFKronDist"].ix[sdss_test_set]) < simple_decision_thresh).astype(int)
ps1_labels = np.logical_not(-1*np.array(sdss_tbl["iPSFminusKron"].ix[sdss_test_set]) < ps1_decision_thresh).astype(int)

In [32]:
binwidth = 0.5
Nboot = 100
mag_array = np.arange(13 , 23+binwidth, binwidth)
kron_mag = np.array(-2.5*np.log10(sdss_tbl['wwKronFlux'].ix[sdss_test_set]/3631))

sdss_acc_arr = np.zeros_like(mag_array)
simple_acc_arr = np.zeros_like(mag_array)
ps1_acc_arr = np.zeros_like(mag_array)
ml_acc_arr = np.zeros_like(mag_array)

sdss_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))


for bin_num, binedge in enumerate(mag_array):
    bin_sources = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth))
    sdss_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                           sdss_labels[bin_sources])
    simple_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                             simple_labels[bin_sources])
    ps1_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                          ps1_labels[bin_sources])
    ml_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                         ml_labels[bin_sources])
    sdss_boot_acc = np.empty(Nboot1
    simple_boot_acc = np.empty_like(sdss_boot_acc)
    ps1_boot_acc = np.empty_like(sdss_boot_acc)
    ml_boot_acc = np.empty_like(sdss_boot_acc)
    for i in range(Nboot):
        boot_sources = np.random.choice(bin_sources[0], len(bin_sources[0]), 
                                        replace=True)
        sdss_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                           sdss_labels[boot_sources])
        simple_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                             simple_labels[boot_sources])
        ps1_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                          ps1_labels[boot_sources])
        ml_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                         ml_labels[boot_sources])

    sdss_boot_scatt[:,bin_num] = np.percentile(sdss_boot_acc, [16, 84])
    simple_boot_scatt[:,bin_num] = np.percentile(simple_boot_acc, [16, 84])
    ps1_boot_scatt[:,bin_num] = np.percentile(ps1_boot_acc, [16, 84])
    ml_boot_scatt[:,bin_num] = np.percentile(ml_boot_acc, [16, 84])


/Users/y_tachibana/anaconda/envs/py35/lib/python3.5/site-packages/ipykernel_launcher.py:4: DeprecationWarning: 
.ix is deprecated. Please use
.loc for label based indexing or
.iloc for positional indexing

See the documentation here:
http://pandas.pydata.org/pandas-docs/stable/indexing.html#ix-indexer-is-deprecated
  after removing the cwd from sys.path.

In [34]:
from sklearn.neighbors import KernelDensity
kde_grid = np.linspace(13,23.5,200)
sdss_stars = np.where(sdss_ml_test_y == 1)
sdss_gal = np.where(sdss_ml_test_y == 0)
sdss_kde_gal_norm = len(sdss_gal[0])/len(sdss_ml_test_y)
sdss_kde_star_norm = 1 - sdss_kde_gal_norm

kde_sdss = KernelDensity(bandwidth=1.059*np.std(kron_mag, ddof=1)*len(kron_mag)**(-0.2),
                         rtol=1E-4)
kde_sdss.fit(kron_mag[:, np.newaxis])

kde_sdss_stars = KernelDensity(bandwidth=1.059*np.std(kron_mag[sdss_stars], ddof=1)*len(kron_mag[sdss_stars])**(-0.2),
                               rtol=1E-4)
kde_sdss_stars.fit(kron_mag[sdss_stars[0], np.newaxis])

kde_sdss_gal = KernelDensity(bandwidth=1.059*np.std(kron_mag[sdss_gal], ddof=1)*len(kron_mag[sdss_gal])**(-0.2),
                             rtol=1E-4)
kde_sdss_gal.fit(kron_mag[sdss_gal[0], np.newaxis])

pdf_sdss = np.exp(kde_sdss.score_samples(kde_grid[:, np.newaxis]))
pdf_sdss_stars = np.exp(kde_sdss_stars.score_samples(kde_grid[:, np.newaxis]))
pdf_sdss_gal = np.exp(kde_sdss_gal.score_samples(kde_grid[:, np.newaxis]))

In [61]:
from matplotlib.ticker import MultipleLocator
color_dict = {'ml': "black", #"#1C1858",
              'sdss': "#5BC236", #"#00C78E",
              'simple': "#C864AF", #"#C70039",
              'ps1': "#C65400"}

mag_bin_centers = mag_array + binwidth/2
cmap_star = sns.cubehelix_palette(rot=0.5, light=0.7,dark=0.3,as_cmap=True)
cmap_gal = sns.cubehelix_palette(start=0.3,rot=-0.5,light=0.7,dark=0.3,as_cmap=True)

fig, ax = plt.subplots(figsize=(8, 5))

ax.errorbar(mag_bin_centers, ml_acc_arr, 
            yerr=np.abs(ml_boot_scatt - ml_acc_arr), 
            ls='-', lw=.75, fmt='o',
            color=color_dict['ml'], label="ML model", zorder=2)
"""
ax.errorbar(mag_bin_centers, simple_acc_arr, 
            yerr=np.abs(simple_boot_scatt - simple_acc_arr), 
            ls ='-', lw=.5, fmt='o',
            color=color_dict['simple'], label="Simple model")
ax.errorbar(mag_bin_centers, ps1_acc_arr, 
            yerr=np.abs(ps1_boot_scatt - ps1_acc_arr), 
            ls ='-', lw=.5, dashes=(8, 4), fmt='o',
            color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
"""
ax.errorbar(mag_bin_centers, sdss_acc_arr, 
            yerr=np.abs(sdss_boot_scatt - sdss_acc_arr), 
            ls='-', lw=.5, fmt='o',
            color=color_dict['sdss'], label="SDSS photo", zorder=1)

# add KDE plots
ax.fill(kde_grid, pdf_sdss + 0.5, alpha=0.4, color="0.7")
ax.fill(kde_grid, pdf_sdss_gal*sdss_kde_gal_norm + 0.5, alpha=0.7, color=cmap_gal(0.25))
ax.fill(kde_grid, pdf_sdss_stars*sdss_kde_star_norm + 0.5, alpha=0.7, color=cmap_star(0.25))

ax.set_ylim(0.5,1.01)
ax.set_xlim(13, 23.5)
ax.tick_params(which="both", top=True, right=True, labelsize=15)
ax.set_xlabel('whiteKronMag', fontsize=15)
ax.set_ylabel('Accuracy', fontsize=15)

ax.yaxis.set_minor_locator(MultipleLocator(0.025))
ax.xaxis.set_major_locator(MultipleLocator(2))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))

ax.legend(bbox_to_anchor=(0.01, 0.3, 1., 0.102), loc=3, fontsize=13)

fig.subplots_adjust(top=0.98,right=0.98,left=0.1,bottom=0.12)


Debiasing by resampling


In [55]:
binwidth = 0.5
Nboot = 100

# bootstrap star arrays
sdss_star_acc_arr = np.zeros_like(mag_array)
simple_star_acc_arr = np.zeros_like(mag_array)
ps1_star_acc_arr = np.zeros_like(mag_array)
ml_star_acc_arr = np.zeros_like(mag_array)

sdss_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))

# bootstrap galaxy arrays
sdss_gal_acc_arr = np.zeros_like(mag_array)
simple_gal_acc_arr = np.zeros_like(mag_array)
ps1_gal_acc_arr = np.zeros_like(mag_array)
ml_gal_acc_arr = np.zeros_like(mag_array)

sdss_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))

for bin_num, binedge in enumerate(mag_array):
    bin_stars = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                         (sdss_ml_test_y == 1))
    sdss_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                           sdss_labels[bin_stars])
    simple_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                             simple_labels[bin_stars])
    ps1_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                          ps1_labels[bin_stars])
    ml_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                         ml_labels[bin_stars])
    
    bin_gals = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                        (sdss_ml_test_y == 0))
    sdss_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                           sdss_labels[bin_gals])
    simple_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                             simple_labels[bin_gals])
    ps1_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                          ps1_labels[bin_gals])
    ml_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                         ml_labels[bin_gals])    
    
    # get the bootstrap accuracies
    
    sdss_star_boot_acc = np.empty(Nboot)
    simple_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ps1_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ml_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    
    sdss_gal_boot_acc = np.empty_like(sdss_star_boot_acc)
    simple_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ps1_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ml_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)

    for i in range(Nboot):
        star_boot_sources = np.random.choice(bin_stars[0], len(bin_stars[0]), 
                                             replace=True)
        sdss_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                           sdss_labels[star_boot_sources])
        simple_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                             simple_labels[star_boot_sources])
        ps1_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                          ps1_labels[star_boot_sources])
        ml_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                         ml_labels[star_boot_sources])
        
        gal_boot_sources = np.random.choice(bin_gals[0], len(bin_gals[0]), 
                                            replace=True)
        sdss_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                           sdss_labels[gal_boot_sources])
        simple_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                             simple_labels[gal_boot_sources])
        ps1_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                          ps1_labels[gal_boot_sources])
        ml_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                         ml_labels[gal_boot_sources])


    sdss_star_boot_scatt[:,bin_num] = np.percentile(sdss_star_boot_acc, [16, 84])
    simple_star_boot_scatt[:,bin_num] = np.percentile(simple_star_boot_acc, [16, 84])
    ps1_star_boot_scatt[:,bin_num] = np.percentile(ps1_star_boot_acc, [16, 84])
    ml_star_boot_scatt[:,bin_num] = np.percentile(ml_star_boot_acc, [16, 84])    
    
    sdss_gal_boot_scatt[:,bin_num] = np.percentile(sdss_gal_boot_acc, [16, 84])
    simple_gal_boot_scatt[:,bin_num] = np.percentile(simple_gal_boot_acc, [16, 84])
    ps1_gal_boot_scatt[:,bin_num] = np.percentile(ps1_gal_boot_acc, [16, 84])
    ml_gal_boot_scatt[:,bin_num] = np.percentile(ml_gal_boot_acc, [16, 84])

In [84]:
binwidth = 0.5
Nboot = 100

sdss_resamp_arr = np.zeros_like(mag_array)
simple_resamp_arr = np.zeros_like(mag_array)
ps1_resamp_arr = np.zeros_like(mag_array)
ml_resamp_arr = np.zeros_like(mag_array)

sdss_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))

# bootstrap star arrays
sdss_star_resamp_arr = np.zeros_like(mag_array)
simple_star_resamp_arr = np.zeros_like(mag_array)
ps1_star_resamp_arr = np.zeros_like(mag_array)
ml_star_resamp_arr = np.zeros_like(mag_array)

sdss_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))

# bootstrap galaxy arrays
sdss_gal_resamp_arr = np.zeros_like(mag_array)
simple_gal_resamp_arr = np.zeros_like(mag_array)
ps1_gal_resamp_arr = np.zeros_like(mag_array)
ml_gal_resamp_arr = np.zeros_like(mag_array)

sdss_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))

for bin_num, binedge in enumerate(mag_array):
    if bin_num <= 3:
        continue
    
    bin_stars = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                         (sdss_ml_test_y == 1))
    
    bin_gals = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                        (sdss_ml_test_y == 0))

    hst_stars = np.where((hst_kron_mag >= binedge) & (hst_kron_mag < binedge + binwidth) & 
                         (hst_ml_train_y == 1))
    hst_gals = np.where((hst_kron_mag >= binedge) & (hst_kron_mag < binedge + binwidth) & 
                        (hst_ml_train_y == 0))
    
    # figure out the number of stars and galaxies to select
    if len(hst_stars[0])/len(hst_gals[0]) > len(bin_stars[0])/len(bin_gals[0]):
        n_star_resamp = len(bin_stars[0])
        n_gal_resamp = int(len(bin_stars[0])*len(hst_gals[0])/len(hst_stars[0]))
    else:
        n_star_resamp = int(len(bin_gals[0])*len(hst_stars[0])/len(hst_gals[0]))
        n_gal_resamp = len(bin_gals[0])
    
    print(len(hst_stars[0]), len(hst_gals[0]), 
          len(bin_stars[0]), len(bin_gals[0]),
          n_star_resamp, n_gal_resamp)
    
    # get the bootstrap accuracies    
    sdss_boot_acc = np.empty(Nboot)
    simple_boot_acc = np.empty_like(sdss_boot_acc)
    ps1_boot_acc = np.empty_like(sdss_boot_acc)
    ml_boot_acc = np.empty_like(sdss_boot_acc)

    sdss_star_boot_acc = np.empty(Nboot)
    simple_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ps1_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ml_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    
    sdss_gal_boot_acc = np.empty_like(sdss_star_boot_acc)
    simple_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ps1_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ml_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)

    for i in range(Nboot):
        star_boot_sources = np.random.choice(bin_stars[0], n_star_resamp, 
                                             replace=True)
        gal_boot_sources = np.random.choice(bin_gals[0], n_gal_resamp, 
                                            replace=True)
        boot_sources = np.append(star_boot_sources, gal_boot_sources)

        sdss_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                           sdss_labels[boot_sources])
        simple_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                             simple_labels[boot_sources])
        ps1_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                          ps1_labels[boot_sources])
        ml_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                         ml_labels[boot_sources])

        sdss_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                           sdss_labels[star_boot_sources])
        simple_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                             simple_labels[star_boot_sources])
        ps1_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                          ps1_labels[star_boot_sources])
        ml_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                         ml_labels[star_boot_sources])
        
        sdss_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                           sdss_labels[gal_boot_sources])
        simple_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                             simple_labels[gal_boot_sources])
        ps1_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                          ps1_labels[gal_boot_sources])
        ml_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                         ml_labels[gal_boot_sources])


    sdss_resamp_arr[bin_num] = np.mean(sdss_boot_acc)
    simple_resamp_arr[bin_num] = np.mean(simple_boot_acc)
    ps1_resamp_arr[bin_num] = np.mean(ps1_boot_acc)
    ml_resamp_arr[bin_num] = np.mean(ml_boot_acc)
    
    sdss_resamp_scatt[:,bin_num] = np.percentile(sdss_boot_acc, [16, 84])
    simple_resamp_scatt[:,bin_num] = np.percentile(simple_boot_acc, [16, 84])
    ps1_resamp_scatt[:,bin_num] = np.percentile(ps1_boot_acc, [16, 84])
    ml_resamp_scatt[:,bin_num] = np.percentile(ml_boot_acc, [16, 84])    

    sdss_star_resamp_arr[bin_num] = np.mean(sdss_star_boot_acc)
    simple_star_resamp_arr[bin_num] = np.mean(simple_star_boot_acc)
    ps1_star_resamp_arr[bin_num] = np.mean(ps1_star_boot_acc)
    ml_star_resamp_arr[bin_num] = np.mean(ml_star_boot_acc)
    
    sdss_star_resamp_scatt[:,bin_num] = np.percentile(sdss_star_boot_acc, [16, 84])
    simple_star_resamp_scatt[:,bin_num] = np.percentile(simple_star_boot_acc, [16, 84])
    ps1_star_resamp_scatt[:,bin_num] = np.percentile(ps1_star_boot_acc, [16, 84])
    ml_star_resamp_scatt[:,bin_num] = np.percentile(ml_star_boot_acc, [16, 84])    

    sdss_gal_resamp_arr[bin_num] = np.mean(sdss_gal_boot_acc)
    simple_gal_resamp_arr[bin_num] = np.mean(simple_gal_boot_acc)
    ps1_gal_resamp_arr[bin_num] = np.mean(ps1_gal_boot_acc)
    ml_gal_resamp_arr[bin_num] = np.mean(ml_gal_boot_acc)
    
    sdss_gal_resamp_scatt[:,bin_num] = np.percentile(sdss_gal_boot_acc, [16, 84])
    simple_gal_resamp_scatt[:,bin_num] = np.percentile(simple_gal_boot_acc, [16, 84])
    ps1_gal_resamp_scatt[:,bin_num] = np.percentile(ps1_gal_boot_acc, [16, 84])
    ml_gal_resamp_scatt[:,bin_num] = np.percentile(ml_gal_boot_acc, [16, 84])


126 3 23645 22763 23645 562
161 6 36897 45170 36897 1375
206 10 50611 94001 50611 2456
288 29 66261 166928 66261 6672
325 57 70723 293069 70723 12403
401 93 79154 174356 79154 18357
472 176 89247 131564 89247 33278
633 343 107231 188220 107231 58104
731 653 106618 255031 106618 95241
926 1057 111399 528090 111399 127158
1007 1756 133584 237271 133584 232942
1200 2811 137765 76111 32491 76111
1397 4375 137181 30476 9731 30476
1416 6133 97725 19126 4415 19126
1120 6738 31678 7312 1215 7312
765 5704 4325 1319 176 1319
370 3312 478 125 13 125

In [76]:
kde_grid = np.linspace(13,23.5,200)
hst_stars = np.where(hst_ml_train_y == 1)
hst_gal = np.where(hst_ml_train_y == 0)

kde_hst = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag, ddof=1)*len(hst_kron_mag)**(-0.2),
                         rtol=1E-4)
kde_hst.fit(hst_kron_mag[:, np.newaxis])

kde_hst_stars = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag[hst_stars], ddof=1)*len(hst_kron_mag[hst_stars])**(-0.2),
                               rtol=1E-4)
kde_hst_stars.fit(hst_kron_mag[hst_stars[0], np.newaxis])

kde_hst_gal = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag[hst_gal], ddof=1)*len(hst_kron_mag[hst_gal])**(-0.2),
                             rtol=1E-4)
kde_hst_gal.fit(hst_kron_mag[hst_gal[0], np.newaxis])

pdf_hst = np.exp(kde_hst.score_samples(kde_grid[:, np.newaxis]))
pdf_hst_stars = np.exp(kde_hst_stars.score_samples(kde_grid[:, np.newaxis]))
pdf_hst_gal = np.exp(kde_hst_gal.score_samples(kde_grid[:, np.newaxis]))

hst_kde_gal_norm = len(hst_gal[0])/len(hst_ml_train_y)
hst_kde_star_norm = 1 - hst_kde_gal_norm

In [111]:
color_dict = {'ml': "black", #"#1C1858",
              'sdss': "teal",
              'simple': "#C864AF", #"#C70039",
              'ps1': "#C65400"}

plt.figure(figsize=(12.5,5))

plt.subplot(1,2,1)

plt.errorbar(mag_bin_centers+0.1, ml_star_acc_arr, 
            yerr=np.abs(ml_star_boot_scatt - ml_star_acc_arr), 
            ls =ls_dict['ml'], lw=.75, fmt='^', ms=8,
            mec="0.2", mew=0.5,
            color=color_dict['ml'], label="ML stars", zorder=4)
# ax.errorbar(mag_bin_centers, simple_star_acc_arr, 
#             yerr=np.abs(simple_star_boot_scatt - simple_star_acc_arr), 
#             ls =ls_dict['simple'], lw=.5, fmt='*', ms=10,
#             mec="0.2", mew=0.5,
#             color=color_dict['simple'], label="Simple model")
# ax.errorbar(mag_bin_centers, ps1_star_acc_arr, 
#             yerr=np.abs(ps1_star_boot_scatt - ps1_star_acc_arr), 
#             ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt='*', ms=10,
#             mec="0.2", mew=0.5,
#             color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
plt.errorbar(mag_bin_centers+0.1, sdss_star_acc_arr, 
            yerr=np.abs(sdss_star_boot_scatt - sdss_star_acc_arr), 
            ls=ls_dict['sdss'], lw=.5, fmt='^', ms=8,
            mec="0.2", mew=0.5,
            color=color_dict['sdss'], label="SDSS stars", zorder=2)


plt.errorbar(mag_bin_centers-0.1, ml_gal_acc_arr, 
            yerr=np.abs(ml_gal_boot_scatt - ml_gal_acc_arr), 
            ls =ls_dict['ml'], lw=.75, fmt='s',
            color=color_dict['ml'], label="ML galaxies", zorder=3)
# ax.errorbar(mag_bin_centers, simple_gal_acc_arr, 
#             yerr=np.abs(simple_gal_boot_scatt - simple_gal_acc_arr), 
#             ls =ls_dict['simple'], lw=.5, fmt=',',
#             color=color_dict['simple'], label="Simple model")
# ax.errorbar(mag_bin_centers, ps1_gal_acc_arr, 
#             yerr=np.abs(ps1_gal_boot_scatt - ps1_gal_acc_arr), 
#             ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt=',',
#             color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
plt.errorbar(mag_bin_centers-0.1, sdss_gal_acc_arr, 
            yerr=np.abs(sdss_gal_boot_scatt - sdss_gal_acc_arr), 
            ls=ls_dict['sdss'], lw=.5, fmt='s',
            color=color_dict['sdss'], label="SDSS galaxies", zorder=1)

# add KDE plots
plt.fill(kde_grid, pdf_sdss + 0.5, alpha=0.4, color="0.7")
plt.fill(kde_grid, pdf_sdss_gal*sdss_kde_gal_norm + 0.5, alpha=0.7, color=cmap_gal(0.25))
plt.fill(kde_grid, pdf_sdss_stars*sdss_kde_star_norm + 0.5, alpha=0.7, color=cmap_star(0.25))

plt.ylim(0.5,1.01)
plt.xlim(13, 23.5)
plt.tick_params(which="both", top=True, right=True, labelsize=15)
plt.xlabel('whiteKronMag', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.minorticks_on()

#plt.yaxis.set_minor_locator(MultipleLocator(0.025))
#plt.xaxis.set_major_locator(MultipleLocator(2))
#plt.xaxis.set_minor_locator(MultipleLocator(0.5))

plt.legend(bbox_to_anchor=(0.01, 0.4, 1., 0.102), loc=3, fontsize=13)

plt.subplot(1,2,2)

plt.errorbar(mag_bin_centers[4:], ml_resamp_arr[4:], 
            yerr=np.abs(ml_resamp_scatt - ml_resamp_arr)[:,4:], 
            ls =ls_dict['ml'], lw=.75, fmt='o',
            color=color_dict['ml'], label="ML model")
"""
ax.errorbar(mag_bin_centers[4:], simple_resamp_arr[4:], 
            yerr=np.abs(simple_resamp_scatt - simple_resamp_arr)[:,4:], 
            ls =ls_dict['simple'], lw=.5, fmt='o',
            color=color_dict['simple'], label="Simple model")
ax.errorbar(mag_bin_centers[4:], ps1_resamp_arr[4:], 
            yerr=np.abs(ps1_resamp_scatt - ps1_resamp_arr)[:,4:], 
            ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt='o',
            color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
"""
plt.errorbar(mag_bin_centers[4:], sdss_resamp_arr[4:], 
            yerr=np.abs(sdss_resamp_scatt - sdss_resamp_arr)[:,4:], 
            ls=ls_dict['sdss'], lw=.5, fmt='o',
            color=color_dict['sdss'], label="SDSS photo")

# add KDE plots
plt.fill(np.append(kde_grid, 23.5), 
        np.append(pdf_hst, 0) + 0.5, alpha=0.4, color="0.7")
plt.fill(np.append(kde_grid, 23.5), 
        np.append(pdf_hst_gal*hst_kde_gal_norm, 0) + 0.5, alpha=0.7, color=cmap_gal(0.25))
plt.fill(np.append(kde_grid, 23.5), 
        np.append(pdf_hst_stars*hst_kde_star_norm, 0) + 0.5, alpha=0.7, color=cmap_star(0.25))

plt.ylim(0.5,1.01)
plt.xlim(15, 23.5)
plt.tick_params(which="both", top=True, right=True, labelsize=15, labelleft='off')
plt.xlabel('whiteKronMag', fontsize=15)
#plt.ylabel('Accuracy', fontsize=15)
plt.minorticks_on()

plt.legend(bbox_to_anchor=(0.01, 0.4, 1., 0.102), loc=3, fontsize=13)

plt.tight_layout()

plt.savefig('Accuracy_SDSS_ML_model.pdf')



In [ ]: