In [ ]:
import numpy as np
import wisps
import splat
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix,accuracy_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split, KFold,RepeatedKFold
import seaborn as sns
import pandas as pd
%matplotlib inline
In [ ]:
comb=pd.read_hdf(wisps.COMBINED_PHOTO_SPECTRO_FILE, key='new_stars')
In [ ]:
comb=comb.iloc[(comb[wisps.INDEX_NAMES]).dropna().index]
In [ ]:
comb=comb.dropna()
In [ ]:
train_df=pd.read_pickle(wisps.LIBRARIES+'/training_set.pkl').reset_index(drop=True)
pred_df=wisps.Annotator.reformat_table(comb).reset_index(drop=True)
In [ ]:
pred_df=pred_df.drop_duplicates(subset='grism_id')
In [ ]:
pred_df.shape, train_df.shape
In [ ]:
def apply_scale(x):
##put features on a log scale
#replace nans
y=np.log10(x)
if np.isnan(y) or np.isinf(y):
y=np.random.uniform(-99, -98)
return y
def create_labels(row):
#use multiclass system
label=0
if row.label ==0.:
label=0
if (row.label==1) & (row.spt <20):
label=1
if (row.label==1) & np.logical_and(row.spt >=20, row.spt<30):
label=2
if (row.label==1) & np.logical_and(row.spt >=30, row.spt<45):
label=3
return label
In [ ]:
pred_df['grism_id']=pred_df.grism_id.apply(lambda x: x.lower())
In [ ]:
#features=wisps.INDEX_NAMES
features=np.concatenate([['snr2','snr1', 'snr3', 'snr4', 'f_test', 'line_chi', 'spex_chi'], wisps.INDEX_NAMES])
#features=['snr2','snr1', 'snr3', 'snr4', 'f_test']
In [ ]:
pred_df=pred_df[pred_df.snr2>3.]
train_df['spt']=train_df.spt.apply(wisps.make_spt_number)
pred_df['spt']=pred_df.spt.apply(wisps.make_spt_number)
In [ ]:
labels=train_df.apply(create_labels, axis=1).values
In [ ]:
train_df[features]=(train_df[features]).applymap(apply_scale)
pred_df[features]=(pred_df[features]).applymap(apply_scale)
In [ ]:
scaler = MinMaxScaler(feature_range=(-100, 100))
scaler.fit(train_df[features])
X=scaler.transform(train_df[features])
y=labels
In [ ]:
#scale the data set to predict for the prediction set
pred_set=scaler.transform(pred_df[features])
In [ ]:
class_weigths={0:1., 1:40/10000, 2:1/10000, 3:5/10000}
In [ ]:
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.5, random_state=np.random.randint(1000))
rf = RandomForestClassifier(n_estimators=10000, min_samples_split=2, verbose=True,bootstrap=True, n_jobs=-1,
class_weight=class_weigths, criterion='entropy', random_state=np.random.randint(1000),
warm_start=False)
rf.fit(X_train, y_train)
pred_labels = rf.predict(X_test)
model_accuracy = accuracy_score(y_test, pred_labels)
In [ ]:
print ('accuracy score {}'.format(model_accuracy))
classes=['non-UCD', 'M7-L0', 'L', 'T']
cm = pd.DataFrame(confusion_matrix(y_test, pred_labels),
columns=classes, index=classes)
In [ ]:
#create a table a confusion matrix
fig, ax=plt.subplots(figsize=(8, 6))
matr=(cm/cm.sum()).applymap(lambda x: np.round(x, 2)).values
im = ax.imshow(matr, cmap='Blues')
# We want to show all ticks...
ax.set_xticks(np.arange(len(classes)))
ax.set_yticks(np.arange(len(classes)))
# ... and label them with the respective list entries
ax.set_xticklabels(classes)
ax.set_yticklabels(classes)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=0, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
for i in range(len(classes)):
for j in range(len(classes)):
text = ax.text(j, i, matr[i, j], ha="center", va="center", color="k", fontsize=18)
ax.set_xlim([-0.5, 3.5])
ax.set_ylim([3.5, -0.5])
plt.tight_layout()
plt.savefig(wisps.OUTPUT_FIGURES+'/confusion_matrix.pdf')
In [ ]:
#cleanup
X_train.shape
In [ ]:
'accuracy score {}'.format(model_accuracy)
In [ ]:
rlabels=rf.predict(pred_set)
In [ ]:
len(rlabels[rlabels>0])
In [ ]:
cands=pd.read_pickle(wisps.OUTPUT_FILES+'/true_spectra_cands.pkl')
In [ ]:
cands['grism_id']=cands.grism_id.apply(lambda x: x.lower())
cands['spt']=[x.spectral_type for x in cands.spectra]
In [ ]:
len(cands), len( pred_df[pred_df.grism_id.isin(cands.grism_id.values)])
In [ ]:
strs=wisps.datasets['stars']
In [ ]:
cands[~ cands.grism_id.isin(pred_df.grism_id.values) ]
In [ ]:
cands[~ cands.grism_id.isin(strs.grism_id.values) ]
In [ ]:
cands=cands[cands.spt>=17]
In [ ]:
true=(pred_df[(rlabels>0) & pred_df.grism_id.isin(cands.grism_id.values)]).drop_duplicates(subset='grism_id')
truep=len(true)
ps=len(rlabels[rlabels>0])
In [ ]:
len(true[true.spt.between(17,20)]), len(true[true.spt.between(20,30)]), len(true[true.spt.between(30,40)])
In [ ]:
'FP rate {}'.format((ps-truep)/ps)
In [ ]:
rf_dict={'classifier': rf,
'sclr':scaler,
'feats':features}
In [ ]:
import pickle
In [ ]:
#save the random forest
output_file=wisps.OUTPUT_FILES+'/random_forest_classifier.pkl'
with open(output_file, 'wb') as file:
pickle.dump(rf_dict,file)
In [ ]:
sv_df=pred_df[(rlabels>0)]
In [ ]:
sv_df.to_pickle(wisps.LIBRARIES+'/labelled_by_rf.pkl')
In [ ]:
slbyids=pd.read_pickle(wisps.OUTPUT_FILES+'/selected_by_indices.pkl')
In [ ]:
#slbyids
In [ ]:
len(sv_df[(sv_df.grism_id.isin(slbyids.grism_id)) & (sv_df.grism_id.isin(cands.grism_id))])
In [ ]:
len(sv_df[( ~sv_df.grism_id.isin(slbyids.grism_id)) & (sv_df.grism_id.isin(cands.grism_id))])
In [ ]:
len(slbyids[( ~slbyids.grism_id.isin(sv_df.grism_id)) & (slbyids.grism_id.isin(cands.grism_id))])
In [ ]: