In this notebook we examine the ROC curve, the FoM, the accuracy, and the ROC AUC for 4 models: the ML model, the simple model, the PS1 model, and the SDSS model, for the SDSS test set. The uncertainty for each statistic is derived by 100 bootstrap resampling.
The performance of the ML model is the best than that of altanative models except the accuracy. The bias in the SDSS test set towards point sources at the faint end and the hard star--galaxy cut employed in the SDSS model makes the reversal of the performance. This bias affects the ROC curve. The kink and the cross of the curve around the TPR=0.015 are due to the bias.
In [47]:
import sys,os,math
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams["font.family"] = "sans-serif"
rcParams['font.sans-serif'] = ['DejaVu Sans']
from matplotlib import gridspec as grs
%matplotlib notebook
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import cm
from astropy.table import Table
import seaborn as sns
import statsmodels.nonparametric.api as smnp
from statsmodels.nonparametric.kernel_density import KDEMultivariate
from scipy.special import expit
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, accuracy_score, auc, make_scorer
from sklearn.model_selection import GridSearchCV, KFold
In [2]:
hst_tbl = Table.read("HST_COSMOS_features_adamamiller.fit").to_pandas()
sdss_tbl = Table.read("sdssSP_MLfeats_adamamiller.fit").to_pandas()
In [5]:
sdss_tbl = sdss_tbl[~sdss_tbl.duplicated('objid')]
In [6]:
hst_det = np.where(hst_tbl.nDetections > 0)
hst_kron_mag = np.array(-2.5*np.log10(hst_tbl['wwKronFlux'].loc[hst_det]/3631))
sdss_det = np.where(sdss_tbl.nDetections > 0)
sdss_photo_det = np.logical_and(np.isfinite(sdss_tbl.type),
sdss_tbl.countRatio > 0)
sdss_in_common = np.logical_and(np.isfinite(sdss_tbl.iPSFminusKron),
sdss_photo_det)
low_z_gal = np.logical_and(sdss_tbl.z < 1e-4, sdss_tbl['class'] == 'GALAXY')
low_z_qso = np.logical_and(sdss_tbl.z < 1, sdss_tbl['class'] == ' QSO')
sdss_test_set = sdss_in_common & ~low_z_gal & ~low_z_qso
In [8]:
features = ['wwpsfChiSq', 'wwExtNSigma', 'wwpsfLikelihood',
'wwPSFKronRatio', 'wwPSFKronDist', 'wwPSFApRatio',
'wwmomentRH', 'wwmomentXX', 'wwmomentXY', 'wwmomentYY',
'wwKronRad']
In [9]:
hst_ml_train_X = np.array(hst_tbl[features].loc[hst_det])
hst_ml_train_y = np.array(hst_tbl["MU_CLASS"].loc[hst_det] - 1)
sdss_ml_test_X = np.array(sdss_tbl[features].loc[sdss_test_set])
sdss_spec_class = np.array(sdss_tbl['class'].loc[sdss_test_set])
sdss_ml_test_y = np.ones_like(sdss_spec_class).astype(int)
sdss_ml_test_y[np.where(sdss_spec_class == "GALAXY")] = 0
In [10]:
import star_galaxy_models
In [11]:
rf_obj = star_galaxy_models.RandomForestModel()
rf_obj.read_rf_from_pickle()
In [12]:
ML_predict = rf_obj.rf_clf_.predict_proba(sdss_ml_test_X)
In [13]:
ML_fpr, ML_tpr, ML_thre = roc_curve(sdss_ml_test_y, ML_predict[:,1])
sdss_fpr, sdss_tpr,sdss_thre = roc_curve(sdss_ml_test_y, sdss_tbl["countRatio"].loc[sdss_test_set])
dist_fpr, dist_tpr, dist_thre = roc_curve(sdss_ml_test_y, sdss_tbl["wwPSFKronDist"].loc[sdss_test_set])
i_fpr, i_tpr, i_thre = roc_curve(sdss_ml_test_y, -1.*sdss_tbl["iPSFminusKron"].loc[sdss_test_set])
In [14]:
ps1_preds = -1.*sdss_tbl["iPSFminusKron"].loc[sdss_test_set]
simple_preds = sdss_tbl["wwPSFKronDist"].loc[sdss_test_set]
rf_preds = ML_predict[:,1]
sdss_preds = sdss_tbl["countRatio"].loc[sdss_test_set]
y_sdss = sdss_ml_test_y
In [19]:
from scipy import interp
from sklearn.metrics import roc_curve, accuracy_score, auc, roc_auc_score
def calc_fom(fpr, tpr, thresh, fpr_at=0.005):
return interp(fpr_at, fpr, thresh), interp(fpr_at, fpr, tpr)
In [50]:
def calc_summary_stats(y_sdss, ps1_preds, simple_preds, rf_preds, sdss_preds,
ps1_ct = -0.05,
simple_ct = 9.2e-7,
rf_ct = 0.5,
sdss_ct = 10**(-0.185/2.5)):
ps1_fpr, ps1_tpr, ps1_thresh = roc_curve(y_sdss, ps1_preds)
ps1_fom = calc_fom(ps1_fpr, ps1_tpr, ps1_thresh)
ps1_auc = roc_auc_score(y_sdss, ps1_preds)
ps1_acc = accuracy_score(y_sdss, -1*ps1_preds <= ps1_ct)
simple_fpr, simple_tpr, simple_thresh = roc_curve(y_sdss, simple_preds)
simple_fom = calc_fom(simple_fpr, simple_tpr, simple_thresh)
simple_auc = roc_auc_score(y_sdss, simple_preds)
simple_acc = accuracy_score(y_sdss, simple_preds >= simple_ct)
rf_fpr, rf_tpr, rf_thresh = roc_curve(y_sdss, rf_preds)
rf_fom = calc_fom(rf_fpr, rf_tpr, rf_thresh)
rf_auc = roc_auc_score(y_sdss, rf_preds)
rf_acc = accuracy_score(y_sdss, rf_preds >= rf_ct)
sdss_fpr, sdss_tpr, sdss_thresh = roc_curve(y_sdss, sdss_preds)
sdss_fom = calc_fom(sdss_fpr, sdss_tpr, sdss_thresh)
sdss_auc = roc_auc_score(y_sdss, sdss_preds)
sdss_acc = accuracy_score(y_sdss, sdss_preds >= sdss_ct)
return rf_auc, rf_acc, rf_fom, simple_auc, simple_acc, simple_fom, ps1_auc, ps1_acc, ps1_fom, sdss_auc, sdss_acc, sdss_fom
In [52]:
summary_stats = calc_summary_stats(y_sdss, ps1_preds, simple_preds, rf_preds, sdss_preds)
In [61]:
print(r"""
RF & {2:.3f} & {1:.3f} & {0:.3f} \\
simple & {5:.3f} & {4:.3f} & {3:.3f} \\
PS1 & {8:.3f} & {7:.3f} & {6:.3f} \\
SDSS & {11:.3f} & {10:.3f} & {9:.3f} \\
""".format(*summary_stats))
In [20]:
kron_mag = np.array(-2.5*np.log10(sdss_tbl['wwKronFlux'].loc[sdss_test_set]/3631))
In [24]:
def summary_stats_bootstrap(gt, preds, ct=0.5, fom_at=[0.005, 0.01, 0.02, 0.05, 0.1], Nboot=100, mag_max = 30):
mag_mask = np.where(kron_mag <= mag_max)
acc = accuracy_score(gt[mag_mask], preds[mag_mask] >= ct)
auc = roc_auc_score(gt[mag_mask], preds[mag_mask])
fpr, tpr, thresh = roc_curve(gt[mag_mask], preds[mag_mask])
fom_thresh = np.array([calc_fom(fpr, tpr, thresh, f) for f in fom_at])
thresh = fom_thresh[:,0]
fom = fom_thresh[:,1]
acc_std_arr = np.empty(Nboot)
auc_std_arr = np.empty_like(acc_std_arr)
fom_std_arr = np.empty((Nboot, len(fom_at)))
thresh_std_arr = np.empty_like(fom_std_arr)
for i in range(Nboot):
if(i%1==0):
print("%d."%i, end='')
boot_sources = np.random.choice(mag_mask[0], len(mag_mask[0]), replace=True)
auc_std_arr[i] = roc_auc_score(gt[boot_sources], preds[boot_sources])
acc_std_arr[i] = accuracy_score(gt[boot_sources], preds[boot_sources] >= ct)
_fpr, _tpr, _thresh = roc_curve(gt[boot_sources], preds[boot_sources])
_fom_thresh = np.array([calc_fom(_fpr, _tpr, _thresh, f) for f in fom_at])
thresh_std_arr[i,:] = _fom_thresh[:,0]
fom_std_arr[i,:] = _fom_thresh[:,1]
acc_std = np.std(acc_std_arr)
auc_std = np.std(auc_std_arr)
fom_std = np.std(fom_std_arr, axis=0)
thresh_std = np.std(thresh_std_arr, axis=0)
return {'Num': len(kron_mag[mag_mask]),
'Acc': acc*100,
'AUC': auc,
'FoM': fom,
'Thresh': thresh,
'AccSTD': acc_std*100,
'AUCSTD': auc_std,
'FoMSTD': fom_std,
'ThreshSTD': thresh_std}
In [38]:
ps1_ct = -0.05,
simple_ct = 9.2e-7
rf_ct = 0.5
sdss_ct = 10**(-0.185/2.5)
In [25]:
stat_rf = summary_stats_bootstrap(y_sdss, rf_preds, ct=rf_ct, Nboot=100)
In [45]:
stat_rf
Out[45]:
In [39]:
stat_ps1 = summary_stats_bootstrap(y_sdss, ps1_preds.values, ct=ps1_ct, Nboot=100)
stat_simple = summary_stats_bootstrap(y_sdss, simple_preds.values, ct=simple_ct, Nboot=100)
stat_sdss = summary_stats_bootstrap(y_sdss, sdss_preds.values, ct=sdss_ct, Nboot=100)
In [40]:
stat_ps1
Out[40]:
In [41]:
stat_simple
Out[41]:
In [42]:
stat_sdss
Out[42]:
In [25]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
cmap = plt.get_cmap("Dark2")
color_dict = {'ml': cmap(0.33),
'sdss': cmap(0.66),
'simple': cmap(0.),
'ps1': cmap(1.)}
# apple colors
color_dict = {'ml': "#0072c6", #"#1C1858",
'sdss': "#5BC236", #"#00C78E",
'simple': "#C864AF", #"#C70039",
'ps1': "#C65400"}
# color blind friendly
color_dict = {'ml': '#0072b2',
'sdss': '#d55e00',
'simple': '#cc79a7',
'ps1': '#009e73'}
# color brewer
color_dict = {'ml': '#7570b3',
'sdss': '#d95f02',
'simple': '#1b9e77',
'ps1': '#34495e'}
col_star_sdss = '#e41a1c'
col_star_ps1 = '#001e43'
ls_dict = {'ml': '-',
'sdss': '-.',
'simple': '--',
'ps1': '--'}
lw_dict = {'ml': 1.75,
'sdss': 1.5,
'simple': 1.5,
'ps1': 1.5}
ylims = [0, 1.02]
xlims = [1.7e-4, 1]
fig, main_ax = plt.subplots(figsize=(7,5))
axins = inset_axes(main_ax, width="43.5%",
height="53%", loc=3,
bbox_to_anchor=(0.55, 0.07, 1., 1.),
bbox_transform=main_ax.transAxes)
main_ax.grid(alpha=0.5, lw=0.5, c='grey', linestyle=':')
main_ax.tick_params(which="both", top=True, right=True)
main_ax.minorticks_on()
main_ax.plot(ML_fpr, ML_tpr, color=color_dict['ml'], label='RF model', ls=ls_dict['ml'], lw=lw_dict['ml'], alpha=0.9, zorder=4)
main_ax.plot(dist_fpr, dist_tpr, color=color_dict['simple'], label='Simple model', ls=ls_dict['simple'], lw=lw_dict['simple'], alpha=0.9, zorder=3)
main_ax.plot(i_fpr, i_tpr, color=color_dict['ps1'], label='PS1 model', ls=ls_dict['ps1'], lw=lw_dict['ps1'], dashes=(8, 4), alpha=0.9, zorder=2)
main_ax.plot(sdss_fpr, sdss_tpr, color=color_dict['sdss'], label='SDSS photo', ls=ls_dict['sdss'], lw=lw_dict['sdss'], alpha=0.9, zorder=1)
main_ax.plot(sdss_fpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))],
sdss_tpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], '*', markersize=10, color=col_star_sdss, zorder=5)
main_ax.plot(i_fpr[np.argmin(np.abs(i_thre+0.05))],
i_tpr[np.argmin(np.abs(i_thre+0.05))], '*', markersize=10, color=col_star_ps1, zorder=5)
main_ax.vlines([5e-3], 1e-3, 1.02,
color='DarkSlateGrey', lw=2., linestyles=":", zorder=1)
main_ax.text(4.9e-3, 0.1, 'FoM',
color='DarkSlateGrey',
rotation=90, ha="right", fontsize=14)
main_ax.set_xlim(xlims); main_ax.set_ylim(ylims)
main_ax.set_xscale('log')
main_ax.tick_params(labelsize = 15)
main_ax.set_xlabel('False Positive Rate', fontsize=15)
main_ax.set_ylabel('True Positive Rate', fontsize=15)
main_ax.legend(loc=3, borderaxespad=0, fontsize=13,
handlelength=1.5,
bbox_to_anchor=(0.0125, 0.725, 1., 0.102))
ylims = [0.875, 0.965]
xlims = np.array([7.5e-3, 2.75e-2])
origin = 'lower'
axins.tick_params(which="both", top=True)
axins.minorticks_on()
axins.plot(ML_fpr, ML_tpr, color=color_dict['ml'], label='RF model', ls=ls_dict['ml'], lw=lw_dict['ml'], alpha=0.9, zorder=5)
axins.plot(dist_fpr, dist_tpr, color=color_dict['simple'], label='Simple model', ls=ls_dict['simple'], lw=lw_dict['simple'], alpha=0.9, zorder=4)
axins.plot(i_fpr, i_tpr, color=color_dict['ps1'], label='PS1 model', ls=ls_dict['ps1'], lw=lw_dict['ps1'], dashes=(8, 4), alpha=0.9, zorder=3)
axins.plot(sdss_fpr, sdss_tpr, color=color_dict['sdss'], label='SDSS photo', ls=ls_dict['sdss'], lw=lw_dict['sdss'], alpha=0.9, zorder=2)
axins.plot(sdss_fpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))],
sdss_tpr[np.argmin(np.abs(sdss_thre-10**(-0.185/2.5)))], '*', markersize=10, color=col_star, zorder=6)
axins.plot(i_fpr[np.argmin(np.abs(i_thre+0.05))],
i_tpr[np.argmin(np.abs(i_thre+0.05))], '*', markersize=10, color=col_star_ps1, zorder=5)
"""
axins.vlines([5e-3], 1e-3, 1,
color='DarkSlateGrey', lw=2., linestyles=":", zorder=1)
axins.text(5e-3, 0.59, 'FoM',
color='DarkSlateGrey',
rotation=90, ha="right", fontsize=14)
"""
axins.set_xlim(xlims); axins.set_ylim(ylims)
#axins.set_yticks([0.85,0.90, 0.95])
axins.tick_params(labelsize = 15)
fig.subplots_adjust(right=0.97,top=0.98,bottom=0.11,left=0.1)
fig.savefig('../paper/Figures/ROC_curves_log_inset2.pdf')
In [31]:
ml_decision_thresh = 0.5
sdss_decision_thresh = 10**(-0.185/2.5)
simple_decision_thresh = 9.2e-07 # maximize acc on training set
ps1_decision_thresh = 0.05
ml_labels = np.logical_not(ML_predict[:,1] < ml_decision_thresh).astype(int)
sdss_labels = np.logical_not(np.array(sdss_tbl["countRatio"].ix[sdss_test_set]) < sdss_decision_thresh).astype(int)
simple_labels = np.logical_not(np.array(sdss_tbl["wwPSFKronDist"].ix[sdss_test_set]) < simple_decision_thresh).astype(int)
ps1_labels = np.logical_not(-1*np.array(sdss_tbl["iPSFminusKron"].ix[sdss_test_set]) < ps1_decision_thresh).astype(int)
In [32]:
binwidth = 0.5
Nboot = 100
mag_array = np.arange(13 , 23+binwidth, binwidth)
kron_mag = np.array(-2.5*np.log10(sdss_tbl['wwKronFlux'].ix[sdss_test_set]/3631))
sdss_acc_arr = np.zeros_like(mag_array)
simple_acc_arr = np.zeros_like(mag_array)
ps1_acc_arr = np.zeros_like(mag_array)
ml_acc_arr = np.zeros_like(mag_array)
sdss_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
for bin_num, binedge in enumerate(mag_array):
bin_sources = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth))
sdss_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources],
sdss_labels[bin_sources])
simple_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources],
simple_labels[bin_sources])
ps1_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources],
ps1_labels[bin_sources])
ml_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_sources],
ml_labels[bin_sources])
sdss_boot_acc = np.empty(Nboot1
simple_boot_acc = np.empty_like(sdss_boot_acc)
ps1_boot_acc = np.empty_like(sdss_boot_acc)
ml_boot_acc = np.empty_like(sdss_boot_acc)
for i in range(Nboot):
boot_sources = np.random.choice(bin_sources[0], len(bin_sources[0]),
replace=True)
sdss_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
sdss_labels[boot_sources])
simple_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
simple_labels[boot_sources])
ps1_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
ps1_labels[boot_sources])
ml_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
ml_labels[boot_sources])
sdss_boot_scatt[:,bin_num] = np.percentile(sdss_boot_acc, [16, 84])
simple_boot_scatt[:,bin_num] = np.percentile(simple_boot_acc, [16, 84])
ps1_boot_scatt[:,bin_num] = np.percentile(ps1_boot_acc, [16, 84])
ml_boot_scatt[:,bin_num] = np.percentile(ml_boot_acc, [16, 84])
In [34]:
from sklearn.neighbors import KernelDensity
kde_grid = np.linspace(13,23.5,200)
sdss_stars = np.where(sdss_ml_test_y == 1)
sdss_gal = np.where(sdss_ml_test_y == 0)
sdss_kde_gal_norm = len(sdss_gal[0])/len(sdss_ml_test_y)
sdss_kde_star_norm = 1 - sdss_kde_gal_norm
kde_sdss = KernelDensity(bandwidth=1.059*np.std(kron_mag, ddof=1)*len(kron_mag)**(-0.2),
rtol=1E-4)
kde_sdss.fit(kron_mag[:, np.newaxis])
kde_sdss_stars = KernelDensity(bandwidth=1.059*np.std(kron_mag[sdss_stars], ddof=1)*len(kron_mag[sdss_stars])**(-0.2),
rtol=1E-4)
kde_sdss_stars.fit(kron_mag[sdss_stars[0], np.newaxis])
kde_sdss_gal = KernelDensity(bandwidth=1.059*np.std(kron_mag[sdss_gal], ddof=1)*len(kron_mag[sdss_gal])**(-0.2),
rtol=1E-4)
kde_sdss_gal.fit(kron_mag[sdss_gal[0], np.newaxis])
pdf_sdss = np.exp(kde_sdss.score_samples(kde_grid[:, np.newaxis]))
pdf_sdss_stars = np.exp(kde_sdss_stars.score_samples(kde_grid[:, np.newaxis]))
pdf_sdss_gal = np.exp(kde_sdss_gal.score_samples(kde_grid[:, np.newaxis]))
In [61]:
from matplotlib.ticker import MultipleLocator
color_dict = {'ml': "black", #"#1C1858",
'sdss': "#5BC236", #"#00C78E",
'simple': "#C864AF", #"#C70039",
'ps1': "#C65400"}
mag_bin_centers = mag_array + binwidth/2
cmap_star = sns.cubehelix_palette(rot=0.5, light=0.7,dark=0.3,as_cmap=True)
cmap_gal = sns.cubehelix_palette(start=0.3,rot=-0.5,light=0.7,dark=0.3,as_cmap=True)
fig, ax = plt.subplots(figsize=(8, 5))
ax.errorbar(mag_bin_centers, ml_acc_arr,
yerr=np.abs(ml_boot_scatt - ml_acc_arr),
ls='-', lw=.75, fmt='o',
color=color_dict['ml'], label="ML model", zorder=2)
"""
ax.errorbar(mag_bin_centers, simple_acc_arr,
yerr=np.abs(simple_boot_scatt - simple_acc_arr),
ls ='-', lw=.5, fmt='o',
color=color_dict['simple'], label="Simple model")
ax.errorbar(mag_bin_centers, ps1_acc_arr,
yerr=np.abs(ps1_boot_scatt - ps1_acc_arr),
ls ='-', lw=.5, dashes=(8, 4), fmt='o',
color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
"""
ax.errorbar(mag_bin_centers, sdss_acc_arr,
yerr=np.abs(sdss_boot_scatt - sdss_acc_arr),
ls='-', lw=.5, fmt='o',
color=color_dict['sdss'], label="SDSS photo", zorder=1)
# add KDE plots
ax.fill(kde_grid, pdf_sdss + 0.5, alpha=0.4, color="0.7")
ax.fill(kde_grid, pdf_sdss_gal*sdss_kde_gal_norm + 0.5, alpha=0.7, color=cmap_gal(0.25))
ax.fill(kde_grid, pdf_sdss_stars*sdss_kde_star_norm + 0.5, alpha=0.7, color=cmap_star(0.25))
ax.set_ylim(0.5,1.01)
ax.set_xlim(13, 23.5)
ax.tick_params(which="both", top=True, right=True, labelsize=15)
ax.set_xlabel('whiteKronMag', fontsize=15)
ax.set_ylabel('Accuracy', fontsize=15)
ax.yaxis.set_minor_locator(MultipleLocator(0.025))
ax.xaxis.set_major_locator(MultipleLocator(2))
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.legend(bbox_to_anchor=(0.01, 0.3, 1., 0.102), loc=3, fontsize=13)
fig.subplots_adjust(top=0.98,right=0.98,left=0.1,bottom=0.12)
In [55]:
binwidth = 0.5
Nboot = 100
# bootstrap star arrays
sdss_star_acc_arr = np.zeros_like(mag_array)
simple_star_acc_arr = np.zeros_like(mag_array)
ps1_star_acc_arr = np.zeros_like(mag_array)
ml_star_acc_arr = np.zeros_like(mag_array)
sdss_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_star_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
# bootstrap galaxy arrays
sdss_gal_acc_arr = np.zeros_like(mag_array)
simple_gal_acc_arr = np.zeros_like(mag_array)
ps1_gal_acc_arr = np.zeros_like(mag_array)
ml_gal_acc_arr = np.zeros_like(mag_array)
sdss_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_gal_boot_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
for bin_num, binedge in enumerate(mag_array):
bin_stars = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) &
(sdss_ml_test_y == 1))
sdss_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars],
sdss_labels[bin_stars])
simple_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars],
simple_labels[bin_stars])
ps1_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars],
ps1_labels[bin_stars])
ml_star_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_stars],
ml_labels[bin_stars])
bin_gals = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) &
(sdss_ml_test_y == 0))
sdss_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals],
sdss_labels[bin_gals])
simple_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals],
simple_labels[bin_gals])
ps1_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals],
ps1_labels[bin_gals])
ml_gal_acc_arr[bin_num] = accuracy_score(sdss_ml_test_y[bin_gals],
ml_labels[bin_gals])
# get the bootstrap accuracies
sdss_star_boot_acc = np.empty(Nboot)
simple_star_boot_acc = np.empty_like(sdss_star_boot_acc)
ps1_star_boot_acc = np.empty_like(sdss_star_boot_acc)
ml_star_boot_acc = np.empty_like(sdss_star_boot_acc)
sdss_gal_boot_acc = np.empty_like(sdss_star_boot_acc)
simple_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
ps1_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
ml_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
for i in range(Nboot):
star_boot_sources = np.random.choice(bin_stars[0], len(bin_stars[0]),
replace=True)
sdss_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
sdss_labels[star_boot_sources])
simple_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
simple_labels[star_boot_sources])
ps1_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
ps1_labels[star_boot_sources])
ml_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
ml_labels[star_boot_sources])
gal_boot_sources = np.random.choice(bin_gals[0], len(bin_gals[0]),
replace=True)
sdss_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
sdss_labels[gal_boot_sources])
simple_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
simple_labels[gal_boot_sources])
ps1_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
ps1_labels[gal_boot_sources])
ml_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
ml_labels[gal_boot_sources])
sdss_star_boot_scatt[:,bin_num] = np.percentile(sdss_star_boot_acc, [16, 84])
simple_star_boot_scatt[:,bin_num] = np.percentile(simple_star_boot_acc, [16, 84])
ps1_star_boot_scatt[:,bin_num] = np.percentile(ps1_star_boot_acc, [16, 84])
ml_star_boot_scatt[:,bin_num] = np.percentile(ml_star_boot_acc, [16, 84])
sdss_gal_boot_scatt[:,bin_num] = np.percentile(sdss_gal_boot_acc, [16, 84])
simple_gal_boot_scatt[:,bin_num] = np.percentile(simple_gal_boot_acc, [16, 84])
ps1_gal_boot_scatt[:,bin_num] = np.percentile(ps1_gal_boot_acc, [16, 84])
ml_gal_boot_scatt[:,bin_num] = np.percentile(ml_gal_boot_acc, [16, 84])
In [84]:
binwidth = 0.5
Nboot = 100
sdss_resamp_arr = np.zeros_like(mag_array)
simple_resamp_arr = np.zeros_like(mag_array)
ps1_resamp_arr = np.zeros_like(mag_array)
ml_resamp_arr = np.zeros_like(mag_array)
sdss_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
# bootstrap star arrays
sdss_star_resamp_arr = np.zeros_like(mag_array)
simple_star_resamp_arr = np.zeros_like(mag_array)
ps1_star_resamp_arr = np.zeros_like(mag_array)
ml_star_resamp_arr = np.zeros_like(mag_array)
sdss_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_star_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
# bootstrap galaxy arrays
sdss_gal_resamp_arr = np.zeros_like(mag_array)
simple_gal_resamp_arr = np.zeros_like(mag_array)
ps1_gal_resamp_arr = np.zeros_like(mag_array)
ml_gal_resamp_arr = np.zeros_like(mag_array)
sdss_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
simple_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ps1_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
ml_gal_resamp_scatt = np.vstack((np.zeros_like(mag_array), np.zeros_like(mag_array)))
for bin_num, binedge in enumerate(mag_array):
if bin_num <= 3:
continue
bin_stars = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) &
(sdss_ml_test_y == 1))
bin_gals = np.where((kron_mag >= binedge) & (kron_mag < binedge + binwidth) &
(sdss_ml_test_y == 0))
hst_stars = np.where((hst_kron_mag >= binedge) & (hst_kron_mag < binedge + binwidth) &
(hst_ml_train_y == 1))
hst_gals = np.where((hst_kron_mag >= binedge) & (hst_kron_mag < binedge + binwidth) &
(hst_ml_train_y == 0))
# figure out the number of stars and galaxies to select
if len(hst_stars[0])/len(hst_gals[0]) > len(bin_stars[0])/len(bin_gals[0]):
n_star_resamp = len(bin_stars[0])
n_gal_resamp = int(len(bin_stars[0])*len(hst_gals[0])/len(hst_stars[0]))
else:
n_star_resamp = int(len(bin_gals[0])*len(hst_stars[0])/len(hst_gals[0]))
n_gal_resamp = len(bin_gals[0])
print(len(hst_stars[0]), len(hst_gals[0]),
len(bin_stars[0]), len(bin_gals[0]),
n_star_resamp, n_gal_resamp)
# get the bootstrap accuracies
sdss_boot_acc = np.empty(Nboot)
simple_boot_acc = np.empty_like(sdss_boot_acc)
ps1_boot_acc = np.empty_like(sdss_boot_acc)
ml_boot_acc = np.empty_like(sdss_boot_acc)
sdss_star_boot_acc = np.empty(Nboot)
simple_star_boot_acc = np.empty_like(sdss_star_boot_acc)
ps1_star_boot_acc = np.empty_like(sdss_star_boot_acc)
ml_star_boot_acc = np.empty_like(sdss_star_boot_acc)
sdss_gal_boot_acc = np.empty_like(sdss_star_boot_acc)
simple_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
ps1_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
ml_gal_boot_acc = np.empty_like(sdss_gal_boot_acc)
for i in range(Nboot):
star_boot_sources = np.random.choice(bin_stars[0], n_star_resamp,
replace=True)
gal_boot_sources = np.random.choice(bin_gals[0], n_gal_resamp,
replace=True)
boot_sources = np.append(star_boot_sources, gal_boot_sources)
sdss_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
sdss_labels[boot_sources])
simple_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
simple_labels[boot_sources])
ps1_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
ps1_labels[boot_sources])
ml_boot_acc[i] = accuracy_score(sdss_ml_test_y[boot_sources],
ml_labels[boot_sources])
sdss_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
sdss_labels[star_boot_sources])
simple_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
simple_labels[star_boot_sources])
ps1_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
ps1_labels[star_boot_sources])
ml_star_boot_acc[i] = accuracy_score(sdss_ml_test_y[star_boot_sources],
ml_labels[star_boot_sources])
sdss_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
sdss_labels[gal_boot_sources])
simple_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
simple_labels[gal_boot_sources])
ps1_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
ps1_labels[gal_boot_sources])
ml_gal_boot_acc[i] = accuracy_score(sdss_ml_test_y[gal_boot_sources],
ml_labels[gal_boot_sources])
sdss_resamp_arr[bin_num] = np.mean(sdss_boot_acc)
simple_resamp_arr[bin_num] = np.mean(simple_boot_acc)
ps1_resamp_arr[bin_num] = np.mean(ps1_boot_acc)
ml_resamp_arr[bin_num] = np.mean(ml_boot_acc)
sdss_resamp_scatt[:,bin_num] = np.percentile(sdss_boot_acc, [16, 84])
simple_resamp_scatt[:,bin_num] = np.percentile(simple_boot_acc, [16, 84])
ps1_resamp_scatt[:,bin_num] = np.percentile(ps1_boot_acc, [16, 84])
ml_resamp_scatt[:,bin_num] = np.percentile(ml_boot_acc, [16, 84])
sdss_star_resamp_arr[bin_num] = np.mean(sdss_star_boot_acc)
simple_star_resamp_arr[bin_num] = np.mean(simple_star_boot_acc)
ps1_star_resamp_arr[bin_num] = np.mean(ps1_star_boot_acc)
ml_star_resamp_arr[bin_num] = np.mean(ml_star_boot_acc)
sdss_star_resamp_scatt[:,bin_num] = np.percentile(sdss_star_boot_acc, [16, 84])
simple_star_resamp_scatt[:,bin_num] = np.percentile(simple_star_boot_acc, [16, 84])
ps1_star_resamp_scatt[:,bin_num] = np.percentile(ps1_star_boot_acc, [16, 84])
ml_star_resamp_scatt[:,bin_num] = np.percentile(ml_star_boot_acc, [16, 84])
sdss_gal_resamp_arr[bin_num] = np.mean(sdss_gal_boot_acc)
simple_gal_resamp_arr[bin_num] = np.mean(simple_gal_boot_acc)
ps1_gal_resamp_arr[bin_num] = np.mean(ps1_gal_boot_acc)
ml_gal_resamp_arr[bin_num] = np.mean(ml_gal_boot_acc)
sdss_gal_resamp_scatt[:,bin_num] = np.percentile(sdss_gal_boot_acc, [16, 84])
simple_gal_resamp_scatt[:,bin_num] = np.percentile(simple_gal_boot_acc, [16, 84])
ps1_gal_resamp_scatt[:,bin_num] = np.percentile(ps1_gal_boot_acc, [16, 84])
ml_gal_resamp_scatt[:,bin_num] = np.percentile(ml_gal_boot_acc, [16, 84])
In [76]:
kde_grid = np.linspace(13,23.5,200)
hst_stars = np.where(hst_ml_train_y == 1)
hst_gal = np.where(hst_ml_train_y == 0)
kde_hst = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag, ddof=1)*len(hst_kron_mag)**(-0.2),
rtol=1E-4)
kde_hst.fit(hst_kron_mag[:, np.newaxis])
kde_hst_stars = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag[hst_stars], ddof=1)*len(hst_kron_mag[hst_stars])**(-0.2),
rtol=1E-4)
kde_hst_stars.fit(hst_kron_mag[hst_stars[0], np.newaxis])
kde_hst_gal = KernelDensity(bandwidth=1.059*np.std(hst_kron_mag[hst_gal], ddof=1)*len(hst_kron_mag[hst_gal])**(-0.2),
rtol=1E-4)
kde_hst_gal.fit(hst_kron_mag[hst_gal[0], np.newaxis])
pdf_hst = np.exp(kde_hst.score_samples(kde_grid[:, np.newaxis]))
pdf_hst_stars = np.exp(kde_hst_stars.score_samples(kde_grid[:, np.newaxis]))
pdf_hst_gal = np.exp(kde_hst_gal.score_samples(kde_grid[:, np.newaxis]))
hst_kde_gal_norm = len(hst_gal[0])/len(hst_ml_train_y)
hst_kde_star_norm = 1 - hst_kde_gal_norm
In [111]:
color_dict = {'ml': "black", #"#1C1858",
'sdss': "teal",
'simple': "#C864AF", #"#C70039",
'ps1': "#C65400"}
plt.figure(figsize=(12.5,5))
plt.subplot(1,2,1)
plt.errorbar(mag_bin_centers+0.1, ml_star_acc_arr,
yerr=np.abs(ml_star_boot_scatt - ml_star_acc_arr),
ls =ls_dict['ml'], lw=.75, fmt='^', ms=8,
mec="0.2", mew=0.5,
color=color_dict['ml'], label="ML stars", zorder=4)
# ax.errorbar(mag_bin_centers, simple_star_acc_arr,
# yerr=np.abs(simple_star_boot_scatt - simple_star_acc_arr),
# ls =ls_dict['simple'], lw=.5, fmt='*', ms=10,
# mec="0.2", mew=0.5,
# color=color_dict['simple'], label="Simple model")
# ax.errorbar(mag_bin_centers, ps1_star_acc_arr,
# yerr=np.abs(ps1_star_boot_scatt - ps1_star_acc_arr),
# ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt='*', ms=10,
# mec="0.2", mew=0.5,
# color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
plt.errorbar(mag_bin_centers+0.1, sdss_star_acc_arr,
yerr=np.abs(sdss_star_boot_scatt - sdss_star_acc_arr),
ls=ls_dict['sdss'], lw=.5, fmt='^', ms=8,
mec="0.2", mew=0.5,
color=color_dict['sdss'], label="SDSS stars", zorder=2)
plt.errorbar(mag_bin_centers-0.1, ml_gal_acc_arr,
yerr=np.abs(ml_gal_boot_scatt - ml_gal_acc_arr),
ls =ls_dict['ml'], lw=.75, fmt='s',
color=color_dict['ml'], label="ML galaxies", zorder=3)
# ax.errorbar(mag_bin_centers, simple_gal_acc_arr,
# yerr=np.abs(simple_gal_boot_scatt - simple_gal_acc_arr),
# ls =ls_dict['simple'], lw=.5, fmt=',',
# color=color_dict['simple'], label="Simple model")
# ax.errorbar(mag_bin_centers, ps1_gal_acc_arr,
# yerr=np.abs(ps1_gal_boot_scatt - ps1_gal_acc_arr),
# ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt=',',
# color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
plt.errorbar(mag_bin_centers-0.1, sdss_gal_acc_arr,
yerr=np.abs(sdss_gal_boot_scatt - sdss_gal_acc_arr),
ls=ls_dict['sdss'], lw=.5, fmt='s',
color=color_dict['sdss'], label="SDSS galaxies", zorder=1)
# add KDE plots
plt.fill(kde_grid, pdf_sdss + 0.5, alpha=0.4, color="0.7")
plt.fill(kde_grid, pdf_sdss_gal*sdss_kde_gal_norm + 0.5, alpha=0.7, color=cmap_gal(0.25))
plt.fill(kde_grid, pdf_sdss_stars*sdss_kde_star_norm + 0.5, alpha=0.7, color=cmap_star(0.25))
plt.ylim(0.5,1.01)
plt.xlim(13, 23.5)
plt.tick_params(which="both", top=True, right=True, labelsize=15)
plt.xlabel('whiteKronMag', fontsize=15)
plt.ylabel('Accuracy', fontsize=15)
plt.minorticks_on()
#plt.yaxis.set_minor_locator(MultipleLocator(0.025))
#plt.xaxis.set_major_locator(MultipleLocator(2))
#plt.xaxis.set_minor_locator(MultipleLocator(0.5))
plt.legend(bbox_to_anchor=(0.01, 0.4, 1., 0.102), loc=3, fontsize=13)
plt.subplot(1,2,2)
plt.errorbar(mag_bin_centers[4:], ml_resamp_arr[4:],
yerr=np.abs(ml_resamp_scatt - ml_resamp_arr)[:,4:],
ls =ls_dict['ml'], lw=.75, fmt='o',
color=color_dict['ml'], label="ML model")
"""
ax.errorbar(mag_bin_centers[4:], simple_resamp_arr[4:],
yerr=np.abs(simple_resamp_scatt - simple_resamp_arr)[:,4:],
ls =ls_dict['simple'], lw=.5, fmt='o',
color=color_dict['simple'], label="Simple model")
ax.errorbar(mag_bin_centers[4:], ps1_resamp_arr[4:],
yerr=np.abs(ps1_resamp_scatt - ps1_resamp_arr)[:,4:],
ls =ls_dict['ps1'], lw=.5, dashes=(8, 4), fmt='o',
color=color_dict['ps1'], label=r'm$_{\rm iPSF}-$m$_{\rm iKron}$')
"""
plt.errorbar(mag_bin_centers[4:], sdss_resamp_arr[4:],
yerr=np.abs(sdss_resamp_scatt - sdss_resamp_arr)[:,4:],
ls=ls_dict['sdss'], lw=.5, fmt='o',
color=color_dict['sdss'], label="SDSS photo")
# add KDE plots
plt.fill(np.append(kde_grid, 23.5),
np.append(pdf_hst, 0) + 0.5, alpha=0.4, color="0.7")
plt.fill(np.append(kde_grid, 23.5),
np.append(pdf_hst_gal*hst_kde_gal_norm, 0) + 0.5, alpha=0.7, color=cmap_gal(0.25))
plt.fill(np.append(kde_grid, 23.5),
np.append(pdf_hst_stars*hst_kde_star_norm, 0) + 0.5, alpha=0.7, color=cmap_star(0.25))
plt.ylim(0.5,1.01)
plt.xlim(15, 23.5)
plt.tick_params(which="both", top=True, right=True, labelsize=15, labelleft='off')
plt.xlabel('whiteKronMag', fontsize=15)
#plt.ylabel('Accuracy', fontsize=15)
plt.minorticks_on()
plt.legend(bbox_to_anchor=(0.01, 0.4, 1., 0.102), loc=3, fontsize=13)
plt.tight_layout()
plt.savefig('Accuracy_SDSS_ML_model.pdf')
In [ ]: