Title: Decision Tree Classifier
Slug: decision_tree_classifier
Summary: Training a decision tree classifier in scikit-learn.
Date: 2017-09-19 12:00
Category: Machine Learning
Tags: Trees And Forests
Authors: Chris Albon

Preliminaries


In [1]:
# Load libraries
from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets

Load Iris Dataset


In [2]:
# Load data
iris = datasets.load_iris()
X = iris.data
y = iris.target

Create Decision Tree Using Gini Impurity


In [3]:
# Create decision tree classifer object using gini
clf = DecisionTreeClassifier(criterion='gini', random_state=0)

Train Model


In [4]:
# Train model
model = clf.fit(X, y)

Create Observation To Predict


In [5]:
# Make new observation
observation = [[ 5,  4,  3,  2]]

Predict Observation


In [6]:
# Predict observation's class    
model.predict(observation)


Out[6]:
array([1])

View Predicted Probabilities


In [7]:
# View predicted class probabilities for the three classes
model.predict_proba(observation)


Out[7]:
array([[ 0.,  1.,  0.]])