Introduction to Scikit-Learn: Machine Learning with Python

  • Jake VanderPlas
  • ESAC statistics workshop, Oct 27-31 2014
  • Source available on github

This tutorial will cover the basics of Scikit-Learn, a popular package containing a collection of tools for machine learning written in Python. See more at http://scikit-learn.org.

Outline

Main Goal: To introduce the central concepts of machine learning, and how they can be applied in Python using the Scikit-learn Package.

  • Definition of machine learning
  • Data representation in scikit-learn
  • Introduction to the Scikit-learn API
  • Basics of Machine Learning
    • Supervised learning: Classification & Regression
    • Unsupervised learning: Dimensionality Reduction & Clustering
    • Model Validation
  • Application: Classifying handwritten digits

About Scikit-Learn

Scikit-Learn is a Python package designed to give access to well-known machine learning algorithms within Python code, through a clean, well-thought-out API. It has been built by hundreds of contributors from around the world, and is used across industry and academia.

Scikit-Learn is built upon Python's NumPy (Numerical Python) and SciPy (Scientific Python) libraries, which enable efficient in-core numerical and scientific computation within Python. As such, scikit-learn is not specifically designed for extremely large datasets, though there is some work in this area.

For this short introduction, I'm going to stick to questions of in-core processing of small to medium datasets with Scikit-learn.

What is Machine Learning?

In this section we will begin to explore the basic principles of machine learning. Machine Learning is about building programs with tunable parameters (typically an array of floating point values) that are adjusted automatically so as to improve their behavior by adapting to previously seen data.

Machine Learning can be considered a subfield of Artificial Intelligence since those algorithms can be seen as building blocks to make computers learn to behave more intelligently by somehow generalizing rather that just storing and retrieving data items like a database system would do.

We'll take a look at two very simple machine learning tasks here. The first is a classification task: the figure shows a collection of two-dimensional data, colored according to two different class labels. A classification algorithm may be used to draw a dividing boundary between the two clusters of points:


In [1]:
# start the inline backend for plotting
%matplotlib inline

In [2]:
# Import the example plot from the figures directory
from fig_code import plot_sgd_separator
plot_sgd_separator()


This may seem like a trivial task, but it is a simple version of a very important concept. By drawing this separating line, we have learned a model which can generalize to new data: if you were to drop another point onto the plane which is unlabeled, this algorithm could now predict whether it's a blue or a red point.

If you'd like to see the source code used to generate this, you can either open the code in the figures directory, or you can load the code using the %load magic command:


In [3]:
#Uncomment the %load command to load the contents of the file
# %load fig_code/sgd_separator.py

The next simple task we'll look at is a regression task: a simple best-fit line to a set of data:


In [4]:
from fig_code import plot_linear_regression
plot_linear_regression()


Again, this is an example of fitting a model to data, such that the model can make generalizations about new data. The model has been learned from the training data, and can be used to predict the result of test data: here, we might be given an x-value, and the model would allow us to predict the y value. Again, this might seem like a trivial problem, but it is a basic example of a type of operation that is fundamental to machine learning tasks.

Representation of Data in Scikit-learn

Machine learning is about creating models from data: for that reason, we'll start by discussing how data can be represented in order to be understood by the computer. Along with this, we'll build on our matplotlib examples from the previous section and show some examples of how to visualize data.

Most machine learning algorithms implemented in scikit-learn expect data to be stored in a two-dimensional array or matrix. The arrays can be either numpy arrays, or in some cases scipy.sparse matrices. The size of the array is expected to be [n_samples, n_features]

  • n_samples: The number of samples: each sample is an item to process (e.g. classify). A sample can be a document, a picture, a sound, a video, an astronomical object, a row in database or CSV file, or whatever you can describe with a fixed set of quantitative traits.
  • n_features: The number of features or distinct traits that can be used to describe each item in a quantitative manner. Features are generally real-valued, but may be boolean or discrete-valued in some cases.

The number of features must be fixed in advance. However it can be very high dimensional (e.g. millions of features) with most of them being zeros for a given sample. This is a case where scipy.sparse matrices can be useful, in that they are much more memory-efficient than numpy arrays.

A Simple Example: the Iris Dataset

As an example of a simple dataset, we're going to take a look at the iris data stored by scikit-learn. The data consists of measurements of three different species of irises. There are three species of iris in the dataset, which we can picture here:


In [5]:
from IPython.core.display import Image, display
display(Image(filename='images/iris_setosa.jpg'))
print("Iris Setosa\n")

display(Image(filename='images/iris_versicolor.jpg'))
print("Iris Versicolor\n")

display(Image(filename='images/iris_virginica.jpg'))
print("Iris Virginica")


Iris Setosa

Iris Versicolor

Iris Virginica

Quick Question:

If we want to design an algorithm to recognize iris species, what might the data be?

Remember: we need a 2D array of size [n_samples x n_features].

  • What would the n_samples refer to?

  • What might the n_features refer to?

Remember that there must be a fixed number of features for each sample, and feature number i must be a similar kind of quantity for each sample.

Loading the Iris Data with Scikit-Learn

Scikit-learn has a very straightforward set of data on these iris species. The data consist of the following:

  • Features in the Iris dataset:

    1. sepal length in cm
    2. sepal width in cm
    3. petal length in cm
    4. petal width in cm
  • Target classes to predict:

    1. Iris Setosa
    2. Iris Versicolour
    3. Iris Virginica

scikit-learn embeds a copy of the iris CSV file along with a helper function to load it into numpy arrays:


In [6]:
from sklearn.datasets import load_iris
iris = load_iris()

In [7]:
iris.keys()


Out[7]:
dict_keys(['target_names', 'target', 'data', 'feature_names', 'DESCR'])

In [8]:
n_samples, n_features = iris.data.shape
print((n_samples, n_features))
print(iris.data[0])


(150, 4)
[ 5.1  3.5  1.4  0.2]

In [9]:
print(iris.data.shape)
print(iris.target.shape)


(150, 4)
(150,)

In [10]:
print(iris.target)


