Decision Tree

CART (Classification and Regression Tree)

Training a Decision Tree with Scikit-Learn Library


In [1]:
import pandas as pd

In [2]:
from sklearn import tree

In [3]:
X = [[0, 0], [1, 2]]
y = [0, 1]

In [4]:
clf = tree.DecisionTreeClassifier()

In [5]:
clf = clf.fit(X, y)

In [6]:
clf.predict([[2., 2.]])


Out[6]:
array([1])

In [7]:
clf.predict_proba([[2. , 2.]])


Out[7]:
array([[ 0.,  1.]])

In [8]:
clf.predict([[0.4, 1.2]])


Out[8]:
array([1])

In [9]:
clf.predict_proba([[0.4, 1.2]])


Out[9]:
array([[ 0.,  1.]])

In [10]:
clf.predict_proba([[0, 0.2]])


Out[10]:
array([[ 1.,  0.]])

DecisionTreeClassifier is capable of both binary (where the labels are [-1, 1]) classification and multiclass (where the labels are [0, …, K-1]) classification.

Applying to Iris Dataset


In [11]:
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()

In [12]:
iris.data[0:5]


Out[12]:
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2]])

In [13]:
iris.feature_names


Out[13]:
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']

In [14]:
X = iris.data[:, 2:]

In [15]:
y = iris.target

In [16]:
y


Out[16]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [17]:
clf = tree.DecisionTreeClassifier(random_state=42)

In [18]:
clf = clf.fit(X, y)

In [19]:
from sklearn.tree import export_graphviz

In [20]:
export_graphviz(clf,
                out_file="tree.dot",
                feature_names=iris.feature_names[2:],
                class_names=iris.target_names,
                rounded=True,
                filled=True)

In [21]:
import graphviz

In [22]:
dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names[2:],
                                class_names=iris.target_names,
                                rounded=True,
                                filled=True)

In [23]:
graph = graphviz.Source(dot_data)

In [24]:
import numpy as np
import seaborn as sns
sns.set_style('whitegrid')
import matplotlib.pyplot as plt
%matplotlib inline

Start Here


In [25]:
df = sns.load_dataset('iris')
df.head()


Out[25]:
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa

In [26]:
col = ['petal_length', 'petal_width']
X = df.loc[:, col]

In [27]:
species_to_num = {'setosa': 0,
                  'versicolor': 1,
                  'virginica': 2}
df['tmp'] = df['species'].map(species_to_num)
y = df['tmp']

In [28]:
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

In [29]:
X[0:5]


Out[29]:
petal_length petal_width
0 1.4 0.2
1 1.4 0.2
2 1.3 0.2
3 1.5 0.2
4 1.4 0.2

In [30]:
X.values


Out[30]:
array([[ 1.4,  0.2],
       [ 1.4,  0.2],
       [ 1.3,  0.2],
       [ 1.5,  0.2],
       [ 1.4,  0.2],
       [ 1.7,  0.4],
       [ 1.4,  0.3],
       [ 1.5,  0.2],
       [ 1.4,  0.2],
       [ 1.5,  0.1],
       [ 1.5,  0.2],
       [ 1.6,  0.2],
       [ 1.4,  0.1],
       [ 1.1,  0.1],
       [ 1.2,  0.2],
       [ 1.5,  0.4],
       [ 1.3,  0.4],
       [ 1.4,  0.3],
       [ 1.7,  0.3],
       [ 1.5,  0.3],
       [ 1.7,  0.2],
       [ 1.5,  0.4],
       [ 1. ,  0.2],
       [ 1.7,  0.5],
       [ 1.9,  0.2],
       [ 1.6,  0.2],
       [ 1.6,  0.4],
       [ 1.5,  0.2],
       [ 1.4,  0.2],
       [ 1.6,  0.2],
       [ 1.6,  0.2],
       [ 1.5,  0.4],
       [ 1.5,  0.1],
       [ 1.4,  0.2],
       [ 1.5,  0.2],
       [ 1.2,  0.2],
       [ 1.3,  0.2],
       [ 1.4,  0.1],
       [ 1.3,  0.2],
       [ 1.5,  0.2],
       [ 1.3,  0.3],
       [ 1.3,  0.3],
       [ 1.3,  0.2],
       [ 1.6,  0.6],
       [ 1.9,  0.4],
       [ 1.4,  0.3],
       [ 1.6,  0.2],
       [ 1.4,  0.2],
       [ 1.5,  0.2],
       [ 1.4,  0.2],
       [ 4.7,  1.4],
       [ 4.5,  1.5],
       [ 4.9,  1.5],
       [ 4. ,  1.3],
       [ 4.6,  1.5],
       [ 4.5,  1.3],
       [ 4.7,  1.6],
       [ 3.3,  1. ],
       [ 4.6,  1.3],
       [ 3.9,  1.4],
       [ 3.5,  1. ],
       [ 4.2,  1.5],
       [ 4. ,  1. ],
       [ 4.7,  1.4],
       [ 3.6,  1.3],
       [ 4.4,  1.4],
       [ 4.5,  1.5],
       [ 4.1,  1. ],
       [ 4.5,  1.5],
       [ 3.9,  1.1],
       [ 4.8,  1.8],
       [ 4. ,  1.3],
       [ 4.9,  1.5],
       [ 4.7,  1.2],
       [ 4.3,  1.3],
       [ 4.4,  1.4],
       [ 4.8,  1.4],
       [ 5. ,  1.7],
       [ 4.5,  1.5],
       [ 3.5,  1. ],
       [ 3.8,  1.1],
       [ 3.7,  1. ],
       [ 3.9,  1.2],
       [ 5.1,  1.6],
       [ 4.5,  1.5],
       [ 4.5,  1.6],
       [ 4.7,  1.5],
       [ 4.4,  1.3],
       [ 4.1,  1.3],
       [ 4. ,  1.3],
       [ 4.4,  1.2],
       [ 4.6,  1.4],
       [ 4. ,  1.2],
       [ 3.3,  1. ],
       [ 4.2,  1.3],
       [ 4.2,  1.2],
       [ 4.2,  1.3],
       [ 4.3,  1.3],
       [ 3. ,  1.1],
       [ 4.1,  1.3],
       [ 6. ,  2.5],
       [ 5.1,  1.9],
       [ 5.9,  2.1],
       [ 5.6,  1.8],
       [ 5.8,  2.2],
       [ 6.6,  2.1],
       [ 4.5,  1.7],
       [ 6.3,  1.8],
       [ 5.8,  1.8],
       [ 6.1,  2.5],
       [ 5.1,  2. ],
       [ 5.3,  1.9],
       [ 5.5,  2.1],
       [ 5. ,  2. ],
       [ 5.1,  2.4],
       [ 5.3,  2.3],
       [ 5.5,  1.8],
       [ 6.7,  2.2],
       [ 6.9,  2.3],
       [ 5. ,  1.5],
       [ 5.7,  2.3],
       [ 4.9,  2. ],
       [ 6.7,  2. ],
       [ 4.9,  1.8],
       [ 5.7,  2.1],
       [ 6. ,  1.8],
       [ 4.8,  1.8],
       [ 4.9,  1.8],
       [ 5.6,  2.1],
       [ 5.8,  1.6],
       [ 6.1,  1.9],
       [ 6.4,  2. ],
       [ 5.6,  2.2],
       [ 5.1,  1.5],
       [ 5.6,  1.4],
       [ 6.1,  2.3],
       [ 5.6,  2.4],
       [ 5.5,  1.8],
       [ 4.8,  1.8],
       [ 5.4,  2.1],
       [ 5.6,  2.4],
       [ 5.1,  2.3],
       [ 5.1,  1.9],
       [ 5.9,  2.3],
       [ 5.7,  2.5],
       [ 5.2,  2.3],
       [ 5. ,  1.9],
       [ 5.2,  2. ],
       [ 5.4,  2.3],
       [ 5.1,  1.8]])

In [31]:
X.values.reshape(-1,1)


Out[31]:
array([[ 1.4],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.3],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.7],
       [ 0.4],
       [ 1.4],
       [ 0.3],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.5],
       [ 0.1],
       [ 1.5],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.4],
       [ 0.1],
       [ 1.1],
       [ 0.1],
       [ 1.2],
       [ 0.2],
       [ 1.5],
       [ 0.4],
       [ 1.3],
       [ 0.4],
       [ 1.4],
       [ 0.3],
       [ 1.7],
       [ 0.3],
       [ 1.5],
       [ 0.3],
       [ 1.7],
       [ 0.2],
       [ 1.5],
       [ 0.4],
       [ 1. ],
       [ 0.2],
       [ 1.7],
       [ 0.5],
       [ 1.9],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.6],
       [ 0.4],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.5],
       [ 0.4],
       [ 1.5],
       [ 0.1],
       [ 1.4],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.2],
       [ 0.2],
       [ 1.3],
       [ 0.2],
       [ 1.4],
       [ 0.1],
       [ 1.3],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.3],
       [ 0.3],
       [ 1.3],
       [ 0.3],
       [ 1.3],
       [ 0.2],
       [ 1.6],
       [ 0.6],
       [ 1.9],
       [ 0.4],
       [ 1.4],
       [ 0.3],
       [ 1.6],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 4.7],
       [ 1.4],
       [ 4.5],
       [ 1.5],
       [ 4.9],
       [ 1.5],
       [ 4. ],
       [ 1.3],
       [ 4.6],
       [ 1.5],
       [ 4.5],
       [ 1.3],
       [ 4.7],
       [ 1.6],
       [ 3.3],
       [ 1. ],
       [ 4.6],
       [ 1.3],
       [ 3.9],
       [ 1.4],
       [ 3.5],
       [ 1. ],
       [ 4.2],
       [ 1.5],
       [ 4. ],
       [ 1. ],
       [ 4.7],
       [ 1.4],
       [ 3.6],
       [ 1.3],
       [ 4.4],
       [ 1.4],
       [ 4.5],
       [ 1.5],
       [ 4.1],
       [ 1. ],
       [ 4.5],
       [ 1.5],
       [ 3.9],
       [ 1.1],
       [ 4.8],
       [ 1.8],
       [ 4. ],
       [ 1.3],
       [ 4.9],
       [ 1.5],
       [ 4.7],
       [ 1.2],
       [ 4.3],
       [ 1.3],
       [ 4.4],
       [ 1.4],
       [ 4.8],
       [ 1.4],
       [ 5. ],
       [ 1.7],
       [ 4.5],
       [ 1.5],
       [ 3.5],
       [ 1. ],
       [ 3.8],
       [ 1.1],
       [ 3.7],
       [ 1. ],
       [ 3.9],
       [ 1.2],
       [ 5.1],
       [ 1.6],
       [ 4.5],
       [ 1.5],
       [ 4.5],
       [ 1.6],
       [ 4.7],
       [ 1.5],
       [ 4.4],
       [ 1.3],
       [ 4.1],
       [ 1.3],
       [ 4. ],
       [ 1.3],
       [ 4.4],
       [ 1.2],
       [ 4.6],
       [ 1.4],
       [ 4. ],
       [ 1.2],
       [ 3.3],
       [ 1. ],
       [ 4.2],
       [ 1.3],
       [ 4.2],
       [ 1.2],
       [ 4.2],
       [ 1.3],
       [ 4.3],
       [ 1.3],
       [ 3. ],
       [ 1.1],
       [ 4.1],
       [ 1.3],
       [ 6. ],
       [ 2.5],
       [ 5.1],
       [ 1.9],
       [ 5.9],
       [ 2.1],
       [ 5.6],
       [ 1.8],
       [ 5.8],
       [ 2.2],
       [ 6.6],
       [ 2.1],
       [ 4.5],
       [ 1.7],
       [ 6.3],
       [ 1.8],
       [ 5.8],
       [ 1.8],
       [ 6.1],
       [ 2.5],
       [ 5.1],
       [ 2. ],
       [ 5.3],
       [ 1.9],
       [ 5.5],
       [ 2.1],
       [ 5. ],
       [ 2. ],
       [ 5.1],
       [ 2.4],
       [ 5.3],
       [ 2.3],
       [ 5.5],
       [ 1.8],
       [ 6.7],
       [ 2.2],
       [ 6.9],
       [ 2.3],
       [ 5. ],
       [ 1.5],
       [ 5.7],
       [ 2.3],
       [ 4.9],
       [ 2. ],
       [ 6.7],
       [ 2. ],
       [ 4.9],
       [ 1.8],
       [ 5.7],
       [ 2.1],
       [ 6. ],
       [ 1.8],
       [ 4.8],
       [ 1.8],
       [ 4.9],
       [ 1.8],
       [ 5.6],
       [ 2.1],
       [ 5.8],
       [ 1.6],
       [ 6.1],
       [ 1.9],
       [ 6.4],
       [ 2. ],
       [ 5.6],
       [ 2.2],
       [ 5.1],
       [ 1.5],
       [ 5.6],
       [ 1.4],
       [ 6.1],
       [ 2.3],
       [ 5.6],
       [ 2.4],
       [ 5.5],
       [ 1.8],
       [ 4.8],
       [ 1.8],
       [ 5.4],
       [ 2.1],
       [ 5.6],
       [ 2.4],
       [ 5.1],
       [ 2.3],
       [ 5.1],
       [ 1.9],
       [ 5.9],
       [ 2.3],
       [ 5.7],
       [ 2.5],
       [ 5.2],
       [ 2.3],
       [ 5. ],
       [ 1.9],
       [ 5.2],
       [ 2. ],
       [ 5.4],
       [ 2.3],
       [ 5.1],
       [ 1.8]])

In [32]:
Xv = X.values.reshape(-1,1)

In [33]:
Xv


Out[33]:
array([[ 1.4],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.3],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.7],
       [ 0.4],
       [ 1.4],
       [ 0.3],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.5],
       [ 0.1],
       [ 1.5],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.4],
       [ 0.1],
       [ 1.1],
       [ 0.1],
       [ 1.2],
       [ 0.2],
       [ 1.5],
       [ 0.4],
       [ 1.3],
       [ 0.4],
       [ 1.4],
       [ 0.3],
       [ 1.7],
       [ 0.3],
       [ 1.5],
       [ 0.3],
       [ 1.7],
       [ 0.2],
       [ 1.5],
       [ 0.4],
       [ 1. ],
       [ 0.2],
       [ 1.7],
       [ 0.5],
       [ 1.9],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.6],
       [ 0.4],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.6],
       [ 0.2],
       [ 1.5],
       [ 0.4],
       [ 1.5],
       [ 0.1],
       [ 1.4],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.2],
       [ 0.2],
       [ 1.3],
       [ 0.2],
       [ 1.4],
       [ 0.1],
       [ 1.3],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.3],
       [ 0.3],
       [ 1.3],
       [ 0.3],
       [ 1.3],
       [ 0.2],
       [ 1.6],
       [ 0.6],
       [ 1.9],
       [ 0.4],
       [ 1.4],
       [ 0.3],
       [ 1.6],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 1.5],
       [ 0.2],
       [ 1.4],
       [ 0.2],
       [ 4.7],
       [ 1.4],
       [ 4.5],
       [ 1.5],
       [ 4.9],
       [ 1.5],
       [ 4. ],
       [ 1.3],
       [ 4.6],
       [ 1.5],
       [ 4.5],
       [ 1.3],
       [ 4.7],
       [ 1.6],
       [ 3.3],
       [ 1. ],
       [ 4.6],
       [ 1.3],
       [ 3.9],
       [ 1.4],
       [ 3.5],
       [ 1. ],
       [ 4.2],
       [ 1.5],
       [ 4. ],
       [ 1. ],
       [ 4.7],
       [ 1.4],
       [ 3.6],
       [ 1.3],
       [ 4.4],
       [ 1.4],
       [ 4.5],
       [ 1.5],
       [ 4.1],
       [ 1. ],
       [ 4.5],
       [ 1.5],
       [ 3.9],
       [ 1.1],
       [ 4.8],
       [ 1.8],
       [ 4. ],
       [ 1.3],
       [ 4.9],
       [ 1.5],
       [ 4.7],
       [ 1.2],
       [ 4.3],
       [ 1.3],
       [ 4.4],
       [ 1.4],
       [ 4.8],
       [ 1.4],
       [ 5. ],
       [ 1.7],
       [ 4.5],
       [ 1.5],
       [ 3.5],
       [ 1. ],
       [ 3.8],
       [ 1.1],
       [ 3.7],
       [ 1. ],
       [ 3.9],
       [ 1.2],
       [ 5.1],
       [ 1.6],
       [ 4.5],
       [ 1.5],
       [ 4.5],
       [ 1.6],
       [ 4.7],
       [ 1.5],
       [ 4.4],
       [ 1.3],
       [ 4.1],
       [ 1.3],
       [ 4. ],
       [ 1.3],
       [ 4.4],
       [ 1.2],
       [ 4.6],
       [ 1.4],
       [ 4. ],
       [ 1.2],
       [ 3.3],
       [ 1. ],
       [ 4.2],
       [ 1.3],
       [ 4.2],
       [ 1.2],
       [ 4.2],
       [ 1.3],
       [ 4.3],
       [ 1.3],
       [ 3. ],
       [ 1.1],
       [ 4.1],
       [ 1.3],
       [ 6. ],
       [ 2.5],
       [ 5.1],
       [ 1.9],
       [ 5.9],
       [ 2.1],
       [ 5.6],
       [ 1.8],
       [ 5.8],
       [ 2.2],
       [ 6.6],
       [ 2.1],
       [ 4.5],
       [ 1.7],
       [ 6.3],
       [ 1.8],
       [ 5.8],
       [ 1.8],
       [ 6.1],
       [ 2.5],
       [ 5.1],
       [ 2. ],
       [ 5.3],
       [ 1.9],
       [ 5.5],
       [ 2.1],
       [ 5. ],
       [ 2. ],
       [ 5.1],
       [ 2.4],
       [ 5.3],
       [ 2.3],
       [ 5.5],
       [ 1.8],
       [ 6.7],
       [ 2.2],
       [ 6.9],
       [ 2.3],
       [ 5. ],
       [ 1.5],
       [ 5.7],
       [ 2.3],
       [ 4.9],
       [ 2. ],
       [ 6.7],
       [ 2. ],
       [ 4.9],
       [ 1.8],
       [ 5.7],
       [ 2.1],
       [ 6. ],
       [ 1.8],
       [ 4.8],
       [ 1.8],
       [ 4.9],
       [ 1.8],
       [ 5.6],
       [ 2.1],
       [ 5.8],
       [ 1.6],
       [ 6.1],
       [ 1.9],
       [ 6.4],
       [ 2. ],
       [ 5.6],
       [ 2.2],
       [ 5.1],
       [ 1.5],
       [ 5.6],
       [ 1.4],
       [ 6.1],
       [ 2.3],
       [ 5.6],
       [ 2.4],
       [ 5.5],
       [ 1.8],
       [ 4.8],
       [ 1.8],
       [ 5.4],
       [ 2.1],
       [ 5.6],
       [ 2.4],
       [ 5.1],
       [ 2.3],
       [ 5.1],
       [ 1.9],
       [ 5.9],
       [ 2.3],
       [ 5.7],
       [ 2.5],
       [ 5.2],
       [ 2.3],
       [ 5. ],
       [ 1.9],
       [ 5.2],
       [ 2. ],
       [ 5.4],
       [ 2.3],
       [ 5.1],
       [ 1.8]])

In [34]:
h = 0.02 # set the spacing

In [35]:
Xv.min()


Out[35]:
0.10000000000000001

In [36]:
Xv.max() + 1


Out[36]:
7.9000000000000004

In [37]:
x_min, x_max = Xv.min(), Xv.max() + 1

In [38]:
y.min()


Out[38]:
0

In [39]:
y.max() + 1


Out[39]:
3

In [40]:
y_min, y_max = y.min(), y.max() + 1

In [41]:
y_min


Out[41]:
0

In [42]:
y_max


Out[42]:
3

In [43]:
np.arange(x_min, x_max, h)


Out[43]:
array([ 0.1 ,  0.12,  0.14,  0.16,  0.18,  0.2 ,  0.22,  0.24,  0.26,
        0.28,  0.3 ,  0.32,  0.34,  0.36,  0.38,  0.4 ,  0.42,  0.44,
        0.46,  0.48,  0.5 ,  0.52,  0.54,  0.56,  0.58,  0.6 ,  0.62,
        0.64,  0.66,  0.68,  0.7 ,  0.72,  0.74,  0.76,  0.78,  0.8 ,
        0.82,  0.84,  0.86,  0.88,  0.9 ,  0.92,  0.94,  0.96,  0.98,
        1.  ,  1.02,  1.04,  1.06,  1.08,  1.1 ,  1.12,  1.14,  1.16,
        1.18,  1.2 ,  1.22,  1.24,  1.26,  1.28,  1.3 ,  1.32,  1.34,
        1.36,  1.38,  1.4 ,  1.42,  1.44,  1.46,  1.48,  1.5 ,  1.52,
        1.54,  1.56,  1.58,  1.6 ,  1.62,  1.64,  1.66,  1.68,  1.7 ,
        1.72,  1.74,  1.76,  1.78,  1.8 ,  1.82,  1.84,  1.86,  1.88,
        1.9 ,  1.92,  1.94,  1.96,  1.98,  2.  ,  2.02,  2.04,  2.06,
        2.08,  2.1 ,  2.12,  2.14,  2.16,  2.18,  2.2 ,  2.22,  2.24,
        2.26,  2.28,  2.3 ,  2.32,  2.34,  2.36,  2.38,  2.4 ,  2.42,
        2.44,  2.46,  2.48,  2.5 ,  2.52,  2.54,  2.56,  2.58,  2.6 ,
        2.62,  2.64,  2.66,  2.68,  2.7 ,  2.72,  2.74,  2.76,  2.78,
        2.8 ,  2.82,  2.84,  2.86,  2.88,  2.9 ,  2.92,  2.94,  2.96,
        2.98,  3.  ,  3.02,  3.04,  3.06,  3.08,  3.1 ,  3.12,  3.14,
        3.16,  3.18,  3.2 ,  3.22,  3.24,  3.26,  3.28,  3.3 ,  3.32,
        3.34,  3.36,  3.38,  3.4 ,  3.42,  3.44,  3.46,  3.48,  3.5 ,
        3.52,  3.54,  3.56,  3.58,  3.6 ,  3.62,  3.64,  3.66,  3.68,
        3.7 ,  3.72,  3.74,  3.76,  3.78,  3.8 ,  3.82,  3.84,  3.86,
        3.88,  3.9 ,  3.92,  3.94,  3.96,  3.98,  4.  ,  4.02,  4.04,
        4.06,  4.08,  4.1 ,  4.12,  4.14,  4.16,  4.18,  4.2 ,  4.22,
        4.24,  4.26,  4.28,  4.3 ,  4.32,  4.34,  4.36,  4.38,  4.4 ,
        4.42,  4.44,  4.46,  4.48,  4.5 ,  4.52,  4.54,  4.56,  4.58,
        4.6 ,  4.62,  4.64,  4.66,  4.68,  4.7 ,  4.72,  4.74,  4.76,
        4.78,  4.8 ,  4.82,  4.84,  4.86,  4.88,  4.9 ,  4.92,  4.94,
        4.96,  4.98,  5.  ,  5.02,  5.04,  5.06,  5.08,  5.1 ,  5.12,
        5.14,  5.16,  5.18,  5.2 ,  5.22,  5.24,  5.26,  5.28,  5.3 ,
        5.32,  5.34,  5.36,  5.38,  5.4 ,  5.42,  5.44,  5.46,  5.48,
        5.5 ,  5.52,  5.54,  5.56,  5.58,  5.6 ,  5.62,  5.64,  5.66,
        5.68,  5.7 ,  5.72,  5.74,  5.76,  5.78,  5.8 ,  5.82,  5.84,
        5.86,  5.88,  5.9 ,  5.92,  5.94,  5.96,  5.98,  6.  ,  6.02,
        6.04,  6.06,  6.08,  6.1 ,  6.12,  6.14,  6.16,  6.18,  6.2 ,
        6.22,  6.24,  6.26,  6.28,  6.3 ,  6.32,  6.34,  6.36,  6.38,
        6.4 ,  6.42,  6.44,  6.46,  6.48,  6.5 ,  6.52,  6.54,  6.56,
        6.58,  6.6 ,  6.62,  6.64,  6.66,  6.68,  6.7 ,  6.72,  6.74,
        6.76,  6.78,  6.8 ,  6.82,  6.84,  6.86,  6.88,  6.9 ,  6.92,
        6.94,  6.96,  6.98,  7.  ,  7.02,  7.04,  7.06,  7.08,  7.1 ,
        7.12,  7.14,  7.16,  7.18,  7.2 ,  7.22,  7.24,  7.26,  7.28,
        7.3 ,  7.32,  7.34,  7.36,  7.38,  7.4 ,  7.42,  7.44,  7.46,
        7.48,  7.5 ,  7.52,  7.54,  7.56,  7.58,  7.6 ,  7.62,  7.64,
        7.66,  7.68,  7.7 ,  7.72,  7.74,  7.76,  7.78,  7.8 ,  7.82,
        7.84,  7.86,  7.88])

In [44]:
np.arange(y_min, y_max, h)


Out[44]:
array([ 0.  ,  0.02,  0.04,  0.06,  0.08,  0.1 ,  0.12,  0.14,  0.16,
        0.18,  0.2 ,  0.22,  0.24,  0.26,  0.28,  0.3 ,  0.32,  0.34,
        0.36,  0.38,  0.4 ,  0.42,  0.44,  0.46,  0.48,  0.5 ,  0.52,
        0.54,  0.56,  0.58,  0.6 ,  0.62,  0.64,  0.66,  0.68,  0.7 ,
        0.72,  0.74,  0.76,  0.78,  0.8 ,  0.82,  0.84,  0.86,  0.88,
        0.9 ,  0.92,  0.94,  0.96,  0.98,  1.  ,  1.02,  1.04,  1.06,
        1.08,  1.1 ,  1.12,  1.14,  1.16,  1.18,  1.2 ,  1.22,  1.24,
        1.26,  1.28,  1.3 ,  1.32,  1.34,  1.36,  1.38,  1.4 ,  1.42,
        1.44,  1.46,  1.48,  1.5 ,  1.52,  1.54,  1.56,  1.58,  1.6 ,
        1.62,  1.64,  1.66,  1.68,  1.7 ,  1.72,  1.74,  1.76,  1.78,
        1.8 ,  1.82,  1.84,  1.86,  1.88,  1.9 ,  1.92,  1.94,  1.96,
        1.98,  2.  ,  2.02,  2.04,  2.06,  2.08,  2.1 ,  2.12,  2.14,
        2.16,  2.18,  2.2 ,  2.22,  2.24,  2.26,  2.28,  2.3 ,  2.32,
        2.34,  2.36,  2.38,  2.4 ,  2.42,  2.44,  2.46,  2.48,  2.5 ,
        2.52,  2.54,  2.56,  2.58,  2.6 ,  2.62,  2.64,  2.66,  2.68,
        2.7 ,  2.72,  2.74,  2.76,  2.78,  2.8 ,  2.82,  2.84,  2.86,
        2.88,  2.9 ,  2.92,  2.94,  2.96,  2.98])

In [45]:
np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))


Out[45]:
[array([[ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
        [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
        [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
        ..., 
        [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
        [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
        [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88]]),
 array([[ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ],
        [ 0.02,  0.02,  0.02, ...,  0.02,  0.02,  0.02],
        [ 0.04,  0.04,  0.04, ...,  0.04,  0.04,  0.04],
        ..., 
        [ 2.94,  2.94,  2.94, ...,  2.94,  2.94,  2.94],
        [ 2.96,  2.96,  2.96, ...,  2.96,  2.96,  2.96],
        [ 2.98,  2.98,  2.98, ...,  2.98,  2.98,  2.98]])]

In [46]:
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))

In [47]:
xx


Out[47]:
array([[ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
       [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
       [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
       ..., 
       [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
       [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88],
       [ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88]])

In [48]:
yy


Out[48]:
array([[ 0.  ,  0.  ,  0.  , ...,  0.  ,  0.  ,  0.  ],
       [ 0.02,  0.02,  0.02, ...,  0.02,  0.02,  0.02],
       [ 0.04,  0.04,  0.04, ...,  0.04,  0.04,  0.04],
       ..., 
       [ 2.94,  2.94,  2.94, ...,  2.94,  2.94,  2.94],
       [ 2.96,  2.96,  2.96, ...,  2.96,  2.96,  2.96],
       [ 2.98,  2.98,  2.98, ...,  2.98,  2.98,  2.98]])

In [49]:
xx.ravel()


Out[49]:
array([ 0.1 ,  0.12,  0.14, ...,  7.84,  7.86,  7.88])

In [50]:
xx.ravel?
ocstring: a.ravel([order]) Return a flattened array. Refer to `numpy.ravel` for full documentation. See Also -------- numpy.ravel : equivalent function ndarray.flat : a flat iterator on the array. Type: builtin_function_or_method

In [51]:
yy.ravel()


Out[51]:
array([ 0.  ,  0.  ,  0.  , ...,  2.98,  2.98,  2.98])

In [52]:
np.c_[xx.ravel(), yy.ravel()]


Out[52]:
array([[ 0.1 ,  0.  ],
       [ 0.12,  0.  ],
       [ 0.14,  0.  ],
       ..., 
       [ 7.84,  2.98],
       [ 7.86,  2.98],
       [ 7.88,  2.98]])

In [53]:
np.c_?
Type: CClass String form: Length: 0 File: c:\anaconda3\lib\site-packages\numpy\lib\index_tricks.py Docstring: Translates slice objects to concatenation along the second axis. This is short-hand for ``np.r_['-1,2,0', index expression]``, which is useful because of its common occurrence. In particular, arrays will be stacked along their last axis after being upgraded to at least 2-D with 1's post-pended to the shape (column vectors made out of 1-D arrays). For detailed documentation, see `r_`. Examples -------- >>> np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])] array([[1, 2, 3, 0, 0, 4, 5, 6]])

In [54]:
pd.DataFrame(np.c_[xx.ravel(), yy.ravel()])


Out[54]:
0 1
0 0.10 0.00
1 0.12 0.00
2 0.14 0.00
3 0.16 0.00
4 0.18 0.00
5 0.20 0.00
6 0.22 0.00
7 0.24 0.00
8 0.26 0.00
9 0.28 0.00
10 0.30 0.00
11 0.32 0.00
12 0.34 0.00
13 0.36 0.00
14 0.38 0.00
15 0.40 0.00
16 0.42 0.00
17 0.44 0.00
18 0.46 0.00
19 0.48 0.00
20 0.50 0.00
21 0.52 0.00
22 0.54 0.00
23 0.56 0.00
24 0.58 0.00
25 0.60 0.00
26 0.62 0.00
27 0.64 0.00
28 0.66 0.00
29 0.68 0.00
... ... ...
58470 7.30 2.98
58471 7.32 2.98
58472 7.34 2.98
58473 7.36 2.98
58474 7.38 2.98
58475 7.40 2.98
58476 7.42 2.98
58477 7.44 2.98
58478 7.46 2.98
58479 7.48 2.98
58480 7.50 2.98
58481 7.52 2.98
58482 7.54 2.98
58483 7.56 2.98
58484 7.58 2.98
58485 7.60 2.98
58486 7.62 2.98
58487 7.64 2.98
58488 7.66 2.98
58489 7.68 2.98
58490 7.70 2.98
58491 7.72 2.98
58492 7.74 2.98
58493 7.76 2.98
58494 7.78 2.98
58495 7.80 2.98
58496 7.82 2.98
58497 7.84 2.98
58498 7.86 2.98
58499 7.88 2.98

58500 rows × 2 columns


In [55]:
z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

In [56]:
z


Out[56]:
array([0, 0, 0, ..., 2, 2, 2], dtype=int64)

In [57]:
xx.shape


Out[57]:
(150, 390)

In [58]:
z.shape


Out[58]:
(58500,)

In [59]:
z = z.reshape(xx.shape)

In [60]:
z.shape


Out[60]:
(150, 390)

In [61]:
plt.contourf?
Signature: plt.contourf(*args, **kwargs) Docstring: Plot contours. :func:`~matplotlib.pyplot.contour` and :func:`~matplotlib.pyplot.contourf` draw contour lines and filled contours, respectively. Except as noted, function signatures and return values are the same for both versions.

In [62]:
fig = plt.figure(figsize=(16,10))
ax = plt.contourf(xx, yy, z, cmap = 'afmhot', alpha=0.3);



In [63]:
fig = plt.figure(figsize=(16,10))
plt.scatter(X.values[:, 0], X.values[:, 1], c=y, s=80, 
            alpha=0.9, edgecolors='g');



In [64]:
fig = plt.figure(figsize=(16,10))
ax = plt.contourf(xx, yy, z, cmap = 'afmhot', alpha=0.3);
plt.scatter(X.values[:, 0], X.values[:, 1], c=y, s=80, 
            alpha=0.9, edgecolors='g');