This notebook contains various implemenations of decision tree visualization using dtreeviz. We have tested and tuned to make visualizations look nice for different data sets (both classification and regression datasets) and for various tree depths. The plots have been carefully tweaked to make them look good within jupyter notebooks as well as when saves as image (svg/pdf) files.
Make sure to have latest graphviz installed
In [15]:
import sys
if 'google.colab' in sys.modules:
!pip install -q dtreeviz
In [2]:
import sys
import os
# add library module to PYTHONPATH
sys.path.append(f"{os.getcwd()}/../")
In [3]:
from sklearn.datasets import *
from dtreeviz.trees import *
from IPython.display import Image, display_svg, SVG
The default regression tree is the fancy version with top-down alignment that shows decision nodes using scatter plot. Below is the decision tree with depth=3 on boston house-prices dataset (regression).
In [13]:
regr = tree.DecisionTreeRegressor(max_depth=3)
boston = load_boston()
X_train = boston.data
y_train = boston.target
regr.fit(X_train, y_train)
viz = dtreeviz(regr,
X_train,
y_train,
target_name='price', # this name will be displayed at the leaf node
feature_names=boston.feature_names
)
viz
# viz.view() will give give a popup with graph in pdf
Out[13]:
Here's how to scale the overall image
In [14]:
dtreeviz(regr,
X_train,
y_train,
target_name='price', # this name will be displayed at the leaf node
feature_names=boston.feature_names,
scale=.5
)
Out[14]:
G
node2
node5
leaf3
node2->leaf3
leaf4
node2->leaf4
leaf6
node5->leaf6
leaf7
node5->leaf7
node1
node1->node2
node1->node5
node8
node9
node12
leaf10
node9->leaf10
leaf11
node9->leaf11
leaf13
node12->leaf13
leaf14
node12->leaf14
node8->node9
node8->node12
node0
node0->node1
<
node0->node8
≥
Let's see visualziations on the classic iris multi-class dataset. It's required to pass the class_names argument for classification trees. This is required to match the legend lables with right category codes of class. The order of labels should be in sequence of class categories.
In [6]:
clas = tree.DecisionTreeClassifier(max_depth=2)
iris = load_iris()
X_train = iris.data
y_train = iris.target
clas.fit(X_train, y_train)
viz = dtreeviz(clas,
X_train,
y_train,
target_name='price',
feature_names=iris.feature_names,
class_names=["setosa", "versicolor", "virginica"],
histtype= 'barstacked') # barstackes is default
viz
Out[6]:
G
cluster_legend
node2
leaf3
node2->leaf3
leaf4
node2->leaf4
node0
node0->node2
≥
leaf1
node0->leaf1
<
legend
Breast Cancer Wisconsin Dataset
We can just use orientation parameter to display the trees from left to right rather than top down
In [7]:
clas = tree.DecisionTreeClassifier(max_depth=2)
cancer = load_breast_cancer()
X_train = cancer.data
y_train = cancer.target
clas.fit(X_train, y_train)
viz = dtreeviz(clas,
X_train,
y_train,
target_name='cancer',
feature_names=cancer.feature_names,
class_names=["malignant", "beingn"],
orientation='LR')
viz
Out[7]:
G
cluster_legend
node1
node4
leaf2
node1->leaf2
leaf3
node1->leaf3
leaf5
node4->leaf5
leaf6
node4->leaf6
node0
node0->node1
<
node0->node4
≥
legend
When there are more than four or five classes, the stacked histograms are difficult to read and so we recommend setting the histtype parameter to bar not barstacked in this case. With high cardinality target categories, the overlapping distributions are harder to visualize and things break down, so we set a limit of 10 target classes.
In [8]:
clas = tree.DecisionTreeClassifier(max_depth=2)
digits = load_digits()
X_train = digits.data
y_train = digits.target
clas.fit(X_train, y_train)
# "8x8 image of integer pixels in the range 0..16."
columns = [f'pixel[{i},{j}]' for i in range(8) for j in range(8)]
viz = dtreeviz(clas,
X_train,
y_train,
target_name='number',
feature_names=columns,
class_names=[chr(c) for c in range(ord('0'),ord('9')+1)],
histtype='bar',
orientation ='TD')
viz
Out[8]:
G
cluster_legend
node1
node4
leaf2
node1->leaf2
leaf3
node1->leaf3
leaf5
node4->leaf5
leaf6
node4->leaf6
node0
node0->node1
<
node0->node4
≥
legend
Sometimes, it is important to understand which decision path is followed by a specific test observation. The prediction path is usually used for interpretetion of a prediction to understand why the tree made xyz prediction for the observation abc.
In [9]:
clf = tree.DecisionTreeClassifier(max_depth=2)
wine = load_wine()
X_train = wine.data
y_train = wine.target
clf.fit(X_train, y_train)
# pick random X observation for demo
X = wine.data[np.random.randint(0, len(wine.data)),:]
viz = dtreeviz(clf,
wine.data,
wine.target,
target_name='wine',
feature_names=wine.feature_names,
class_names=list(wine.target_names),
X=X) # pass the test observation
viz
Out[9]:
G
cluster_legend
cluster_instance
node1
node4
leaf2
node1->leaf2
leaf3
node1->leaf3
leaf5
node4->leaf5
leaf6
node4->leaf6
node0
node0->node1
<
node0->node4
≥
X_y
alcohol
malic_acid
ash
alcalinity_of_ash
magnesium
total_phenols
flavanoids
nonflavanoid_phenols
proanthocyanins
color_intensity
hue
od280/od315_of_diluted_wines
proline
13.05
1.77
2.10
17.00
107.00
3.00
3.00
0.28
2.03
5.04
0.88
3.35
885.00
leaf6->X_y
Prediction
class_0
legend
We can turn on node id labelling that is useful if we are trying to understand or expain the working of a decision tree using dtreeviz visualizations.
In [10]:
regr = tree.DecisionTreeRegressor(max_depth=3)
diabetes = load_diabetes()
X_train = diabetes.data
y_train = diabetes.target
regr.fit(X_train, y_train)
X = diabetes.data[np.random.randint(0, len(diabetes.data)),:]
viz = dtreeviz(regr,
X_train,
y_train,
target_name='progr', # this name will be displayed at the leaf node
feature_names=diabetes.feature_names,
X=X,
show_node_labels = True
)
viz
Out[10]:
G
cluster_instance
node2
Node 2
node5
Node 5
leaf3
Node 3
node2->leaf3
leaf4
Node 4
node2->leaf4
leaf6
Node 6
node5->leaf6
leaf7
Node 7
node5->leaf7
node1
Node 1
node1->node2
node1->node5
node8
Node 8
node9
Node 9
node12
Node 12
leaf10
Node 10
node9->leaf10
leaf11
Node 11
node9->leaf11
leaf13
Node 13
node12->leaf13
leaf14
Node 14
node12->leaf14
node8->node9
node8->node12
node0
Node 0
node0->node1
<
node0->node8
≥
X_y
age
sex
bmi
bp
s1
s2
s3
s4
s5
s6
0.00
0.05
-0.06
-0.04
-0.10
-0.05
-0.10
0.03
-0.06
-0.07
leaf3->X_y
Prediction
108.80
In [11]:
# data from https://archive.ics.uci.edu/ml/datasets/User+Knowledge+Modeling
clf = tree.DecisionTreeClassifier(max_depth=3)
if 'google.colab' in sys.modules:
know = pd.read_csv("https://raw.githubusercontent.com/parrt/dtreeviz/master/testing/data/knowledge.csv")
else:
know = pd.read_csv("../testing/data/knowledge.csv")
target_names = ['very_low', 'Low', 'Middle', 'High']
know['UNS'] = know['UNS'].map({n: i for i, n in enumerate(target_names)})
X_train, y_train = know.drop('UNS', axis=1), know['UNS']
clf = clf.fit(X_train, y_train)
viz = dtreeviz(clf,
X_train,
y_train,
target_name='UNS',
feature_names=X_train.columns.values,
class_names=target_names,
fancy=False)
viz
Out[11]:
G
cluster_legend
node2
LPR@0.71
node5
LPR@0.79
leaf3
node2->leaf3
leaf4
node2->leaf4
leaf6
node5->leaf6
leaf7
node5->leaf7
node1
PEG@0.12
node1->node2
node1->node5
node8
PEG@0.68
node9
LPR@0.85
node12
SCG@0.84
leaf10
node9->leaf10
leaf11
node9->leaf11
leaf13
node12->leaf13
leaf14
node12->leaf14
node8->node9
node8->node12
node0
PEG@0.34
node0->node1
<
node0->node8
≥
legend
Content source: parrt/AniML
Similar notebooks: