In [1]:
from sklearn.ensemble import ExtraTreesClassifier
from sklearn import cross_validation
from sklearn.feature_selection import RFECV
from sklearn.grid_search import GridSearchCV
from sklearn.pipeline import Pipeline
from time import time
from operator import itemgetter
import numpy as np
import pandas as pd
In [2]:
df_train = pd.read_csv('../Shelter_train.csv')
df_test = pd.read_csv('../Shelter_test.csv')
In [3]:
X = df_train.ix[:, :-1]
y = df_train.ix[:, -1]
df_test = df_test.drop('ID', 1)
In [4]:
clf = ExtraTreesClassifier()
cross_validation.cross_val_score(clf, X, y, scoring="log_loss")
Out[4]:
In [5]:
params = {"clf__max_depth": [5, 3, None],
"clf__max_features": [0.1, 0.25, 0.5, 1.0],
"clf__min_samples_split": [1, 3, 10],
"clf__min_samples_leaf": [1, 3, 10],
"clf__bootstrap": [True, False],
"clf__criterion": ["gini", "entropy"],
"clf__class_weight": ["balanced", "balanced_subsample", None,]}
In [6]:
def report(grid_scores, n_top=3):
top_scores = sorted(grid_scores, key=itemgetter(1), reverse=True)[:n_top]
for i, score in enumerate(top_scores):
print("Model with rank: {0}".format(i + 1))
print("Mean validation score: {0:.3f} (std: {1:.3f})".format(
score.mean_validation_score,
np.std(score.cv_validation_scores)))
print("Parameters: {0}".format(score.parameters))
print("")
In [7]:
pipeline = Pipeline([
('featureSelection', RFECV(estimator=ExtraTreesClassifier(), scoring='log_loss')),
('clf', ExtraTreesClassifier(n_estimators=20))
])
grid_search = GridSearchCV(pipeline, params, n_jobs=-1, scoring='log_loss')
start = time()
grid_search.fit(X, y)
print("GridSearchCV took %.2f seconds for %d candidate parameter settings."
% (time() - start, len(grid_search.grid_scores_)))
report(grid_search.grid_scores_)
predictions = grid_search.predict_proba(df_test)
output = pd.DataFrame(predictions, columns=['Adoption', 'Died', 'Euthanasia', 'Return_to_owner', 'Transfer'])
output.index.names = ['ID']
output.index += 1
output.head()
Out[7]:
In [8]:
output.to_csv('../submission-extraTreesClassifier.2.0.csv', index_label = 'ID')