Exploring Supervised Machine Learning: Classification

This notebook explores a classification task on astronomical data with Scikit-Learn. Much of the practice of machine learning relies on searching through documentation to learn (or remind yourself) of how things are done. Recall that in IPython, you can see the documentation of any function using the ? character. For example:


In [ ]:
import numpy as np
np.linspace?

In addition, Google is your friend. Do you want to know how to import a Gaussian mixture model in scikit-learn? Search for "scikit-learn Gaussian Mixture" and go from there.

In this notebook, we'll look at a stellar photometry dataset for classification, where the labels distinguish between variable and non-variable stars (based on multiple observations for the training set).

(see also the Regression breakout)

We'll start with the normal notebook imports & preliminaries:


In [ ]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# set matplotlib figure style
mpl.style.use('ggplot') # only works in matplotlib 1.4+
mpl.rc('figure', figsize=(8, 6))

The Data: Stellar Photometry

The stellar photometry data is a combination of photometry for RR Lyrae and main sequence stars from SDSS Stripe 82, with RR Lyrae flagged based on their temporal variation.


In [ ]:
from astroML.datasets import fetch_rrlyrae_combined
from sklearn.cross_validation import train_test_split

X, y = fetch_rrlyrae_combined()

The data consist of 93141 objects, each with four features representing [u-g, g-r, r-i, i-z] colors


In [ ]:
print(X.shape)
print(X[:10])

There are two labels, 0 = non-variable, 1 = variable star.


In [ ]:
print(y.shape)
print(np.unique(y))

The data have a large imbalance: that is, there are many more background stars than there are RR Lyrae stars. This is important to keep in mind when performing a machine learning task on them:


In [ ]:
np.sum(y == 0), np.sum(y == 1)

As before, we can plot the data to better understand what we're looking at:


In [ ]:
plt.scatter(X[:, 0], X[:, 1], c=y,
            linewidth=0, alpha=0.3)
plt.xlabel('u-g')
plt.ylabel('g-r')
plt.axis('tight');

Classification: Labeling Variable Sources

Here we'll explore a classification task using the variable sources shown above.

1. Explore the Data

It's good to start by visually exploring the data you're working with. Plot various combinations of the inputs: if you were doing a classic manual separation of the data based on this, what would you do? Treating the variable stars as your "signal" and the rest as your "background", roughly what would you expect for the completeness and contamination of such a method?


In [ ]:


In [ ]:


In [ ]:

2. Applying a Fast & Simple Model

Import the Gaussian Naive Bayes estimator from scikit-learn (try a Google search – the sklearn documentation will pop up, and you should be able to find the correct import statement pretty easily). This is a particularly useful estimator as a first-pass: it essentially fits each input class with an axis-aligned normal distribution, and thus is very fast (though usually not very accurate).

Fit this model to the data, and then use cross-validation to determine the accuracy of the result.


In [ ]:


In [ ]:


In [ ]:

Note that the typical accuracy score may not be relevant here: if our goal is to find RR Lyrae from the background, you'll want to use a different scoring method (the scoring argument to cross_val_score). The default scoring here is "accuracy", which basically means

$$ {\rm precision} \equiv \frac{\rm # correct labels}{\rm total labels} $$

But for unbalanced data, accuracy is actually a really bad metric! You can create a very accurate classifier by simply returning 0 for everything: the result will have an accuracy of $(93141 - 483) / 93141 = 99.5%$ Needless to say, this model isn't very useful.

Instead, we should look at the "precision" and "recall" scores.

This is what the terms mean:

$$ {\rm precision} \equiv \frac{\rm true~positives}{\rm true~positives + false~positives} $$$$ {\rm recall} \equiv \frac{\rm true~positives}{\rm true~positives + false~negatives} $$$$ F_1 \equiv \frac{\rm precision \cdot recall}{\rm precision + recall} $$

(Note that wikipedia has a nice visual of these quantities). The range for precision, recall, and F1 score is 0 to 1, with 1 being perfect. Often in astronomy, we call the recall the completeness, and (1 - precision) the contamination.

Use the scoring argument to the sklearn cross validation routine to find the optimal accuracy, precision, and recall attained with Gaussian Naive Bayes.


In [ ]:


In [ ]:


In [ ]:

3. Applying a more Flexible Model

Gaussian Naive Bayes is an example of a generative classifier, meaning that each individual class distribution is modeled, and the classification is achieved by comparing these. Another realm of classification algorithms are discriminative classifiers which, roughly speaking, try to draw a line or curve (or more generally, hyperplane or manifold) through parameter space to separate the classes.

A good example of a discriminative classifier is the Support Vector Machine classifier, which you can read about in the SVM Notebook. One thing to note is that the SVM routines tend to be fairly slow for large amounts of data, so for experimenting with it below it may be helpful to subset your data: i.e.


In [ ]:
# subset consisting of 10% of full data
Xsub = X[::10]
ysub = y[::10]

You can always go back and run the algorithm on the full dataset once you've tried a few things.

Here, import the support vector machine classifier (again, Google is your friend) and fit it to this data. What is the accuracy, precision, and recall of the default estimator? How does changing the kernel parameter affect things?


In [ ]:


In [ ]:


In [ ]:

You may notice that there are some issues here: at times the SVC fails to even predict a single positive sample, which makes precision and recall ill-defined! One of the issues with discriminative classification is that it does not do a good job with dataset imbalance (i.e. where background objects dominate). This can be adjusted by weighting the instances of each class, here using the class_weight parameter. Set this to 'auto': how does this change the results?


In [ ]:


In [ ]:


In [ ]:

4. Non-parametrics: Random Forests

Both the generative and discriminative models above are examples of parametric models. Either the data distribution or the separating hyper-plane is parametrized in some way, which means that the model is not as flexible as it could be. Often better results can be found with a non-parametric model such as a Random Forest. For some insight into what's happening with this model, see the Random Forest Notebook.

Try a Random Forest classifier on this data (again: see Google). What accuracy, precision, and recall can you attain?


In [ ]:


In [ ]:


In [ ]:

One of the most important pieces of machine learning is hyperparameter optimization. Hyperparameters are the parameters given at model instantiation, which are set by the user rather than tuned by the fit to the data. Hyperparameter optimization is the process of determining the best choice of these parameters for your dataset. Keep in mind that the results of your fit may be extremely sensitive to these!

We've seen a hint of this previously: when we use cross-validation to compare different options for the SGD classifier. The random forest classifier has a number of hyperparameters to tune, but here let's focus-in on the max_depth parameter and the n_estimators parameter. Try a few of these, compute the cross-validation score, and see how it changes. Do you notice any patterns?


In [ ]:


In [ ]:

Searching through a grid of hyperparameters for the optimal model is a very common (and useful!) task in performing a machine learning analysis. Scikit-learn provides a Grid Search interface to enable this to be done quickly. Take a look at the scikit-learn Grid Search Documentation and use the grid search tools to find the best combination of n_estimators and max_depth for your particular data. What's the best accuracy/precision/recall that you can attain with RandomForests on this dataset?


In [ ]:


In [ ]:


In [ ]:

5. Other classification models

If you have more time, take a look at scikit-learn's online documentation and read about the other classification models available. Try a few of them on this dataset: are there any that can outperform your random forest on one of these metrics?


In [ ]:


In [ ]:


In [ ]: