In [ ]:
import pandas as pd
%matplotlib inline
In [ ]:
from sklearn import datasets
from pandas.tools.plotting import scatter_matrix
In [ ]:
import matplotlib.pyplot as plt
In [ ]:
iris = datasets.load_iris() # load iris data set
In [ ]:
x = iris.data[:,2:] # the attributes
y = iris.target # the target variable
In [ ]:
from sklearn import tree
In [ ]:
dt = tree.DecisionTreeClassifier()
In [ ]:
dt = dt.fit(x,y)
In [ ]:
from sklearn.externals.six import StringIO
import pydotplus #pip install pydotplus
In [ ]:
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(dt, out_file=f)
In [ ]:
import os
os.unlink('iris.dot')
In [ ]:
dot_data = StringIO()
tree.export_graphviz(dt, out_file=dot_data) #brew install graphviz
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("iris.pdf")
In [ ]:
from IPython.display import IFrame
IFrame("iris.pdf", width=800, height=800)
In [ ]: