Author: Christin Seifert, licensed under the Creative Commons Attribution 3.0 Unported License https://creativecommons.org/licenses/by/3.0/
This is a tutorial for learing and evaluating a simple decision tree on the famous breast cancer data set. In this notebook you will
It is assumed that you have some general knowledge on
In [1]:
# classifying
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# pretty printing
from pprint import pprint
# visualizing
import matplotlib.pyplot as plt
from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
The dataset comes with sklearn. So the only thing we have to do is to load it and see what's in there. The original version of the data set can be found at the UCI Machine learning Repository. Some more description can be found in the Sklearn Documentation.
We will have a look at
In [2]:
from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()
# The shape of the data matrix (without class attribute)
print("Matrix shape: " + repr(data.data.shape))
# The names of the features
print("The data set has the following features:")
pprint(data.feature_names)
# The names of the classes
print("The data set has the following classes:")
pprint(data.target_names)
pprint(data.data[1])
# This prints a rather long descriptions.
#print(data.DESCR)
There are 30 features. For the classifier we will only used the first 10 (those are the ones that contain the average values).
We further need to keep some examples back that the classifier will not see during training in order to evaluate how well it does if it hasn't seen the examples before. Those will end up in our test split. The choice, whether one example is shown or not should be done randomly with shuffle=True
(because we do not know whether the data set is somehow ordered).
The variables below mean the following:
X_train
- matrix with features for the training data X_test
- matrix with features for the testing datay_train
- vector with the labels (true labels) for the training data set y_test
- vector with the labels (true labels) for the test data set
In [3]:
# Split into training and test
X_train, X_test, y_train, y_test = train_test_split(data.data[:,0:9],data.target,shuffle=True,test_size=0.3, random_state=42)
Now we are ready to train the classifier. We first as for a on decision tree object that can do the training and then we tell it to train on the data we present. Note: At this point, the decision tree only sees X_train
and y_train
, so the features of the training data and the true labels of the training data.
In [4]:
# DECISION TREE
# initialize the model with standard parameters
clf_dt = DecisionTreeClassifier(criterion="entropy")
# train the model
clf_dt.fit(X_train,y_train)
Out[4]:
Now we have a trained decision tree. So what? Is this a good one? How does it look like?
To answer these questions we first let it make predictions on the test data set X_test
(we know the true labels, but we won't tell them to the decision tree). And then we compare the labels the tree predicted y_test_pred
with the true labels y_test
. And we count how often they agree, which gives us the accuracy of the decision tree. We do the same for the training data (let it predict the training data and compare with the true labels).
In [5]:
# Evaluating on the test data
y_test_pred = clf_dt.predict(X_test);
a_dt_test = accuracy_score(y_test, y_test_pred);
# Evaluating on the training data
y_train_pred = clf_dt.predict(X_train);
a_dt_train = accuracy_score(y_train, y_train_pred);
print("Training data accuracy is " + repr(a_dt_train) + " and test data accuracy is " + repr(a_dt_test))
We see that the decision tree classifies the training data perfectly, and the test data quite good. The difference in training and test accuracy is an indication for overfitting, which is bad. So we might decide to try different parameters for the decision tree learner (we are not going to do this in this notebook):
In [6]:
dot_data = StringIO()
export_graphviz(clf_dt, out_file=dot_data,
feature_names=data.feature_names[0:9],
class_names=data.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
Out[6]:
That's all for today.
In [ ]: