In this example, we'll be plotting a confusion matrix to describe the classification performance of a RandomForestClassifier
using the digits dataset from scikit-learn. Here, we'll be using the scikitplot.metrics.plot_confusion_matrix
method.
In [1]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_digits as load_data
from sklearn.model_selection import cross_val_predict
import matplotlib.pyplot as plt
# Import scikit-plot
import scikitplot as skplt
%pylab inline
pylab.rcParams['figure.figsize'] = (14, 14)
In [2]:
# Load the data
X, y = load_data(return_X_y=True)
In [3]:
# Create an instance of the RandomForestClassifier
classifier = RandomForestClassifier()
# Perform predictions
predictions = cross_val_predict(classifier, X, y)
In [4]:
# Plot!
plot = skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True)