In this example, we'll be plotting the learning curve of a RandomForestClassifier
to see both training and cross-validation scores. We'll call the scikitplot.estimators.plot_learning_curve
method, and supply our classifier and our
features and labels.
In [1]:
from __future__ import absolute_import
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_breast_cancer as load_data
# Import scikit-plot
import scikitplot as skplt
%pylab inline
pylab.rcParams['figure.figsize'] = (14, 14)
In [2]:
# Load the data
X, y = load_data(return_X_y=True)
# Create classifier instance
rf = RandomForestClassifier()
In [3]:
# Plot!
skplt.estimators.plot_learning_curve(rf, X, y)
plt.show()