[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

In [11]:
print(iris.target_names)


['setosa' 'versicolor' 'virginica']

This data is four dimensional, but we can visualize two of the dimensions at a time using a simple scatter-plot:


In [12]:
import numpy as np
import matplotlib.pyplot as plt

x_index = 0
y_index = 1

# this formatter will label the colorbar with the correct target names
formatter = plt.FuncFormatter(lambda i, *args: iris.target_names[int(i)])

plt.scatter(iris.data[:, x_index], iris.data[:, y_index],
            c=iris.target, cmap=plt.cm.get_cmap('RdYlBu', 3))
plt.colorbar(ticks=[0, 1, 2], format=formatter)
plt.clim(-0.5, 2.5)
plt.xlabel(iris.feature_names[x_index])
plt.ylabel(iris.feature_names[y_index]);


Quick Exercise:

Change x_index and y_index in the above script and find a combination of two parameters which maximally separate the three classes.

This exercise is a preview of dimensionality reduction, which we'll see later.

Other Available Data

They come in three flavors:

  • Packaged Data: these small datasets are packaged with the scikit-learn installation, and can be downloaded using the tools in sklearn.datasets.load_*
  • Downloadable Data: these larger datasets are available for download, and scikit-learn includes tools which streamline this process. These tools can be found in sklearn.datasets.fetch_*
  • Generated Data: there are several datasets which are generated from models based on a random seed. These are available in the sklearn.datasets.make_*

You can explore the available dataset loaders, fetchers, and generators using IPython's tab-completion functionality. After importing the datasets submodule from sklearn, type

datasets.load_ + TAB

or

datasets.fetch_ + TAB

or

datasets.make_ + TAB

to see a list of available functions.


In [13]:
from sklearn import datasets

In [14]:
# Type datasets.fetch_<TAB> in IPython to see all possibilities
# datasets.load_

Side Note: a similar interface to datasets is available in astroML: see some examples at http://astroml.org


In [15]:
from astroML import datasets

In [16]:
# Use tab completion to explore datasets
# datasets.fetch_

Basic Principles of Machine Learning

Here we'll dive into the basic principles of machine learning, and how to utilize them via the Scikit-Learn API.

After briefly introducing scikit-learn's Estimator object, we'll cover supervised learning, including classification and regression problems, and unsupervised learning, including dimensinoality reduction and clustering problems.

The Scikit-learn Estimator Object

Every algorithm is exposed in scikit-learn via an ''Estimator'' object. For instance a linear regression is implemented as so:


In [17]:
from sklearn.linear_model import LinearRegression

Estimator parameters: All the parameters of an estimator can be set when it is instantiated, and have suitable default values:


In [18]:
model = LinearRegression(normalize=True)
print(model.normalize)


True

In [19]:
print(model)


LinearRegression(copy_X=True, fit_intercept=True, normalize=True)

Estimated Model parameters: When data is fit with an estimator, parameters are estimated from the data at hand. All the estimated parameters are attributes of the estimator object ending by an underscore:


In [20]:
x = np.arange(10)
y = 2 * x + 1

In [21]:
print(x)
print(y)


[0 1 2 3 4 5 6 7 8 9]
[ 1  3  5  7  9 11 13 15 17 19]

In [22]:
plt.plot(x, y, 'o');



In [23]:
# The input data for sklearn is 2D: (samples == 3 x features == 1)
X = x[:, np.newaxis]
print(X)
print(y)


[[0]
 [1]
 [2]
 [3]
 [4]
 [5]
 [6]
 [7]
 [8]
 [9]]
[ 1  3  5  7  9 11 13 15 17 19]

In [24]:
# fit the model on our data
model.fit(X, y)


Out[24]:
LinearRegression(copy_X=True, fit_intercept=True, normalize=True)

In [24]:


In [25]:
# underscore at the end indicates a fit parameter
print(model.coef_)
print(model.intercept_)


[ 2.]
1.0

In [26]:
model.residues_


Out[26]:
3.944304526105059e-31

The model found a line with a slope 2 and intercept 1, as we'd expect.

Supervised Learning: Classification and Regression

In Supervised Learning, we have a dataset consisting of both features and labels. The task is to construct an estimator which is able to predict the label of an object given the set of features. A relatively simple example is predicting the species of iris given a set of measurements of its flower. This is a relatively simple task. Some more complicated examples are:

  • given a multicolor image of an object through a telescope, determine whether that object is a star, a quasar, or a galaxy.
  • given a photograph of a person, identify the person in the photo.
  • given a list of movies a person has watched and their personal rating of the movie, recommend a list of movies they would like (So-called recommender systems: a famous example is the Netflix Prize).

What these tasks have in common is that there is one or more unknown quantities associated with the object which needs to be determined from other observed quantities.

Supervised learning is further broken down into two categories, classification and regression. In classification, the label is discrete, while in regression, the label is continuous. For example, in astronomy, the task of determining whether an object is a star, a galaxy, or a quasar is a classification problem: the label is from three distinct categories. On the other hand, we might wish to estimate the age of an object based on such observations: this would be a regression problem, because the label (age) is a continuous quantity.

Classification Example

K nearest neighbors (kNN) is one of the simplest learning strategies: given a new, unknown observation, look up in your reference database which ones have the closest features and assign the predominant class.

Let's try it out on our iris classification problem:


In [27]:
from sklearn import neighbors, datasets

iris = datasets.load_iris()
X, y = iris.data, iris.target

# create the model
knn = neighbors.KNeighborsClassifier(n_neighbors=5)

# fit the model
knn.fit(X, y)

# What kind of iris has 3cm x 5cm sepal and 4cm x 2cm petal?
# call the "predict" method:
result = knn.predict([[3, 5, 4, 2],])

print(iris.target_names[result])


['versicolor']

You can also do probabilistic predictions:


In [28]:
knn.predict_proba([[3, 5, 4, 2],])


Out[28]:
array([[ 0. ,  0.8,  0.2]])

In [29]:
from fig_code import plot_iris_knn
plot_iris_knn()


Exercise: Now use as an estimator on the same problem: sklearn.svm.SVC.

Note that you don't have to know what it is do use it.

If you finish early, do the same plot as above.


In [30]:
from sklearn.svm import SVC

In [31]:
model = SVC()
model.fit(X, y)
result = model.predict([[3, 5, 4, 2],])
print(iris.target_names[result])


['virginica']

Regression Example

Simplest possible regression is fitting a line to data:


In [32]:
# Create some simple data
import numpy as np
np.random.seed(0)
X = np.random.random(size=(20, 1))
y = 3 * X.squeeze() + 2 + np.random.randn(20)

# Fit a Random Forest
from sklearn.ensemble import RandomForestRegressor
model = RandomForestRegressor()
model.fit(X, y)

# Plot the data and the model prediction
X_test = np.linspace(0, 1, 100)[:, np.newaxis]
y_test = model.predict(X_test)

plt.plot(X.squeeze(), y, 'o')
plt.plot(X_test.squeeze(), y_test);


Unsupervised Learning: Dimensionality Reduction and Clustering

Unsupervised Learning addresses a different sort of problem. Here the data has no labels, and we are interested in finding similarities between the objects in question. In a sense, you can think of unsupervised learning as a means of discovering labels from the data itself. Unsupervised learning comprises tasks such as dimensionality reduction, clustering, and density estimation. For example, in the iris data discussed above, we can used unsupervised methods to determine combinations of the measurements which best display the structure of the data. As we'll see below, such a projection of the data can be used to visualize the four-dimensional dataset in two dimensions. Some more involved unsupervised learning problems are:

  • given detailed observations of distant galaxies, determine which features or combinations of features best summarize the information.
  • given a mixture of two sound sources (for example, a person talking over some music), separate the two (this is called the blind source separation problem).
  • given a video, isolate a moving object and categorize in relation to other moving objects which have been seen.

Sometimes the two may even be combined: e.g. Unsupervised learning can be used to find useful features in heterogeneous data, and then these features can be used within a supervised framework.

Dimensionality Reduction: PCA

Principle Component Analysis (PCA) is a dimension reduction technique that can find the combinations of variables that explain the most variance.

Consider the iris dataset. It cannot be visualized in a single 2D plot, as it has 4 features. We are going to extract 2 combinations of sepal and petal dimensions to visualize it:


In [33]:
X, y = iris.data, iris.target
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
pca.fit(X)
X_reduced = pca.transform(X)
print("Reduced dataset shape:", X_reduced.shape)

import pylab as pl
pl.scatter(X_reduced[:, 0], X_reduced[:, 1], c=y,
           cmap='RdYlBu')

print("Meaning of the 2 components:")
for component in pca.components_:
    print(" + ".join("%.3f x %s" % (value, name)
                     for value, name in zip(component,
                                            iris.feature_names)))


Reduced dataset shape: (150, 2)
Meaning of the 2 components:
0.362 x sepal length (cm) + -0.082 x sepal width (cm) + 0.857 x petal length (cm) + 0.359 x petal width (cm)
-0.657 x sepal length (cm) + -0.730 x sepal width (cm) + 0.176 x petal length (cm) + 0.075 x petal width (cm)

Clustering: K-means

Clustering groups together observations that are homogeneous with respect to a given criterion, finding ''clusters'' in the data.

Note that these clusters will uncover relevent hidden structure of the data only if the criterion used highlights it.


In [34]:
from sklearn.cluster import KMeans
k_means = KMeans(n_clusters=3, random_state=0) # Fixing the RNG in kmeans
k_means.fit(X)
y_pred = k_means.predict(X)

pl.scatter(X_reduced[:, 0], X_reduced[:, 1], c=y_pred,
           cmap='RdYlBu');


Recap: Scikit-learn's estimator interface

Scikit-learn strives to have a uniform interface across all methods, and we'll see examples of these below. Given a scikit-learn estimator object named model, the following methods are available:

  • Available in all Estimators
    • model.fit() : fit training data. For supervised learning applications, this accepts two arguments: the data X and the labels y (e.g. model.fit(X, y)). For unsupervised learning applications, this accepts only a single argument, the data X (e.g. model.fit(X)).
  • Available in supervised estimators
    • model.predict() : given a trained model, predict the label of a new set of data. This method accepts one argument, the new data X_new (e.g. model.predict(X_new)), and returns the learned label for each object in the array.
    • model.predict_proba() : For classification problems, some estimators also provide this method, which returns the probability that a new observation has each categorical label. In this case, the label with the highest probability is returned by model.predict().
    • model.score() : for classification or regression problems, most (all?) estimators implement a score method. Scores are between 0 and 1, with a larger score indicating a better fit.
  • Available in unsupervised estimators
    • model.transform() : given an unsupervised model, transform new data into the new basis. This also accepts one argument X_new, and returns the new representation of the data based on the unsupervised model.
    • model.fit_transform() : some estimators implement this method, which more efficiently performs a fit and a transform on the same input data.

Model Validation

An important piece of machine learning is model validation: that is, determining how well your model will generalize from the training data to future unlabeled data. Let's look at an example using the nearest neighbor classifier. This is a very simple classifier: it simply stores all training data, and for any unknown quantity, simply returns the label of the closest training point.

With the iris data, it very easily returns the correct prediction for each of the input points:


In [35]:
from sklearn.neighbors import KNeighborsClassifier
X, y = iris.data, iris.target
clf = KNeighborsClassifier(n_neighbors=1)
clf.fit(X, y)
y_pred = clf.predict(X)
print(np.all(y == y_pred))


True

A more useful way to look at the results is to view the confusion matrix, or the matrix showing the frequency of inputs and outputs:


In [36]:
from sklearn.metrics import confusion_matrix
print(confusion_matrix(y, y_pred))


[[50  0  0]
 [ 0 50  0]
 [ 0  0 50]]

For each class, all 50 training samples are correctly identified. But this does not mean that our model is perfect! In particular, such a model generalizes extremely poorly to new data. We can simulate this by splitting our data into a training set and a testing set. Scikit-learn contains some convenient routines to do this:


In [37]:
from sklearn.cross_validation import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y)
clf.fit(Xtrain, ytrain)
ypred = clf.predict(Xtest)
print(confusion_matrix(ytest, ypred))


[[12  0  0]
 [ 0  6  0]
 [ 0  2 18]]

This paints a better picture of the true performance of our classifier: apparently there is some confusion between the second and third species, which we might anticipate given what we've seen of the data above.

This is why it's extremely important to use a train/test split when evaluating your models. We'll go into more depth on model evaluation later in this tutorial.

Flow Chart: How to Choose your Estimator

This is a flow chart created by scikit-learn super-contributor Andreas Mueller which gives a nice summary of which algorithms to choose in various situations. Keep it around as a handy reference!


In [38]:
from IPython.display import Image
Image("http://scikit-learn.org/dev/_static/ml_map.png")


Out[38]:

Original source on the scikit-learn website

Quick Application: Optical Character Recognition

One interesting application where we might apply machine learning is in exploring handwritten digit data. Scikit-learn has a small test set built-in:

Loading and visualizing the digits data

We'll use scikit-learn's data access interface and take a look at this data:


In [39]:
from sklearn import datasets
digits = datasets.load_digits()
digits.images.shape


Out[39]:
(1797, 8, 8)

Let's plot a few of these:


In [40]:
fig, axes = plt.subplots(10, 10, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)

for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap='binary')
    ax.text(0.05, 0.05, str(digits.target[i]),
            transform=ax.transAxes, color='green')
    ax.set_xticks([])
    ax.set_yticks([])


Here the data is simply each pixel value within an 8x8 grid:


In [41]:
# The images themselves
print(digits.images.shape)
print(digits.images[0])


(1797, 8, 8)
[[  0.   0.   5.  13.   9.   1.   0.   0.]
 [  0.   0.  13.  15.  10.  15.   5.   0.]
 [  0.   3.  15.   2.   0.  11.   8.   0.]
 [  0.   4.  12.   0.   0.   8.   8.   0.]
 [  0.   5.   8.   0.   0.   9.   8.   0.]
 [  0.   4.  11.   0.   1.  12.   7.   0.]
 [  0.   2.  14.   5.  10.  12.   0.   0.]
 [  0.   0.   6.  13.  10.   0.   0.   0.]]

In [42]:
# The data for use in our algorithms
print(digits.data.shape)
print(digits.data[0])


(1797, 64)
[  0.   0.   5.  13.   9.   1.   0.   0.   0.   0.  13.  15.  10.  15.   5.
   0.   0.   3.  15.   2.   0.  11.   8.   0.   0.   4.  12.   0.   0.   8.
   8.   0.   0.   5.   8.   0.   0.   9.   8.   0.   0.   4.  11.   0.   1.
  12.   7.   0.   0.   2.  14.   5.  10.  12.   0.   0.   0.   0.   6.  13.
  10.   0.   0.   0.]

In [43]:
# The target label
print(digits.target)


[0 1 2 ..., 8 9 8]

So our data have 1797 samples in 64 dimensions.

Unsupervised Learning: Dimensionality Reduction

We'd like to visualize our points within the 64-dimensional parameter space, but it's difficult to plot points in 64 dimensions! Instead we'll use a unsupervised dimensionality reduction technique called Isomap:


In [44]:
from sklearn.manifold import Isomap

In [45]:
iso = Isomap(n_components=2)
data_projected = iso.fit_transform(digits.data)

In [46]:
data_projected.shape


Out[46]:
(1797, 2)

In [47]:
plt.scatter(data_projected[:, 0], data_projected[:, 1], c=digits.target,
            edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('nipy_spectral', 10));
plt.colorbar(label='digit label', ticks=range(10))
plt.clim(-0.5, 9.5)


We see here that the digits are fairly well-separated in the parameter space; this tells us that a supervised classification algorithm should perform fairly well. Let's give it a try.

Classification on Digits

Let's try a classification task on the digits. The first thing we'll want to do is split the digits into a training and testing sample:


In [48]:
from sklearn.cross_validation import train_test_split
Xtrain, Xtest, ytrain, ytest = train_test_split(digits.data, digits.target,
                                                random_state=2)
print(Xtrain.shape, Xtest.shape)


(1347, 64) (450, 64)

Let's use a simple logistic regression which (despite its name) is a classification algorithm:


In [49]:
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(penalty='l2')
clf.fit(Xtrain, ytrain)
ypred = clf.predict(Xtest)

We can check our classification accuracy by comparing the true values of the test set to the predictions:


In [50]:
from sklearn.metrics import accuracy_score
accuracy_score(ytest, ypred)


Out[50]:
0.94666666666666666

This single number doesn't tell us where we've gone wrong: one nice way to do this is to use the confusion matrix


In [51]:
from sklearn.metrics import confusion_matrix
print(confusion_matrix(ytest, ypred))
plt.imshow(np.log(confusion_matrix(ytest, ypred)),
           cmap='Blues', interpolation='nearest')
plt.ylabel('true')
plt.xlabel('predicted');


[[42  0  0  0  0  0  0  0  0  0]
 [ 0 45  0  1  0  0  0  0  3  1]
 [ 0  0 47  0  0  0  0  0  0  0]
 [ 0  0  0 42  0  2  0  3  1  0]
 [ 0  2  0  0 36  0  0  0  1  1]
 [ 0  0  0  0  0 52  0  0  0  0]
 [ 0  0  0  0  0  0 42  0  1  0]
 [ 0  0  0  0  0  0  0 48  1  0]
 [ 0  2  0  0  0  0  0  0 38  0]
 [ 0  0  0  1  0  1  0  1  2 34]]
-c:3: RuntimeWarning: divide by zero encountered in log

We might also take a look at some of the outputs along with their predicted labels. We'll make the bad labels red:


In [52]:
fig, axes = plt.subplots(10, 10, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)

for i, ax in enumerate(axes.flat):
    ax.imshow(Xtest[i].reshape(8, 8), cmap='binary')
    ax.text(0.05, 0.05, str(ypred[i]),
            transform=ax.transAxes,
            color='green' if (ytest[i] == ypred[i]) else 'red')
    ax.set_xticks([])
    ax.set_yticks([])


The interesting thing is that even with this simple logistic regression algorithm, many of the mislabeled cases are ones that we ourselves might get wrong!

There are many ways to improve this classifier, but we're out of time here. To go further, we could use a more sophisticated model, use cross validation, etc.