Taken from Google's Visualizing a Decision Tree - Machine Learning Recipes #2
In [1]:
from sklearn import tree
from sklearn.datasets import load_iris
In [2]:
iris = load_iris()
In [3]:
type(iris)
Out[3]:
In [4]:
isinstance(iris, dict)
Out[4]:
In [5]:
iris.keys()
Out[5]:
In [6]:
iris.feature_names
Out[6]:
In [7]:
iris.target_names
Out[7]:
In [8]:
iris['target']
Out[8]:
In [9]:
type(iris['target'])
Out[9]:
In [10]:
for i in range(len(iris.target)):
if i < 5:
print('Example {}: label {}, features {}'.format(i, iris.target[i], iris.data[i]))
In [11]:
import numpy as np
In [12]:
test_idx = [0, 50, 100] # these are the rows to be removed from the training data
# remove the same rows from the actual data
# Note: without axis=0, returns just a list, not a list of lists
# ie we want this:
# [[ 4.9, 3. , 1.4, 0.2],
# [ 4.7, 3.2, 1.3, 0.2],
# [ 4.6, 3.1, 1.5, 0.2], …]
# and not this:
# [4.9, 3. , 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, …]
train_data = np.delete(iris.data, test_idx, axis=0)
# np.delete() remove the above 3 indices from array iris.target
# Note: here the axis= arg doesn't matter, as only a 1 interger per item in list
train_target = np.delete(iris.target, test_idx)
In [13]:
# See how rows have been rm'd
len(iris.target)
Out[13]:
In [14]:
len(train_data) # the three taken out
Out[14]:
In [15]:
len(train_target)
Out[15]:
In [16]:
test_target = iris.target[test_idx]
In [17]:
test_target # only three
Out[17]:
In [18]:
test_data = iris.data[test_idx]
In [19]:
test_data
Out[19]:
In [20]:
# Note: on numpy array
l = [1, 4, 5,6,8, 999, 44, 6, 7, 10]
a = np.array(l)
# now we can pull out items with a list of indices/rows
idx = [0, 4, 6]
a[idx]
Out[20]:
In [21]:
# train model
clf = tree.DecisionTreeClassifier()
clf.fit(train_data, train_target)
Out[21]:
In [22]:
# make prediction
clf.predict(test_data)
Out[22]:
In [23]:
# matches input labels?
clf.predict(test_data) == test_target
Out[23]:
In [25]:
from sklearn.externals.six import StringIO
import pydotplus # note installed pydotplus for Py3 compatibility
In [26]:
dot_data = StringIO()
tree.export_graphviz(clf,
out_file=dot_data,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True,
rounded=True,
impurity=False)
In [27]:
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
In [29]:
# graphviz installed on mac with `brew install graphviz`
graph.write_pdf('iris.pdf')
# open -a preview ~/ipython/tensorflow/iris.pdf
Out[29]:
In [32]:
# now check the rows withheld for testing
# check against rules in graphic tree
test_data[0], test_target[0] # we know is a setosa
Out[32]:
In [34]:
iris.feature_names, iris.target_names
Out[34]:
In [36]:
test_data[1], test_target[1] # we know is a versicolor
Out[36]:
In [38]:
test_data[2], test_target[2] # we know is a virginica
# all test to true!
Out[38]: