In this notebook we examine the ROC curve, the FoM, the accuracy, and the ROC AUC for 4 models: the ML model, the simple model, the PS1 model, and the SDSS model, for the SDSS test set. The uncertainty for each statistic is derived by 100 bootstrap resampling.
The performance of the ML model is the best than that of altanative models except the accuracy. The bias in the SDSS test set towards point sources at the faint end and the hard star--galaxy cut employed in the SDSS model makes the reversal of the performance. This bias affects the ROC curve. The kink and the cross of the curve around the TPR=0.015 are due to the bias.
In [47]:
    
import sys,os,math
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams["font.family"] = "sans-serif"
rcParams['font.sans-serif'] = ['DejaVu Sans']
from matplotlib import gridspec as grs
%matplotlib notebook
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import cm
from astropy.table import Table
import seaborn as sns
import statsmodels.nonparametric.api as smnp
from statsmodels.nonparametric.kernel_density import KDEMultivariate
from scipy.special import expit
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, accuracy_score, auc, make_scorer
from sklearn.model_selection import GridSearchCV, KFold
    
In [2]:
    
hst_tbl = Table.read("HST_COSMOS_features_adamamiller.fit").to_pandas()
sdss_tbl = Table.read("sdssSP_MLfeats_adamamiller.fit").to_pandas()
    
In [5]:
    
sdss_tbl = sdss_tbl[~sdss_tbl.duplicated('objid')]
    
In [6]:
    
hst_det = np.where(hst_tbl.nDetections > 0)
hst_kron_mag = np.array(-2.5*np.log10(hst_tbl['wwKronFlux'].loc[hst_det]/3631))
sdss_det = np.where(sdss_tbl.nDetections > 0)
sdss_photo_det = np.logical_and(np.isfinite(sdss_tbl.type),
                                sdss_tbl.countRatio > 0)
sdss_in_common = np.logical_and(np.isfinite(sdss_tbl.iPSFminusKron), 
                                sdss_photo_det)
low_z_gal = np.logical_and(sdss_tbl.z < 1e-4, sdss_tbl['class'] == 'GALAXY')
low_z_qso = np.logical_and(sdss_tbl.z < 1, sdss_tbl['class'] == '   QSO')
sdss_test_set = sdss_in_common & ~low_z_gal & ~low_z_qso
    
In [8]:
    
features = ['wwpsfChiSq', 'wwExtNSigma', 'wwpsfLikelihood',
            'wwPSFKronRatio', 'wwPSFKronDist',  'wwPSFApRatio',
            'wwmomentRH', 'wwmomentXX', 'wwmomentXY', 'wwmomentYY', 
            'wwKronRad']
    
In [9]:
    
hst_ml_train_X = np.array(hst_tbl[features].loc[hst_det])
hst_ml_train_y = np.array(hst_tbl["MU_CLASS"].loc[hst_det] - 1)
sdss_ml_test_X = np.array(sdss_tbl[features].loc[sdss_test_set])
sdss_spec_class = np.array(sdss_tbl['class'].loc[sdss_test_set])
sdss_ml_test_y = np.ones_like(sdss_spec_class).astype(int)
sdss_ml_test_y[np.where(sdss_spec_class == "GALAXY")] = 0
    
In [10]:
    
import star_galaxy_models
    
In [11]:
    
rf_obj = star_galaxy_models.RandomForestModel()
rf_obj.read_rf_from_pickle()
    
In [12]:
    
ML_predict = rf_obj.rf_clf_.predict_proba(sdss_ml_test_X)
    
In [13]:
    
ML_fpr, ML_tpr, ML_thre = roc_curve(sdss_ml_test_y, ML_predict[:,1])
sdss_fpr, sdss_tpr,sdss_thre = roc_curve(sdss_ml_test_y, sdss_tbl["countRatio"].loc[sdss_test_set])
dist_fpr, dist_tpr, dist_thre = roc_curve(sdss_ml_test_y, sdss_tbl["wwPSFKronDist"].loc[sdss_test_set])
i_fpr, i_tpr, i_thre = roc_curve(sdss_ml_test_y, -1.*sdss_tbl["iPSFminusKron"].loc[sdss_test_set])
    
In [14]:
    
ps1_preds = -1.*sdss_tbl["iPSFminusKron"].loc[sdss_test_set]
simple_preds = sdss_tbl["wwPSFKronDist"].loc[sdss_test_set]
rf_preds = ML_predict[:,1]
sdss_preds = sdss_tbl["countRatio"].loc[sdss_test_set]
y_sdss = sdss_ml_test_y
    
In [19]:
    
from scipy import interp
from sklearn.metrics import roc_curve, accuracy_score, auc, roc_auc_score
def calc_fom(fpr, tpr, thresh, fpr_at=0.005):
    return interp(fpr_at, fpr, thresh), interp(fpr_at, fpr, tpr)
    
In [50]:
    
def calc_summary_stats(y_sdss, ps1_preds, simple_preds, rf_preds, sdss_preds, 
                       ps1_ct = -0.05, 
                       simple_ct = 9.2e-7,
                       rf_ct = 0.5, 
                       sdss_ct = 10**(-0.185/2.5)):
    ps1_fpr, ps1_tpr, ps1_thresh = roc_curve(y_sdss, ps1_preds)
    ps1_fom = calc_fom(ps1_fpr, ps1_tpr, ps1_thresh)
    ps1_auc = roc_auc_score(y_sdss, ps1_preds)
    ps1_acc = accuracy_score(y_sdss, -1*ps1_preds <= ps1_ct)
    
    simple_fpr, simple_tpr, simple_thresh = roc_curve(y_sdss, simple_preds)
    simple_fom = calc_fom(simple_fpr, simple_tpr, simple_thresh)
    simple_auc = roc_auc_score(y_sdss, simple_preds)
    simple_acc = accuracy_score(y_sdss, simple_preds >= simple_ct)
    
    rf_fpr, rf_tpr, rf_thresh = roc_curve(y_sdss, rf_preds)
    rf_fom = calc_fom(rf_fpr, rf_tpr, rf_thresh)
    rf_auc = roc_auc_score(y_sdss, rf_preds)
    rf_acc = accuracy_score(y_sdss, rf_preds >= rf_ct)
    
    sdss_fpr, sdss_tpr, sdss_thresh = roc_curve(y_sdss, sdss_preds)
    sdss_fom = calc_fom(sdss_fpr, sdss_tpr, sdss_thresh)
    sdss_auc = roc_auc_score(y_sdss, sdss_preds)
    sdss_acc = accuracy_score(y_sdss, sdss_preds >= sdss_ct)
    
    return  rf_auc, rf_acc, rf_fom,  simple_auc, simple_acc, simple_fom, ps1_auc, ps1_acc, ps1_fom, sdss_auc, sdss_acc, sdss_fom
    
In [52]:
    
summary_stats = calc_summary_stats(y_sdss, ps1_preds, simple_preds, rf_preds, sdss_preds)
    
In [61]:
    
print(r"""
    RF & {2:.3f} & {1:.3f}  & {0:.3f}  \\
    simple & {5:.3f}  & {4:.3f} & {3:.3f} \\
    PS1 & {8:.3f} & {7:.3f} & {6:.3f} \\
    SDSS & {11:.3f} & {10:.3f} & {9:.3f} \\
""".format(*summary_stats))
    
    
In [20]:
    
kron_mag = np.array(-2.5*np.log10(sdss_tbl['wwKronFlux'].loc[sdss_test_set]/3631))
    
In [24]:
    
def summary_stats_bootstrap(gt, preds, ct=0.5, fom_at=[0.005, 0.01, 0.02, 0.05, 0.1], Nboot=100, mag_max = 30):
    mag_mask = np.where(kron_mag  <= mag_max)
    acc = accuracy_score(gt[mag_mask], preds[mag_mask] >= ct)
    auc = roc_auc_score(gt[mag_mask], preds[mag_mask])
    fpr, tpr, thresh = roc_curve(gt[mag_mask], preds[mag_mask])
    fom_thresh = np.array([calc_fom(fpr, tpr, thresh, f) for f in fom_at])
    thresh = fom_thresh[:,0]
    fom = fom_thresh[:,1]
    acc_std_arr = np.empty(Nboot)
    auc_std_arr = np.empty_like(acc_std_arr)
    fom_std_arr = np.empty((Nboot, len(fom_at)))
    thresh_std_arr = np.empty_like(fom_std_arr)
    for i in range(Nboot):
        if(i%1==0):
            print("%d."%i, end='')
        boot_sources = np.random.choice(mag_mask[0], len(mag_mask[0]), replace=True)
        auc_std_arr[i] = roc_auc_score(gt[boot_sources], preds[boot_sources])
        acc_std_arr[i] = accuracy_score(gt[boot_sources], preds[boot_sources] >= ct)
        _fpr, _tpr, _thresh = roc_curve(gt[boot_sources], preds[boot_sources])
        _fom_thresh = np.array([calc_fom(_fpr, _tpr, _thresh, f) for f in fom_at])
        thresh_std_arr[i,:] = _fom_thresh[:,0]
        fom_std_arr[i,:] = _fom_thresh[:,1]
    acc_std = np.std(acc_std_arr)
    auc_std = np.std(auc_std_arr)
    fom_std = np.std(fom_std_arr, axis=0)
    thresh_std = np.std(thresh_std_arr, axis=0)
    return {'Num': len(kron_mag[mag_mask]), 
                'Acc': acc*100, 
                'AUC': auc, 
                'FoM': fom, 
                'Thresh': thresh, 
                'AccSTD': acc_std*100, 
                'AUCSTD': auc_std, 
                'FoMSTD': fom_std, 
                'ThreshSTD': thresh_std}
    
In [38]:
    
ps1_ct = -0.05, 
simple_ct = 9.2e-7
rf_ct = 0.5
sdss_ct = 10**(-0.185/2.5)
    
In [25]:
    
stat_rf = summary_stats_bootstrap(y_sdss, rf_preds, ct=rf_ct, Nboot=100)
    
    
In [45]:
    
stat_rf
    
    Out[45]:
In [39]:
    
stat_ps1 = summary_stats_bootstrap(y_sdss, ps1_preds.values, ct=ps1_ct, Nboot=100)
stat_simple = summary_stats_bootstrap(y_sdss, simple_preds.values, ct=simple_ct, Nboot=100)
stat_sdss = summary_stats_bootstrap(y_sdss, sdss_preds.values, ct=sdss_ct, Nboot=100)
    
    
In [40]:
    
stat_ps1
    
    Out[40]:
In [41]:
    
stat_simple
    
    Out[41]:
In [42]:
    
stat_sdss
    
    Out[42]:
In [25]:
    
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cmap = plt.get_cmap("Dark2")
color_dict = {'ml': cmap(0.33), 
              'sdss': cmap(0.66), 
              'simple': cmap(0.),
              'ps1': cmap(1.)}
# apple colors
color_dict = {'ml': "#0072c6", #"#1C1858",
              'sdss': "#5BC236", #"#00C78E",
              'simple': "#C864AF", #"#C70039",
              'ps1': "#C65400"}
# color blind friendly
color_dict = {'ml': '#0072b2', 
              'sdss': '#d55e00', 
              'simple': '#cc79a7',
              'ps1': '#009e73'}
# color brewer
color_dict = {'ml': '#7570b3', 
              'sdss': '#d95f02', 
              'simple': '#1b9e77',
              'ps1': '#34495e'}
col_star_sdss = '#e41a1c'
col_star_ps1 = '#001e43'
ls_dict = {'ml': '-', 
              'sdss': '-.', 
              'simple': '--',
              'ps1': '--'}
lw_dict = {'ml': 1.75, 
              'sdss': 1.5, 
              'simple': 1.5, 
              'ps1': 1.5}
ylims = [0, 1.02]
xlims = [1.7e-4, 1]
fig, main_ax = plt.subplots(figsize=(7,5))
axins = inset_axes(main_ax, width="43.5%",  
                   height="53%", loc=3,
                   bbox_to_anchor=(0.55, 0.07, 1., 1.),
                   bbox_transform=main_ax.transAxes)
main_ax.grid(alpha=0.5, lw=0.5, c='grey', linestyle=':') 
main_ax.tick_params(which="both", top=True, right=True)
main_ax.minorticks_on()
main_ax.plot(ML_fpr, ML_tpr, color=color_dict['ml'], label='RF model', ls=ls_dict['ml'], lw=lw_dict['ml'], alpha=0.9, zorder=4)
main_ax.plot(dist_fpr, dist_tpr, color=color_dict['simple'], label='Simple model', ls=ls_dict['simple'], lw=lw_dict['simple'], alpha=0.9, zorder=3)
main_ax.plot(i_fpr, i_tpr, color=color_dict['ps1'], label='PS1 model', ls=ls_dict['ps1'], lw=lw_dict['ps1'], dashes=(8, 4), alpha=0.9, zorder=2)
main_ax.plot(sdss_fpr, sdss_tpr, color=color_dict['sdss'], label='SDSS photo', ls=ls_dict['sdss'], lw=lw_dict['sdss'], alpha=0.9, zorder=1)
main_ax.plot(sdss_fpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], 
        sdss_tpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], '*', markersize=10, color=col_star_sdss, zorder=5)
main_ax.plot(i_fpr[np.argmin(np.abs(i_thre+0.05))], 
        i_tpr[np.argmin(np.abs(i_thre+0.05))], '*', markersize=10, color=col_star_ps1, zorder=5)
main_ax.vlines([5e-3], 1e-3, 1.02, 
          color='DarkSlateGrey', lw=2., linestyles=":", zorder=1)
main_ax.text(4.9e-3, 0.1, 'FoM', 
        color='DarkSlateGrey', 
        rotation=90, ha="right", fontsize=14)
main_ax.set_xlim(xlims); main_ax.set_ylim(ylims)
main_ax.set_xscale('log')
main_ax.tick_params(labelsize = 15)
main_ax.set_xlabel('False Positive Rate', fontsize=15)
main_ax.set_ylabel('True Positive Rate', fontsize=15)
main_ax.legend(loc=3, borderaxespad=0, fontsize=13, 
               handlelength=1.5, 
               bbox_to_anchor=(0.0125, 0.725, 1., 0.102))
ylims = [0.875, 0.965]
xlims = np.array([7.5e-3, 2.75e-2])
origin = 'lower'
axins.tick_params(which="both", top=True)
axins.minorticks_on()
axins.plot(ML_fpr, ML_tpr, color=color_dict['ml'], label='RF model', ls=ls_dict['ml'], lw=lw_dict['ml'], alpha=0.9, zorder=5)
axins.plot(dist_fpr, dist_tpr, color=color_dict['simple'], label='Simple model', ls=ls_dict['simple'], lw=lw_dict['simple'], alpha=0.9, zorder=4)
axins.plot(i_fpr, i_tpr, color=color_dict['ps1'], label='PS1 model', ls=ls_dict['ps1'], lw=lw_dict['ps1'], dashes=(8, 4), alpha=0.9, zorder=3)
axins.plot(sdss_fpr, sdss_tpr, color=color_dict['sdss'], label='SDSS photo', ls=ls_dict['sdss'], lw=lw_dict['sdss'], alpha=0.9, zorder=2)
axins.plot(sdss_fpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], 
        sdss_tpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], '*', markersize=10, color=col_star, zorder=6)
axins.plot(i_fpr[np.argmin(np.abs(i_thre+0.05))], 
        i_tpr[np.argmin(np.abs(i_thre+0.05))], '*', markersize=10, color=col_star_ps1, zorder=5)
"""
axins.vlines([5e-3], 1e-3, 1, 
          color='DarkSlateGrey', lw=2., linestyles=":", zorder=1)
axins.text(5e-3, 0.59, 'FoM', 
        color='DarkSlateGrey', 
        rotation=90, ha="right", fontsize=14)
"""
axins.set_xlim(xlims); axins.set_ylim(ylims)
#axins.set_yticks([0.85,0.90, 0.95])
axins.tick_params(labelsize = 15)
fig.subplots_adjust(right=0.97,top=0.98,bottom=0.11,left=0.1)
fig.savefig('../paper/Figures/ROC_curves_log_inset2.pdf')
    
    
    
In [31]:
    
ml_decision_thresh = 0.5
sdss_decision_thresh = 10**(-0.185/2.5)
simple_decision_thresh = 9.2e-07 #  maximize acc on training set
ps1_decision_thresh = 0.05
ml_labels = np.logical_not(ML_predict[:,1] < ml_decision_thresh).astype(int)
sdss_labels = np.logical_not(np.array(sdss_tbl["countRatio"].ix[sdss_test_set]) < sdss_decision_thresh).astype(int)
simple_labels = np.logical_not(np.array(sdss_tbl["wwPSFKronDist"].ix[sdss_test_set]) < simple_decision_thresh).astype(int)
ps1_labels = np.logical_not(-1*np.array(sdss_tbl["iPSFminusKron"].ix[sdss_test_set]) < ps1_decision_thresh).astype(int)
    
In [32]:
    
binwidth = 0.5
Nboot = 100
mag_array = np.arange(13 , 23+binwidth, binwidth)
kron_mag = np.array(-2.5*np.log10(sdss_tbl['wwKronFlux'].ix[sdss_test_set]/3631))
sdss_acc_arr = np.zeros_like(mag_array)
simple_acc_arr = np.zeros_like(mag_array)
ps1_acc_arr = np.zeros_like(mag_array)
ml_acc_arr = np.zeros_like(mag_array)
sdss_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
for bin_num, binedge in enumerate(mag_array):
    bin_sources = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth))
    sdss_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                           sdss_labels[bin_sources])
    simple_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                             simple_labels[bin_sources])
    ps1_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                          ps1_labels[bin_sources])
    ml_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources], 
                                         ml_labels[bin_sources])
    sdss_boot_acc = np.empty(Nboot1
    simple_boot_acc = np.empty_like(sdss_boot_acc)
    ps1_boot_acc = np.empty_like(sdss_boot_acc)
    ml_boot_acc = np.empty_like(sdss_boot_acc)
    for i in range(Nboot):
        boot_sources = np.random.choice(bin_sources[0], len(bin_sources[0]), 
                                        replace=True)
        sdss_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                           sdss_labels[boot_sources])
        simple_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                             simple_labels[boot_sources])
        ps1_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                          ps1_labels[boot_sources])
        ml_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                         ml_labels[boot_sources])
    sdss_boot_scatt[:,bin_num] = np.percentile(sdss_boot_acc, [16, 84])
    simple_boot_scatt[:,bin_num] = np.percentile(simple_boot_acc, [16, 84])
    ps1_boot_scatt[:,bin_num] = np.percentile(ps1_boot_acc, [16, 84])
    ml_boot_scatt[:,bin_num] = np.percentile(ml_boot_acc, [16, 84])
    
    
In [34]:
    
from sklearn.neighbors import KernelDensity
kde_grid = np.linspace(13,23.5,200)
sdss_stars = np.where(sdss_ml_test_y == 1)
sdss_gal = np.where(sdss_ml_test_y == 0)
sdss_kde_gal_norm = len(sdss_gal[0])/len(sdss_ml_test_y)
sdss_kde_star_norm = 1 - sdss_kde_gal_norm
kde_sdss = KernelDensity(bandwidth=1.059*np.std(kron_mag, ddof=1)*len(kron_mag)**(-0.2),
                         rtol=1E-4)
kde_sdss.fit(kron_mag[:, np.newaxis])
kde_sdss_stars = KernelDensity(bandwidth=1.059*np.std(kron_mag[sdss_stars], ddof=1)*len(kron_mag[sdss_stars])**(-0.2),
                               rtol=1E-4)
kde_sdss_stars.fit(kron_mag[sdss_stars[0], np.newaxis])
kde_sdss_gal = KernelDensity(bandwidth=1.059*np.std(kron_mag[sdss_gal], ddof=1)*len(kron_mag[sdss_gal])**(-0.2),
                             rtol=1E-4)
kde_sdss_gal.fit(kron_mag[sdss_gal[0], np.newaxis])
pdf_sdss = np.exp(kde_sdss.score_samples(kde_grid[:, np.newaxis]))
pdf_sdss_stars = np.exp(kde_sdss_stars.score_samples(kde_grid[:, np.newaxis]))
pdf_sdss_gal = np.exp(kde_sdss_gal.score_samples(kde_grid[:, np.newaxis]))
    
In [61]:
    
from matplotlib.ticker import MultipleLocator
color_dict = {'ml': "black", #"#1C1858",
              'sdss': "#5BC236", #"#00C78E",
              'simple': "#C864AF", #"#C70039",
              'ps1': "#C65400"}
mag_bin_centers = mag_array + binwidth/2
cmap_star = sns.cubehelix_palette(rot=0.5, light=0.7,dark=0.3,as_cmap=True)
cmap_gal = sns.cubehelix_palette(start=0.3,rot=-0.5,light=0.7,dark=0.3,as_cmap=True)
fig, ax = plt.subplots(figsize=(8, 5))
ax.errorbar(mag_bin_centers, ml_acc_arr, 
            yerr=np.abs(ml_boot_scatt - ml_acc_arr), 
            ls='-', lw=.75, fmt='o',
            color=color_dict['ml'], label="ML model", zorder=2)
"""
ax.errorbar(mag_bin_centers, simple_acc_arr, 
            yerr=np.abs(simple_boot_scatt - simple_acc_arr), 
            ls ='-', lw=.5, fmt='o',
            color=color_dict['simple'], label="Simple model")
ax.errorbar(mag_bin_centers, ps1_acc_arr, 
            yerr=np.abs(ps1_boot_scatt - ps1_acc_arr), 
            ls ='-', lw=.5, dashes=(8, 4), fmt='o',
            color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
"""
ax.errorbar(mag_bin_centers, sdss_acc_arr, 
            yerr=np.abs(sdss_boot_scatt - sdss_acc_arr), 
            ls='-', lw=.5, fmt='o',
            color=color_dict['sdss'], label="SDSS photo", zorder=1)
# add KDE plots
ax.fill(kde_grid, pdf_sdss + 0.5, alpha=0.4, color="0.7")
ax.fill(kde_grid, pdf_sdss_gal*sdss_kde_gal_norm + 0.5, alpha=0.7, color=cmap_gal(0.25))
ax.fill(kde_grid, pdf_sdss_stars*sdss_kde_star_norm + 0.5, alpha=0.7, color=cmap_star(0.25))
ax.set_ylim(0.5,1.01)
ax.set_xlim(13, 23.5)
ax.tick_params(which="both", top=True, right=True, labelsize=15)
ax.set_xlabel('whiteKronMag', fontsize=15)
ax.set_ylabel('Accuracy', fontsize=15)
ax.yaxis.set_minor_locator(MultipleLocator(0.025))
ax.xaxis.set_major_locator(MultipleLocator(2))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.legend(bbox_to_anchor=(0.01, 0.3, 1., 0.102), loc=3, fontsize=13)
fig.subplots_adjust(top=0.98,right=0.98,left=0.1,bottom=0.12)
    
    
In [55]:
    
binwidth = 0.5
Nboot = 100
# bootstrap star arrays
sdss_star_acc_arr = np.zeros_like(mag_array)
simple_star_acc_arr = np.zeros_like(mag_array)
ps1_star_acc_arr = np.zeros_like(mag_array)
ml_star_acc_arr = np.zeros_like(mag_array)
sdss_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
# bootstrap galaxy arrays
sdss_gal_acc_arr = np.zeros_like(mag_array)
simple_gal_acc_arr = np.zeros_like(mag_array)
ps1_gal_acc_arr = np.zeros_like(mag_array)
ml_gal_acc_arr = np.zeros_like(mag_array)
sdss_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
for bin_num, binedge in enumerate(mag_array):
    bin_stars = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                         (sdss_ml_test_y == 1))
    sdss_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                           sdss_labels[bin_stars])
    simple_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                             simple_labels[bin_stars])
    ps1_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                          ps1_labels[bin_stars])
    ml_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars], 
                                         ml_labels[bin_stars])
    
    bin_gals = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                        (sdss_ml_test_y == 0))
    sdss_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                           sdss_labels[bin_gals])
    simple_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                             simple_labels[bin_gals])
    ps1_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                          ps1_labels[bin_gals])
    ml_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals], 
                                         ml_labels[bin_gals])    
    
    # get the bootstrap accuracies
    
    sdss_star_boot_acc = np.empty(Nboot)
    simple_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ps1_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ml_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    
    sdss_gal_boot_acc = np.empty_like(sdss_star_boot_acc)
    simple_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ps1_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ml_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    for i in range(Nboot):
        star_boot_sources = np.random.choice(bin_stars[0], len(bin_stars[0]), 
                                             replace=True)
        sdss_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                           sdss_labels[star_boot_sources])
        simple_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                             simple_labels[star_boot_sources])
        ps1_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                          ps1_labels[star_boot_sources])
        ml_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                         ml_labels[star_boot_sources])
        
        gal_boot_sources = np.random.choice(bin_gals[0], len(bin_gals[0]), 
                                            replace=True)
        sdss_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                           sdss_labels[gal_boot_sources])
        simple_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                             simple_labels[gal_boot_sources])
        ps1_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                          ps1_labels[gal_boot_sources])
        ml_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                         ml_labels[gal_boot_sources])
    sdss_star_boot_scatt[:,bin_num] = np.percentile(sdss_star_boot_acc, [16, 84])
    simple_star_boot_scatt[:,bin_num] = np.percentile(simple_star_boot_acc, [16, 84])
    ps1_star_boot_scatt[:,bin_num] = np.percentile(ps1_star_boot_acc, [16, 84])
    ml_star_boot_scatt[:,bin_num] = np.percentile(ml_star_boot_acc, [16, 84])    
    
    sdss_gal_boot_scatt[:,bin_num] = np.percentile(sdss_gal_boot_acc, [16, 84])
    simple_gal_boot_scatt[:,bin_num] = np.percentile(simple_gal_boot_acc, [16, 84])
    ps1_gal_boot_scatt[:,bin_num] = np.percentile(ps1_gal_boot_acc, [16, 84])
    ml_gal_boot_scatt[:,bin_num] = np.percentile(ml_gal_boot_acc, [16, 84])
    
In [84]:
    
binwidth = 0.5
Nboot = 100
sdss_resamp_arr = np.zeros_like(mag_array)
simple_resamp_arr = np.zeros_like(mag_array)
ps1_resamp_arr = np.zeros_like(mag_array)
ml_resamp_arr = np.zeros_like(mag_array)
sdss_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
# bootstrap star arrays
sdss_star_resamp_arr = np.zeros_like(mag_array)
simple_star_resamp_arr = np.zeros_like(mag_array)
ps1_star_resamp_arr = np.zeros_like(mag_array)
ml_star_resamp_arr = np.zeros_like(mag_array)
sdss_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
# bootstrap galaxy arrays
sdss_gal_resamp_arr = np.zeros_like(mag_array)
simple_gal_resamp_arr = np.zeros_like(mag_array)
ps1_gal_resamp_arr = np.zeros_like(mag_array)
ml_gal_resamp_arr = np.zeros_like(mag_array)
sdss_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
for bin_num, binedge in enumerate(mag_array):
    if bin_num <= 3:
        continue
    
    bin_stars = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                         (sdss_ml_test_y == 1))
    
    bin_gals = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) & 
                        (sdss_ml_test_y == 0))
    hst_stars = np.where((hst_kron_mag >= binedge) & (hst_kron_mag < binedge + binwidth) & 
                         (hst_ml_train_y == 1))
    hst_gals = np.where((hst_kron_mag >= binedge) & (hst_kron_mag < binedge + binwidth) & 
                        (hst_ml_train_y == 0))
    
    # figure out the number of stars and galaxies to select
    if len(hst_stars[0])/len(hst_gals[0]) > len(bin_stars[0])/len(bin_gals[0]):
        n_star_resamp = len(bin_stars[0])
        n_gal_resamp = int(len(bin_stars[0])*len(hst_gals[0])/len(hst_stars[0]))
    else:
        n_star_resamp = int(len(bin_gals[0])*len(hst_stars[0])/len(hst_gals[0]))
        n_gal_resamp = len(bin_gals[0])
    
    print(len(hst_stars[0]), len(hst_gals[0]), 
          len(bin_stars[0]), len(bin_gals[0]),
          n_star_resamp, n_gal_resamp)
    
    # get the bootstrap accuracies    
    sdss_boot_acc = np.empty(Nboot)
    simple_boot_acc = np.empty_like(sdss_boot_acc)
    ps1_boot_acc = np.empty_like(sdss_boot_acc)
    ml_boot_acc = np.empty_like(sdss_boot_acc)
    sdss_star_boot_acc = np.empty(Nboot)
    simple_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ps1_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    ml_star_boot_acc = np.empty_like(sdss_star_boot_acc)
    
    sdss_gal_boot_acc = np.empty_like(sdss_star_boot_acc)
    simple_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ps1_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    ml_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
    for i in range(Nboot):
        star_boot_sources = np.random.choice(bin_stars[0], n_star_resamp, 
                                             replace=True)
        gal_boot_sources = np.random.choice(bin_gals[0], n_gal_resamp, 
                                            replace=True)
        boot_sources = np.append(star_boot_sources, gal_boot_sources)
        sdss_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                           sdss_labels[boot_sources])
        simple_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                             simple_labels[boot_sources])
        ps1_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                          ps1_labels[boot_sources])
        ml_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources], 
                                         ml_labels[boot_sources])
        sdss_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                           sdss_labels[star_boot_sources])
        simple_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                             simple_labels[star_boot_sources])
        ps1_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                          ps1_labels[star_boot_sources])
        ml_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources], 
                                         ml_labels[star_boot_sources])
        
        sdss_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                           sdss_labels[gal_boot_sources])
        simple_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                             simple_labels[gal_boot_sources])
        ps1_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                          ps1_labels[gal_boot_sources])
        ml_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources], 
                                         ml_labels[gal_boot_sources])
    sdss_resamp_arr[bin_num] = np.mean(sdss_boot_acc)
    simple_resamp_arr[bin_num] = np.mean(simple_boot_acc)
    ps1_resamp_arr[bin_num] = np.mean(ps1_boot_acc)
    ml_resamp_arr[bin_num] = np.mean(ml_boot_acc)
    
    sdss_resamp_scatt[:,bin_num] = np.percentile(sdss_boot_acc, [16, 84])
    simple_resamp_scatt[:,bin_num] = np.percentile(simple_boot_acc, [16, 84])
    ps1_resamp_scatt[:,bin_num] = np.percentile(ps1_boot_acc, [16, 84])
    ml_resamp_scatt[:,bin_num] = np.percentile(ml_boot_acc, [16, 84])    
    sdss_star_resamp_arr[bin_num] = np.mean(sdss_star_boot_acc)
    simple_star_resamp_arr[bin_num] = np.mean(simple_star_boot_acc)
    ps1_star_resamp_arr[bin_num] = np.mean(ps1_star_boot_acc)
    ml_star_resamp_arr[bin_num] = np.mean(ml_star_boot_acc)
    
    sdss_star_resamp_scatt[:,bin_num] = np.percentile(sdss_star_boot_acc, [16, 84])
    simple_star_resamp_scatt[:,bin_num] = np.percentile(simple_star_boot_acc, [16, 84])
    ps1_star_resamp_scatt[:,bin_num] = np.percentile(ps1_star_boot_acc, [16, 84])
    ml_star_resamp_scatt[:,bin_num] = np.percentile(ml_star_boot_acc, [16, 84])    
    sdss_gal_resamp_arr[bin_num] = np.mean(sdss_gal_boot_acc)
    simple_gal_resamp_arr[bin_num] = np.mean(simple_gal_boot_acc)
    ps1_gal_resamp_arr[bin_num] = np.mean(ps1_gal_boot_acc)
    ml_gal_resamp_arr[bin_num] = np.mean(ml_gal_boot_acc)
    
    sdss_gal_resamp_scatt[:,bin_num] = np.percentile(sdss_gal_boot_acc, [16, 84])
    simple_gal_resamp_scatt[:,bin_num] = np.percentile(simple_gal_boot_acc, [16, 84])
    ps1_gal_resamp_scatt[:,bin_num] = np.percentile(ps1_gal_boot_acc, [16, 84])
    ml_gal_resamp_scatt[:,bin_num] = np.percentile(ml_gal_boot_acc, [16, 84])
    
    
In [76]:
    
kde_grid = np.linspace(13,23.5,200)
hst_stars = np.where(hst_ml_train_y == 1)
hst_gal = np.where(hst_ml_train_y == 0)
kde_hst = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag, ddof=1)*len(hst_kron_mag)**(-0.2),
                         rtol=1E-4)
kde_hst.fit(hst_kron_mag[:, np.newaxis])
kde_hst_stars = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag[hst_stars], ddof=1)*len(hst_kron_mag[hst_stars])**(-0.2),
                               rtol=1E-4)
kde_hst_stars.fit(hst_kron_mag[hst_stars[0], np.newaxis])
kde_hst_gal = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag[hst_gal], ddof=1)*len(hst_kron_mag[hst_gal])**(-0.2),
                             rtol=1E-4)
kde_hst_gal.fit(hst_kron_mag[hst_gal[0], np.newaxis])
pdf_hst = np.exp(kde_hst.score_samples(kde_grid[:, np.newaxis]))
pdf_hst_stars = np.exp(kde_hst_stars.score_samples(kde_grid[:, np.newaxis]))
pdf_hst_gal = np.exp(kde_hst_gal.score_samples(kde_grid[:, np.newaxis]))
hst_kde_gal_norm = len(hst_gal[0])/len(hst_ml_train_y)
hst_kde_star_norm = 1 - hst_kde_gal_norm
    
In [111]:
    
color_dict = {'ml': "black", #"#1C1858",
              'sdss': "teal",
              'simple': "#C864AF", #"#C70039",
              'ps1': "#C65400"}
plt.figure(figsize=(12.5,5))
plt.subplot(1,2,1)
plt.errorbar(mag_bin_centers+0.1, ml_star_acc_arr, 
            yerr=np.abs(ml_star_boot_scatt - ml_star_acc_arr), 
            ls =ls_dict['ml'], lw=.75, fmt='^', ms=8,
            mec="0.2", mew=0.5,
            color=color_dict['ml'], label="ML stars", zorder=4)
# ax.errorbar(mag_bin_centers, simple_star_acc_arr, 
#             yerr=np.abs(simple_star_boot_scatt - simple_star_acc_arr), 
#             ls =ls_dict['simple'], lw=.5, fmt='*', ms=10,
#             mec="0.2", mew=0.5,
#             color=color_dict['simple'], label="Simple model")
# ax.errorbar(mag_bin_centers, ps1_star_acc_arr, 
#             yerr=np.abs(ps1_star_boot_scatt - ps1_star_acc_arr), 
#             ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt='*', ms=10,
#             mec="0.2", mew=0.5,
#             color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
plt.errorbar(mag_bin_centers+0.1, sdss_star_acc_arr, 
            yerr=np.abs(sdss_star_boot_scatt - sdss_star_acc_arr), 
            ls=ls_dict['sdss'], lw=.5, fmt='^', ms=8,
            mec="0.2", mew=0.5,
            color=color_dict['sdss'], label="SDSS stars", zorder=2)
plt.errorbar(mag_bin_centers-0.1, ml_gal_acc_arr, 
            yerr=np.abs(ml_gal_boot_scatt - ml_gal_acc_arr), 
            ls =ls_dict['ml'], lw=.75, fmt='s',
            color=color_dict['ml'], label="ML galaxies", zorder=3)
# ax.errorbar(mag_bin_centers, simple_gal_acc_arr, 
#             yerr=np.abs(simple_gal_boot_scatt - simple_gal_acc_arr), 
#             ls =ls_dict['simple'], lw=.5, fmt=',',
#             color=color_dict['simple'], label="Simple model")
# ax.errorbar(mag_bin_centers, ps1_gal_acc_arr, 
#             yerr=np.abs(ps1_gal_boot_scatt - ps1_gal_acc_arr), 
#             ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt=',',
#             color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
plt.errorbar(mag_bin_centers-0.1, sdss_gal_acc_arr, 
            yerr=np.abs(sdss_gal_boot_scatt - sdss_gal_acc_arr), 
            ls=ls_dict['sdss'], lw=.5, fmt='s',
            color=color_dict['sdss'], label="SDSS galaxies", zorder=1)
# add KDE plots
plt.fill(kde_grid, pdf_sdss + 0.5, alpha=0.4, color="0.7")
plt.fill(kde_grid, pdf_sdss_gal*sdss_kde_gal_norm + 0.5, alpha=0.7, color=cmap_gal(0.25))
plt.fill(kde_grid, pdf_sdss_stars*sdss_kde_star_norm + 0.5, alpha=0.7, color=cmap_star(0.25))
plt.ylim(0.5,1.01)
plt.xlim(13, 23.5)
plt.tick_params(which="both", top=True, right=True, labelsize=15)
plt.xlabel('whiteKronMag', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.minorticks_on()
#plt.yaxis.set_minor_locator(MultipleLocator(0.025))
#plt.xaxis.set_major_locator(MultipleLocator(2))
#plt.xaxis.set_minor_locator(MultipleLocator(0.5))
plt.legend(bbox_to_anchor=(0.01, 0.4, 1., 0.102), loc=3, fontsize=13)
plt.subplot(1,2,2)
plt.errorbar(mag_bin_centers[4:], ml_resamp_arr[4:], 
            yerr=np.abs(ml_resamp_scatt - ml_resamp_arr)[:,4:], 
            ls =ls_dict['ml'], lw=.75, fmt='o',
            color=color_dict['ml'], label="ML model")
"""
ax.errorbar(mag_bin_centers[4:], simple_resamp_arr[4:], 
            yerr=np.abs(simple_resamp_scatt - simple_resamp_arr)[:,4:], 
            ls =ls_dict['simple'], lw=.5, fmt='o',
            color=color_dict['simple'], label="Simple model")
ax.errorbar(mag_bin_centers[4:], ps1_resamp_arr[4:], 
            yerr=np.abs(ps1_resamp_scatt - ps1_resamp_arr)[:,4:], 
            ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt='o',
            color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
"""
plt.errorbar(mag_bin_centers[4:], sdss_resamp_arr[4:], 
            yerr=np.abs(sdss_resamp_scatt - sdss_resamp_arr)[:,4:], 
            ls=ls_dict['sdss'], lw=.5, fmt='o',
            color=color_dict['sdss'], label="SDSS photo")
# add KDE plots
plt.fill(np.append(kde_grid, 23.5), 
        np.append(pdf_hst, 0) + 0.5, alpha=0.4, color="0.7")
plt.fill(np.append(kde_grid, 23.5), 
        np.append(pdf_hst_gal*hst_kde_gal_norm, 0) + 0.5, alpha=0.7, color=cmap_gal(0.25))
plt.fill(np.append(kde_grid, 23.5), 
        np.append(pdf_hst_stars*hst_kde_star_norm, 0) + 0.5, alpha=0.7, color=cmap_star(0.25))
plt.ylim(0.5,1.01)
plt.xlim(15, 23.5)
plt.tick_params(which="both", top=True, right=True, labelsize=15, labelleft='off')
plt.xlabel('whiteKronMag', fontsize=15)
#plt.ylabel('Accuracy', fontsize=15)
plt.minorticks_on()
plt.legend(bbox_to_anchor=(0.01, 0.4, 1., 0.102), loc=3, fontsize=13)
plt.tight_layout()
plt.savefig('Accuracy_SDSS_ML_model.pdf')
    
    
In [ ]: