An example showing the plot_confusion_matrix method used by a scikit-learn classifier

In this example, we'll be plotting a confusion matrix to describe the classification performance of a RandomForestClassifier using the digits dataset from scikit-learn. Here, we'll be using the scikitplot.metrics.plot_confusion_matrix method.


In [1]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_digits as load_data
from sklearn.model_selection import cross_val_predict
import matplotlib.pyplot as plt

# Import scikit-plot
import scikitplot as skplt

%pylab inline
pylab.rcParams['figure.figsize'] = (14, 14)


Populating the interactive namespace from numpy and matplotlib

In [2]:
# Load the data
X, y = load_data(return_X_y=True)

In [3]:
# Create an instance of the RandomForestClassifier
classifier = RandomForestClassifier()

# Perform predictions
predictions = cross_val_predict(classifier, X, y)

In [4]:
# Plot!
plot = skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True)