In [ ]:
import numpy as np
import wisps
import splat
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix,accuracy_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.model_selection import train_test_split, KFold,RepeatedKFold

import seaborn as sns
import pandas as pd
%matplotlib inline

In [ ]:
comb=pd.read_hdf(wisps.COMBINED_PHOTO_SPECTRO_FILE, key='new_stars')

In [ ]:
comb=comb.iloc[(comb[wisps.INDEX_NAMES]).dropna().index]

In [ ]:
comb=comb.dropna()

In [ ]:
train_df=pd.read_pickle(wisps.LIBRARIES+'/training_set.pkl').reset_index(drop=True)
                                                                           
pred_df=wisps.Annotator.reformat_table(comb).reset_index(drop=True)

In [ ]:
pred_df=pred_df.drop_duplicates(subset='grism_id')

In [ ]:
pred_df.shape, train_df.shape

In [ ]:
def apply_scale(x):
    ##put features on a log scale
    #replace nans
    y=np.log10(x)
    if np.isnan(y) or np.isinf(y):
        y=np.random.uniform(-99, -98)
    return y

def create_labels(row):
    #use multiclass system
    label=0
    if row.label ==0.:
        label=0
    if (row.label==1) & (row.spt <20):
        label=1
    if (row.label==1) & np.logical_and(row.spt >=20, row.spt<30):
        label=2
    if (row.label==1) & np.logical_and(row.spt >=30, row.spt<45):
        label=3
    return label

In [ ]:
pred_df['grism_id']=pred_df.grism_id.apply(lambda x: x.lower())

In [ ]:
#features=wisps.INDEX_NAMES
features=np.concatenate([['snr2','snr1', 'snr3', 'snr4', 'f_test', 'line_chi', 'spex_chi'], wisps.INDEX_NAMES])
#features=['snr2','snr1', 'snr3', 'snr4', 'f_test']

In [ ]:
pred_df=pred_df[pred_df.snr2>3.]
train_df['spt']=train_df.spt.apply(wisps.make_spt_number)
pred_df['spt']=pred_df.spt.apply(wisps.make_spt_number)

In [ ]:
labels=train_df.apply(create_labels, axis=1).values

In [ ]:
train_df[features]=(train_df[features]).applymap(apply_scale)
pred_df[features]=(pred_df[features]).applymap(apply_scale)

In [ ]:
scaler = MinMaxScaler(feature_range=(-100, 100))
scaler.fit(train_df[features])
X=scaler.transform(train_df[features])
y=labels

In [ ]:
#scale the data set to predict for the prediction set
pred_set=scaler.transform(pred_df[features])

In [ ]:
class_weigths={0:1., 1:40/10000, 2:1/10000, 3:5/10000}

In [ ]:
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                   test_size=0.5,  random_state=np.random.randint(1000))
    
rf = RandomForestClassifier(n_estimators=10000, min_samples_split=2, verbose=True,bootstrap=True, n_jobs=-1, 
                            class_weight=class_weigths, criterion='entropy',  random_state=np.random.randint(1000), 
                            warm_start=False)
rf.fit(X_train, y_train)
pred_labels = rf.predict(X_test)
model_accuracy = accuracy_score(y_test, pred_labels)

In [ ]:
print ('accuracy score {}'.format(model_accuracy))
classes=['non-UCD', 'M7-L0', 'L', 'T']
cm = pd.DataFrame(confusion_matrix(y_test, pred_labels), 
                  columns=classes, index=classes)

In [ ]:
#create a table a confusion matrix

fig, ax=plt.subplots(figsize=(8, 6))

matr=(cm/cm.sum()).applymap(lambda x: np.round(x, 2)).values
im = ax.imshow(matr, cmap='Blues')

# We want to show all ticks...
ax.set_xticks(np.arange(len(classes)))
ax.set_yticks(np.arange(len(classes)))
# ... and label them with the respective list entries
ax.set_xticklabels(classes)
ax.set_yticklabels(classes)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=0, ha="right",
         rotation_mode="anchor")

# Loop over data dimensions and create text annotations.
for i in range(len(classes)):
    for j in range(len(classes)):
        text = ax.text(j, i, matr[i, j], ha="center", va="center", color="k", fontsize=18)
ax.set_xlim([-0.5, 3.5])
ax.set_ylim([3.5, -0.5])
plt.tight_layout()
plt.savefig(wisps.OUTPUT_FIGURES+'/confusion_matrix.pdf')

In [ ]:
#cleanup
X_train.shape

In [ ]:
'accuracy score {}'.format(model_accuracy)

In [ ]:
rlabels=rf.predict(pred_set)

In [ ]:
len(rlabels[rlabels>0])

In [ ]:
cands=pd.read_pickle(wisps.OUTPUT_FILES+'/true_spectra_cands.pkl')

In [ ]:
cands['grism_id']=cands.grism_id.apply(lambda x: x.lower())
cands['spt']=[x.spectral_type for x in cands.spectra]

In [ ]:
len(cands), len( pred_df[pred_df.grism_id.isin(cands.grism_id.values)])

In [ ]:
strs=wisps.datasets['stars']

In [ ]:
cands[~ cands.grism_id.isin(pred_df.grism_id.values) ]

In [ ]:
cands[~ cands.grism_id.isin(strs.grism_id.values) ]

In [ ]:
cands=cands[cands.spt>=17]

In [ ]:
true=(pred_df[(rlabels>0) & pred_df.grism_id.isin(cands.grism_id.values)]).drop_duplicates(subset='grism_id')
truep=len(true)
ps=len(rlabels[rlabels>0])

In [ ]:
len(true[true.spt.between(17,20)]), len(true[true.spt.between(20,30)]), len(true[true.spt.between(30,40)])

In [ ]:
'FP rate {}'.format((ps-truep)/ps)

In [ ]:
rf_dict={'classifier': rf,
            'sclr':scaler,
            'feats':features}

In [ ]:
import pickle

In [ ]:
#save the random forest
output_file=wisps.OUTPUT_FILES+'/random_forest_classifier.pkl'
with open(output_file, 'wb') as file:
    pickle.dump(rf_dict,file)

In [ ]:
sv_df=pred_df[(rlabels>0)]

In [ ]:
sv_df.to_pickle(wisps.LIBRARIES+'/labelled_by_rf.pkl')

In [ ]:
slbyids=pd.read_pickle(wisps.OUTPUT_FILES+'/selected_by_indices.pkl')

In [ ]:
#slbyids

In [ ]:
len(sv_df[(sv_df.grism_id.isin(slbyids.grism_id)) & (sv_df.grism_id.isin(cands.grism_id))])

In [ ]:
len(sv_df[( ~sv_df.grism_id.isin(slbyids.grism_id)) & (sv_df.grism_id.isin(cands.grism_id))])

In [ ]:
len(slbyids[( ~slbyids.grism_id.isin(sv_df.grism_id)) & (slbyids.grism_id.isin(cands.grism_id))])

In [ ]: