decision trees


In [1]:
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.datasets
import sklearn.tree

In [2]:
plt.rcParams["figure.figsize"] = [17, 10]

Decision trees are directed graphs beginning with one node and branching to many. They are a hierarchical data structure that represent data by implementing a divide-and-conquer strategy. There are two main types of decision tree: classification and regression, and both are used to make predictions based on data. Classification trees output a discrete category/class/target while regression trees output real values. Regression tree algorithms were introduced in 1963 (reference).

Moving through a decision tree, each node splits up the input data. Each node is a sort of cluster of cases that is to be split by further branches in the tree. Often trees are binary, wherein each node is split into two subsamples, but they don't have to be binary.

So, imagine there are some colored shapes that can be classified as A, B or C.

A classification decision tree for the colored shapes could look like this:

Decision trees can be seen as a compact way to represent a lot of data. A usual goal in defining a decision tree is to search for one that is as small as possible.

super simple example of decision tree classifier

scikit-learn provides a DecisionTreeClassifier. It takes as input two arrays, an array of data features and an array of class labels for each collection of features.

Create some data. There are features and there are classifications for each collection of features.


In [3]:
# features
X = [
        [0, 0],
        [1, 1]
    ]

# targets
Y = [
        0,
        1
    ]

classifier = sklearn.tree.DecisionTreeClassifier()
classifier = classifier.fit(X, Y)

Now, predict the class of some example collection of features.


In [4]:
classifier.predict([[2, 2]])


Out[4]:
array([1])

The probability of each class can be predicted too, which is the fraction of training samples of the same class in a leaf.


In [5]:
classifier.predict_proba([[2, 2]])


Out[5]:
array([[0., 1.]])

We can look at the tree in Graphviz format.


In [6]:
graph = graphviz.Source(sklearn.tree.export_graphviz(classifier, out_file=None))
graph;

more detailed example of decision tree classifier using the iris dataset

Get the iris dataset.


In [7]:
iris = sklearn.datasets.load_iris()

The top bit of the dataset looks like this:


In [8]:
pd.DataFrame(
    data    = np.c_[iris["data"], iris["target"]],
    columns = iris["feature_names"] + ["target"]
).head()


Out[8]:
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
0 5.1 3.5 1.4 0.2 0.0
1 4.9 3.0 1.4 0.2 0.0
2 4.7 3.2 1.3 0.2 0.0
3 4.6 3.1 1.5 0.2 0.0
4 5.0 3.6 1.4 0.2 0.0

Make a decision tree and then fit it using the features ("data") and class labels ("target") of the iris dataset.


In [9]:
classifier = sklearn.tree.DecisionTreeClassifier()
classifier = classifier.fit(iris.data, iris.target)

Ok, let's look at the tree, but we'll fancy it up this time with colors and shit.


In [10]:
graph = graphviz.Source(
    sklearn.tree.export_graphviz(
        classifier,
        out_file           = None,
        feature_names      = iris.feature_names,
        class_names        = iris.target_names,
        filled             = True,
        rounded            = False,
        special_characters = True,
        proportion         = True,
    )
)
graph.render('iris_DT')
graph


Out[10]:
Tree 0 petal width (cm) ≤ 0.8 gini = 0.667 samples = 100.0% value = [0.333, 0.333, 0.333] class = setosa 1 gini = 0.0 samples = 33.3% value = [1.0, 0.0, 0.0] class = setosa 0->1 True 2 petal width (cm) ≤ 1.75 gini = 0.5 samples = 66.7% value = [0.0, 0.5, 0.5] class = versicolor 0->2 False 3 petal length (cm) ≤ 4.95 gini = 0.168 samples = 36.0% value = [0.0, 0.907, 0.093] class = versicolor 2->3 12 petal length (cm) ≤ 4.85 gini = 0.043 samples = 30.7% value = [0.0, 0.022, 0.978] class = virginica 2->12 4 petal width (cm) ≤ 1.65 gini = 0.041 samples = 32.0% value = [0.0, 0.979, 0.021] class = versicolor 3->4 7 petal width (cm) ≤ 1.55 gini = 0.444 samples = 4.0% value = [0.0, 0.333, 0.667] class = virginica 3->7 5 gini = 0.0 samples = 31.3% value = [0.0, 1.0, 0.0] class = versicolor 4->5 6 gini = 0.0 samples = 0.7% value = [0.0, 0.0, 1.0] class = virginica 4->6 8 gini = 0.0 samples = 2.0% value = [0.0, 0.0, 1.0] class = virginica 7->8 9 petal length (cm) ≤ 5.45 gini = 0.444 samples = 2.0% value = [0.0, 0.667, 0.333] class = versicolor 7->9 10 gini = 0.0 samples = 1.3% value = [0.0, 1.0, 0.0] class = versicolor 9->10 11 gini = 0.0 samples = 0.7% value = [0.0, 0.0, 1.0] class = virginica 9->11 13 sepal length (cm) ≤ 5.95 gini = 0.444 samples = 2.0% value = [0.0, 0.333, 0.667] class = virginica 12->13 16 gini = 0.0 samples = 28.7% value = [0.0, 0.0, 1.0] class = virginica 12->16 14 gini = 0.0 samples = 0.7% value = [0.0, 1.0, 0.0] class = versicolor 13->14 15 gini = 0.0 samples = 1.3% value = [0.0, 0.0, 1.0] class = virginica 13->15

In [11]:
sklearn.tree.export_graphviz(
    classifier,
    out_file           = "tree_1.svg",
    feature_names      = iris.feature_names,
    class_names        = iris.target_names,
    filled             = True,
    rounded            = False,
    special_characters = True,
    proportion         = True,
)

Right, so now let's make some predictions.


In [12]:
classifier.predict(iris.data)


Out[12]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

How accurate is it? Well, here is what it should have got:


In [13]:
iris.target


Out[13]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

Boom, it's awesome. Well done, decision tree. :)

decision tree regressor

Now, let's take a glance at a decision tree for regression, or modelling something. Here, let's model a slightly noisy sine curve.


In [14]:
rng = np.random.RandomState(1)
X = np.sort(5*rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3*(0.5-rng.rand(16))

In [15]:
plt.scatter(X, y, s=30, edgecolor="black", c="red", label="data")
plt.title("a fuck off noisy sine curve")
plt.xlabel("data")
plt.ylabel("target")
plt.show();


Aait, let's create and fit a decision tree with a depth of like 2 nodes.


In [16]:
regressor = sklearn.tree.DecisionTreeRegressor(max_depth=2)
regressor.fit(X, y);

Ok, let's make some predictions and see how it does.


In [17]:
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_prediction = regressor.predict(X_test)

In [18]:
plt.scatter(X, y, s=30, edgecolor="black", c = "red", label="data")
plt.plot(X_test, y_prediction, color="cornflowerblue", label="max_depth = 2", linewidth=2)
plt.title("just fittin' a noisy sine curve, it's fine")
plt.xlabel("data")
plt.ylabel("target")
plt.legend()
plt.show();


Damn, that shit is woke!


In [19]:
graph = graphviz.Source(
    sklearn.tree.export_graphviz(
        regressor,
        out_file           = None,
        filled             = True,
        rounded            = False
    )
)
graph;

Ok, now let's try a tree with greater depth, like 5 nodes.


In [20]:
regressor = sklearn.tree.DecisionTreeRegressor(max_depth=5)
regressor.fit(X, y);

In [21]:
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_prediction = regressor.predict(X_test)

In [22]:
plt.scatter(X, y, s=30, edgecolor="black", c="red", label="data")
plt.plot(X_test, y_prediction, color="cornflowerblue", label="max_depth = 5", linewidth=2)
plt.title("just fittin' a noisy sine curve, but what the Bjork?")
plt.xlabel("data")
plt.ylabel("target")
plt.legend()
plt.show();


Yeah ok, naw.

It turns out that learning a tree that classifies or models data perfectly may not lead to a tree with good generalization performance. There could be noise in the data (as there was in this example) or the algorithm might be making decisions based on low statistics (very little data).


In [27]:
graph = graphviz.Source(
    sklearn.tree.export_graphviz(
        regressor,
        out_file           = None,
        filled             = True,
        rounded            = False,
        special_characters = True,
        proportion         = True,
    )
)
graph.render('iris_DT')
graph


Out[27]:
Tree 0 X 0 ≤ 3.133 mse = 0.547 samples = 100.0% value = 0.122 1 X 0 ≤ 0.514 mse = 0.231 samples = 63.8% value = 0.571 0->1 True 22 X 0 ≤ 3.85 mse = 0.124 samples = 36.2% value = -0.667 0->22 False 2 X 0 ≤ 0.046 mse = 0.192 samples = 13.8% value = 0.052 1->2 9 X 0 ≤ 2.029 mse = 0.148 samples = 50.0% value = 0.714 1->9 3 mse = 0.0 samples = 1.2% value = -1.149 2->3 4 X 0 ≤ 0.502 mse = 0.052 samples = 12.5% value = 0.173 2->4 5 X 0 ≤ 0.258 mse = 0.035 samples = 11.2% value = 0.221 4->5 8 mse = -0.0 samples = 1.2% value = -0.263 4->8 6 mse = 0.012 samples = 6.2% value = 0.079 5->6 7 mse = 0.007 samples = 5.0% value = 0.398 5->7 10 X 0 ≤ 1.421 mse = 0.093 samples = 30.0% value = 0.839 9->10 17 X 0 ≤ 2.071 mse = 0.172 samples = 20.0% value = 0.527 9->17 11 X 0 ≤ 1.365 mse = 0.119 samples = 18.8% value = 0.741 10->11 14 X 0 ≤ 1.861 mse = 0.007 samples = 11.2% value = 1.0 10->14 12 mse = 0.059 samples = 17.5% value = 0.809 11->12 13 mse = 0.0 samples = 1.2% value = -0.202 11->13 15 mse = 0.006 samples = 8.8% value = 1.025 14->15 16 mse = 0.0 samples = 2.5% value = 0.915 14->16 18 mse = 0.0 samples = 1.2% value = -0.517 17->18 19 X 0 ≤ 2.94 mse = 0.106 samples = 18.8% value = 0.596 17->19 20 mse = 0.071 samples = 17.5% value = 0.542 19->20 21 mse = 0.0 samples = 1.2% value = 1.35 19->21 23 X 0 ≤ 3.429 mse = 0.124 samples = 17.5% value = -0.452 22->23 36 X 0 ≤ 4.68 mse = 0.041 samples = 18.8% value = -0.869 22->36 24 X 0 ≤ 3.373 mse = 0.002 samples = 5.0% value = -0.229 23->24 31 X 0 ≤ 3.446 mse = 0.145 samples = 12.5% value = -0.541 23->31 25 X 0 ≤ 3.336 mse = 0.0 samples = 2.5% value = -0.193 24->25 28 X 0 ≤ 3.41 mse = 0.0 samples = 2.5% value = -0.265 24->28 26 mse = 0.0 samples = 1.2% value = -0.176 25->26 27 mse = 0.0 samples = 1.2% value = -0.209 25->27 29 mse = 0.0 samples = 1.2% value = -0.25 28->29 30 mse = -0.0 samples = 1.2% value = -0.281 28->30 32 mse = 0.0 samples = 1.2% value = -1.635 31->32 33 X 0 ≤ 3.616 mse = 0.014 samples = 11.2% value = -0.419 31->33 34 mse = 0.0 samples = 6.2% value = -0.319 33->34 35 mse = 0.002 samples = 5.0% value = -0.545 33->35 37 X 0 ≤ 4.59 mse = 0.048 samples = 13.8% value = -0.824 36->37 42 X 0 ≤ 4.893 mse = 0.0 samples = 5.0% value = -0.99 36->42 38 X 0 ≤ 4.391 mse = 0.011 samples = 12.5% value = -0.886 37->38 41 mse = -0.0 samples = 1.2% value = -0.208 37->41 39 mse = 0.01 samples = 7.5% value = -0.829 38->39 40 mse = 0.0 samples = 5.0% value = -0.972 38->40 43 X 0 ≤ 4.815 mse = 0.0 samples = 3.8% value = -0.996 42->43 46 mse = -0.0 samples = 1.2% value = -0.973 42->46 44 mse = 0.0 samples = 2.5% value = -0.998 43->44 45 mse = -0.0 samples = 1.2% value = -0.992 43->45

In [ ]: