Dataset and Classifier classes for uniform XD/RF/... interface


In [20]:
%%writefile PS1QLS.py
#%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cross_validation import train_test_split
from astroML.density_estimation import XDGMM
from astropy.io import fits
from sklearn.ensemble import RandomForestClassifier
from sklearn import grid_search
from sklearn import metrics
import triangle
import pickle

lQSO_prior = .06007 #~1800 in a 30,000 sq deg PS1-like survey according to OM10 TODO: what is the actual PS1 depth?
dud_prior = 12300 # reciprocal number of objects in PS1 data sample from eric 

def listify(covmat,length):
    if length <= 1:
        return covmat[:,:,np.newaxis]
    else:
        covmatlist = np.dstack((covmat,covmat))
        while covmatlist.shape[2] < length:
            covmatlist = np.dstack((covmatlist,covmat))
    return covmatlist.transpose()

class Dataset(object):

    def __init__(self,data,truth=None,text=None):
        self.data = data
        self.text = text
        self.truth=truth
        self.nFeatures = data.shape[1]
        self.covmat = .9*np.identity(self.nFeatures)+.1*np.ones(self.nFeatures)
        return
        
    def noisify(self):
        perturbations = np.random.multivariate_normal(np.zeros(self.nFeatures),self.covmat,size=self.data.shape[0])
        self.truedata = self.data
        self.data = self.data + perturbations
        return
        
    def resample_from(self,xd_classifier):
        self.resampled = np.concatenate((xd_classifier.lQSO_model.sample(np.sum(self.truth)), \
                                         xd_classifier.dud_model.sample(np.sum(1-self.truth))),axis=0)
        return
    
    def plot(self):
        if hasattr(self,'truedata'):
            if hasattr(self,'resampled'):
                fig1 = triangle.corner(self.truedata,labels=self.text)
                _ = triangle.corner(self.data,fig=fig1,labels=self.text,color='r')
                _ = triangle.corner(self.resampled,fig=fig1,color='b',labels=self.text)
                fig1.suptitle('Everything')
                return
            fig1 = triangle.corner(self.truedata,labels=self.text)
            _ = triangle.corner(self.data,fig=fig1,labels=self.text,color='r')
            fig1.suptitle('Noisy Data over Truth')
            return
        if self.truth is not None:
            fig1 = triangle.corner(self.data[self.truth==0],truths=self.data[self.truth==1],truth_color='b',labels=self.text)
            #fig1 = triangle.corner(self.data[self.truth==1],color='b',fig=fig1,plot_datapoints=True,plot_contours=False,labels=self.text)
            fig1.suptitle('Cornerplot for Raw Data')
        else:
            fig1 = triangle.corner(self.data,labels=self.text)
            fig1.suptitle('Cornerplot for Raw Data')
        return

class Classifier(object):

    def __init__(self,algorithm='XD',n_comp = 20):
        if algorithm == 'XD':
            self.algorithm='XD'
            self.lQSO_model = XDGMM(n_components=n_comp,verbose=True)
            self.dud_model = XDGMM(n_components=n_comp,verbose=True)
        elif algorithm == 'RandomForest':
            self.algorithm = 'RandomForest'
            self.trialRF = RandomForestClassifier()
            self.RF_params = {'n_estimators':(10,50,200),"max_features": ["auto",2,4],
                          'criterion':["gini","entropy"],"min_samples_leaf": [1,2]}
        return
    
    def train(self,train,truth,covmat=1):
        if self.algorithm == 'XD':
            self.XDtrain(train,truth,covmat)
        elif self.algorithm == 'RandomForest':
            self.RFtrain(train,truth)
        return
    
    def RFtrain(self,train,truth):
        tunedRF = grid_search.GridSearchCV(self.trialRF, self.RF_params, score_func=metrics.accuracy_score,\
                                    n_jobs = -1, cv = 3,verbose=1)
        self.optRF = tunedRF.fit(train, truth)
        return
    
    def XDtrain(self,train,truth,covmat=1):
        self.lQSO_model.fit(train[truth==1], listify(covmat,np.sum(truth)))
        self.dud_model.fit(train[truth==0], listify(covmat,np.sum(1-truth)))
        return
    
    def test(self,test,covmat=1):
        if self.algorithm == 'XD':
            self.XDprobs(test,covmat)
        elif self.algorithm == 'RandomForest':
            self.RFprobs(test)
        return
    
    def RFprobs(self,test):
        self.dud_probs = self.optRF.predict_proba(test)[:,0]
        self.lQSO_probs = self.optRF.predict_proba(test)[:,1]
        return
    
    def XDprobs(self,test,covmat):
        lQSO_like = np.sum(np.exp(self.lQSO_model.logprob_a(test, listify(covmat,test.shape[0]))),axis=1)
        dud_like= np.sum(np.exp(self.dud_model.logprob_a(test, listify(covmat,test.shape[0]))),axis=1)
        self.lQSO_probs = (lQSO_like * lQSO_prior) / (lQSO_like * lQSO_prior + dud_like * dud_prior)
        self.dud_probs =  (dud_like  *  dud_prior) / (lQSO_like * lQSO_prior + dud_like * dud_prior)
        return
    
    def make_roc(self,truth):
        fpr, tpr, _ = metrics.roc_curve(truth,self.lQSO_probs,pos_label=1)
        plt.title('ROC Curve')
        plt.plot(fpr,tpr,'b--')
        plt.xlabel('FPR')
        plt.ylabel('TPR')
        return fpr,tpr
    
    def save(self,pkl_fname='classifiers.pkl'):
        outfile = open(pkl_fname,'wb')
        outDict = {}
        if hasattr(self,'lQSO_model'):
            outDict.update({'lQSO_model':self.lQSO_model})
        if hasattr(self,'dud_model'):
            outDict.update({'dud_model':self.dud_model})
        if hasattr(self,'optRF'):
            outDict.update({'optRF':self.optRF})
        pickle.dump(outDict,outfile)
        outfile.close()
        return
    
    def load(self,pkl_fname='classifiers.pkl'):
        pkl_in = open(pkl_fname,'rb')
        inDict = pickle.load(pkl_in)
        pkl_in.close()
        if 'lQSO_model' in inDict.keys():
            self.lQSO_model = inDict['lQSO_model']
        if 'dud_model' in inDict.keys():
            self.dud_model = inDict['dud_model']
        if 'optRF' in inDict.keys():
            self.optRF = inDict['optRF']
        return


Overwriting PS1QLS.py

In [3]:
dataLoc = '../../'
pQSO = np.loadtxt(dataLoc+'pQSO/pSDSScolmag.txt')[:,2:]
lQSO = np.loadtxt(dataLoc+'lQSO/SDSScolmag.txt')[:,2:]
sinQSO = np.loadtxt(dataLoc+'sinQSO/sSDSScolmag.txt')[:,2:]
unlQSO = np.loadtxt(dataLoc+'unlQSO/nlSDSScolmag.txt')[:,2:]
unlQSO[:,3:5] = -unlQSO[:,3:5] #bug in WISE magnitudes for this file

duds = np.concatenate((pQSO,unlQSO,sinQSO),axis=0)
data = np.concatenate((lQSO,duds),axis=0) #all sims
truth = np.concatenate((np.ones(lQSO.shape[0]),np.zeros(duds.shape[0])),axis=0)

Testing Plotting Functionality Post-Object-Orientification


In [4]:
ds = Dataset(data,truth)
ds.plot()
ds.noisify()
ds.plot() #automatically plots both noisy and truth


Testing the RF Classifier via the uniform interface


In [5]:
RF = Classifier(algorithm='RandomForest')
RF.train(ds.data,truth)
RF.test(ds.data)
RF.make_roc(truth)


Fitting 3 folds for each of 36 candidates, totalling 108 fits
/Users/mbaumer/anaconda/lib/python2.7/site-packages/sklearn/grid_search.py:466: DeprecationWarning: Passing function as ``score_func`` is deprecated and will be removed in 0.15. Either use strings or score objects.The relevant new parameter is called ''scoring''.
  self.loss_func, self.score_func, self.scoring)
[Parallel(n_jobs=-1)]: Done   1 jobs       | elapsed:    0.2s
[Parallel(n_jobs=-1)]: Done  50 jobs       | elapsed:    7.4s
[Parallel(n_jobs=-1)]: Done 108 out of 108 | elapsed:   28.4s finished

In [6]:
RF.save('rf.pkl')

In [19]:
reloadedRF = Classifier(algorithm='RandomForest')
reloadedRF.load('classifiers.pkl')
reloadedRF.test(duds)
print reloadedRF.dud_probs


{'optRF': GridSearchCV(cv=3,
       estimator=RandomForestClassifier(bootstrap=True, compute_importances=None,
            criterion='gini', max_depth=None, max_features='auto',
            min_density=None, min_samples_leaf=1, min_samples_split=2,
            n_estimators=10, n_jobs=1, oob_score=False, random_state=None,
            verbose=0),
       fit_params={}, iid=True, loss_func=None, n_jobs=-1,
       param_grid={'n_estimators': (10, 50, 200), 'max_features': ['auto', 2, 4], 'criterion': ['gini', 'entropy'], 'min_samples_leaf': [1, 2]},
       pre_dispatch='2*n_jobs', refit=True,
       score_func=<function accuracy_score at 0x1096ca230>, scoring=None,
       verbose=1)}
[ 0.98291667  0.95595833  1.         ...,  1.          1.          0.99833333]

The ROC curve functionality worked here, but since we're presently testing on the training set, overfitting makes the ROC curve hug the axes!

Testing XD through the same interface!


In [7]:
XD = Classifier(algorithm='XD')
XD.train(ds.data,truth,ds.covmat)
XD.test(ds.data,ds.covmat)
XD.make_roc(truth)


1: log(L) = -4767.4
    (3.2 sec)
2: log(L) = -4695.3
    (3.1 sec)
3: log(L) = -4651.9
    (3.1 sec)
4: log(L) = -4623.2
    (3.2 sec)
5: log(L) = -4603
    (3.2 sec)
6: log(L) = -4588
    (3.1 sec)
7: log(L) = -4576.6
    (3.3 sec)
8: log(L) = -4567.6
    (3.3 sec)
9: log(L) = -4560.4
    (3.3 sec)
10: log(L) = -4554.5
    (3.2 sec)
11: log(L) = -4549.6
    (3.1 sec)
12: log(L) = -4545.6
    (3.1 sec)
13: log(L) = -4542.2
    (3.2 sec)
14: log(L) = -4539.3
    (3.1 sec)
15: log(L) = -4537
    (3.3 sec)
16: log(L) = -4535
    (3.3 sec)
17: log(L) = -4533.3
    (3.2 sec)
18: log(L) = -4532
    (3.2 sec)
19: log(L) = -4531
    (3.3 sec)
20: log(L) = -4530.2
    (3.3 sec)
21: log(L) = -4529.6
    (3.1 sec)
22: log(L) = -4529.3
    (3.1 sec)
23: log(L) = -4529.1
    (3.1 sec)
24: log(L) = -4529.2
    (3.1 sec)
1: log(L) = -21912
    (15 sec)
2: log(L) = -21585
    (14 sec)
3: log(L) = -21395
    (14 sec)
4: log(L) = -21273
    (14 sec)
5: log(L) = -21188
    (14 sec)
6: log(L) = -21127
    (14 sec)
7: log(L) = -21081
    (14 sec)
8: log(L) = -21044
    (14 sec)
9: log(L) = -21016
    (14 sec)
10: log(L) = -20992
    (14 sec)
11: log(L) = -20973
    (14 sec)
12: log(L) = -20957
    (15 sec)
13: log(L) = -20943
    (15 sec)
14: log(L) = -20931
    (14 sec)
15: log(L) = -20920
    (15 sec)
16: log(L) = -20911
    (15 sec)
17: log(L) = -20903
    (14 sec)
18: log(L) = -20896
    (14 sec)
19: log(L) = -20890
    (15 sec)
20: log(L) = -20884
    (14 sec)
21: log(L) = -20879
    (15 sec)
22: log(L) = -20875
    (15 sec)
23: log(L) = -20871
    (15 sec)
24: log(L) = -20867
    (14 sec)
25: log(L) = -20863
    (14 sec)
26: log(L) = -20860
    (14 sec)
27: log(L) = -20857
    (15 sec)
28: log(L) = -20855
    (14 sec)
29: log(L) = -20852
    (14 sec)
30: log(L) = -20850
    (14 sec)
31: log(L) = -20848
    (15 sec)
32: log(L) = -20846
    (14 sec)
33: log(L) = -20844
    (14 sec)
34: log(L) = -20842
    (14 sec)
35: log(L) = -20840
    (14 sec)
36: log(L) = -20839
    (14 sec)
37: log(L) = -20837
    (14 sec)
38: log(L) = -20836
    (14 sec)
39: log(L) = -20835
    (15 sec)
40: log(L) = -20834
    (15 sec)
41: log(L) = -20832
    (15 sec)
42: log(L) = -20831
    (14 sec)
43: log(L) = -20830
    (14 sec)
44: log(L) = -20829
    (14 sec)
45: log(L) = -20828
    (15 sec)
46: log(L) = -20827
    (14 sec)
47: log(L) = -20827
    (14 sec)
48: log(L) = -20826
    (15 sec)
49: log(L) = -20825
    (15 sec)
50: log(L) = -20824
    (14 sec)
51: log(L) = -20824
    (14 sec)
52: log(L) = -20823
    (14 sec)
53: log(L) = -20822
    (14 sec)
54: log(L) = -20822
    (14 sec)
55: log(L) = -20821
    (15 sec)
56: log(L) = -20820
    (15 sec)
57: log(L) = -20820
    (14 sec)
58: log(L) = -20819
    (14 sec)
59: log(L) = -20819
    (14 sec)
60: log(L) = -20818
    (14 sec)
61: log(L) = -20818
    (14 sec)
62: log(L) = -20817
    (14 sec)
63: log(L) = -20817
    (14 sec)
64: log(L) = -20816
    (14 sec)
65: log(L) = -20816
    (14 sec)
66: log(L) = -20816
    (14 sec)
67: log(L) = -20815
    (14 sec)
68: log(L) = -20815
    (14 sec)
69: log(L) = -20814
    (14 sec)
70: log(L) = -20814
    (14 sec)
71: log(L) = -20814
    (15 sec)
72: log(L) = -20813
    (14 sec)
73: log(L) = -20813
    (14 sec)
74: log(L) = -20813
    (15 sec)
75: log(L) = -20812
    (14 sec)
76: log(L) = -20812
    (14 sec)
77: log(L) = -20812
    (14 sec)
78: log(L) = -20811
    (15 sec)
79: log(L) = -20811
    (14 sec)
80: log(L) = -20811
    (14 sec)
81: log(L) = -20811
    (14 sec)
82: log(L) = -20810
    (15 sec)
83: log(L) = -20810
    (14 sec)
84: log(L) = -20810
    (14 sec)
85: log(L) = -20810
    (15 sec)
86: log(L) = -20809
    (14 sec)
87: log(L) = -20809
    (14 sec)
88: log(L) = -20809
    (14 sec)
89: log(L) = -20809
    (14 sec)
90: log(L) = -20808
    (14 sec)
91: log(L) = -20808
    (14 sec)
92: log(L) = -20808
    (14 sec)
93: log(L) = -20808
    (14 sec)
94: log(L) = -20808
    (14 sec)
95: log(L) = -20807
    (14 sec)
96: log(L) = -20807
    (15 sec)
97: log(L) = -20807
    (15 sec)
98: log(L) = -20807
    (15 sec)
99: log(L) = -20807
    (14 sec)
100: log(L) = -20806
    (15 sec)

In [8]:
XD.save('xd.pkl')

In [11]:
XD2 = Classifier(algorithm='XD')
XD2.load('xd.pkl')
print XD2.lQSO_model


<astroML.density_estimation.xdeconv.XDGMM object at 0x105a4ae10>

In [14]:
ds.resample_from(XD)
ds.plot()


/Users/mbaumer/anaconda/lib/python2.7/site-packages/astroML/density_estimation/xdeconv.py:202: DeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  r = np.atleast_1d(np.random.random(size))
/Users/mbaumer/anaconda/lib/python2.7/site-packages/astroML/density_estimation/xdeconv.py:215: DeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  return draw.reshape(shape)
/Users/mbaumer/anaconda/lib/python2.7/site-packages/astroML/density_estimation/xdeconv.py:202: DeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  r = np.atleast_1d(np.random.random(size))
/Users/mbaumer/anaconda/lib/python2.7/site-packages/astroML/density_estimation/xdeconv.py:215: DeprecationWarning: using a non-integer number instead of an integer will result in an error in the future
  return draw.reshape(shape)

In [ ]: