Sample Decision Trees with Scikit-Learn
Load the required components.
Note that this notebook requires the installation of graphviz http://www.graphviz.org/ and pydotplus https://pypi.python.org/pypi/pydotplus in order to see a visualization of the Decision tree
In [2]:
from __future__ import print_function
import os
from IPython.display import Image
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.cross_validation import cross_val_score
from sklearn import tree
from sklearn.externals.six import StringIO
import pydotplus
import matplotlib.pyplot as plt
%matplotlib inline
In [3]:
raw_dir = os.path.join(os.getcwd(), os.pardir, "data/raw/")
diabetesdf = pd.read_csv(raw_dir+"pimadiabetes.csv",index_col=0)
print("* diabetesdf.head()", diabetesdf.head(10), sep="\n", end="\n\n")
print("* diabetesdf.tail()", diabetesdf.tail(10), sep="\n", end="\n\n")
In [4]:
print("* Class types:", diabetesdf["diag"].unique(), sep="\n")
In [5]:
features = list(diabetesdf.columns[:8])
print("* features:", features, sep="\n")
In [6]:
y = diabetesdf["class"]
X = diabetesdf[features]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)
print ("* Training sample size : ", len(X_train))
print ("* Validation sample size : ", len(X_test))
Building the Decision Tree. For full usage of Decision Tree Classifier refer to scikit-learn api http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
Tunable parameters:
In [7]:
dt = tree.DecisionTreeClassifier(criterion='gini',min_samples_split=5, random_state=1024)
dt.fit(X_train, y_train)
#dt.fit(X, y)
Out[7]:
In [8]:
def printScores(amodel, xtrain,ytrain,xtest,ytest):
tscores = amodel.score( xtrain, ytrain)
vscores = amodel.score( xtest, ytest)
print ("Training score is %f" % tscores)
print ("Validation score is %f" % vscores)
print ("Model depth is %i" % amodel.tree_.max_depth )
printScores(dt,X_train,y_train, X_test,y_test)
In [9]:
dot_data = StringIO()
tree.export_graphviz(dt, out_file=dot_data,
feature_names=features,
class_names='class',
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
Out[9]:
In [15]:
def listFeatureImportance(amodel, features):
### Extract feature importance
## based from
## http://stackoverflow.com/questions/34239355/feature-importance-extraction-of-decision-trees-scikit-learn
importances = amodel.feature_importances_
indices = np.argsort(importances)[::-1]
# Print the feature ranking
print('Feature Ranking:')
for f in range(len(features)):
if importances[indices[f]] > 0:
print("%d. feature %d (%f)" % (f + 1, indices[f], importances[indices[f]]))
print ("\tfeature name: ", features[indices[f]])
#listFeatureImportance(dt, features)
In [11]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_auc_score
from sklearn.grid_search import GridSearchCV
param_grid = { 'criterion':['gini','entropy'], 'max_depth': range(1,15) ,
'min_samples_split': range(1,15),'min_samples_leaf':range(1,10)}
Gridtree = GridSearchCV(DecisionTreeClassifier(), param_grid,n_jobs=2,cv=3)
Gridtree.fit(X_train, y_train)
tree_preds = Gridtree.predict_proba(X_test)[:, 1]
tree_performance = roc_auc_score(y_test, tree_preds)
print("Best Classifier : %s (Best Score %0.2f)" % (Gridtree.best_estimator_ , Gridtree.best_score_ ))
print("ROC AUC : (%0.2f)" % (tree_performance ))
bc = Gridtree.best_estimator_
In [17]:
bc.fit(X_train, y_train)
printScores(bc,X_train,y_train, X_test,y_test)
In [13]:
dot_data = StringIO()
tree.export_graphviz(bc, out_file=dot_data,
feature_names=features,
class_names='class',
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
Out[13]:
In [14]:
print ("Original Model")
listFeatureImportance(dt, features)
printScores(dt,X_train,y_train, X_test,y_test)
print ("\n\nBest Model")
listFeatureImportance(bc, features)
printScores(bc,X_train,y_train, X_test,y_test)
In [ ]: