In [29]:
%matplotlib inline
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
As always, we need to start with some data. Let's first generate a set of outputs $y$ and predicted outputs $\hat{y}$ to illustrate a few typical cases.
In [44]:
from sklearn.metrics import confusion_matrix
y_true = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]
y_pred = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]
cnf_matrix = confusion_matrix(y_true, y_pred)
print(cnf_matrix)
Now let's define a function that will display the confusion matrix. The following is inspired from this example.
In [47]:
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
def plot_confusion_matrix(cm, classes,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
In [48]:
plt.figure()
plot_confusion_matrix(cnf_matrix, ['0', '1'])
In [ ]: