Viva la Factory

Factory is way to train several different classifiers on the same dataset and compare the quailty of predictions.

First, enable plotting


In [1]:
%pylab inline


Populating the interactive namespace from numpy and matplotlib

Variables needed for analysis


In [2]:
loaded_variables = ["FlightDistance", "FlightDistanceError", "IP", "VertexChi2", "pt", "p0_pt", "p1_pt", "p2_pt", 'LifeTime', 'dira', 'mass']
train_variables = ["FL: FlightDistance/FlightDistanceError", "IP", "VertexChi2", "pt", "p0_pt", "p1_pt", "p2_pt", 'LifeTime', 'dira']
plot_variables = train_variables + ['mass']

Loading data


In [3]:
import numpy, pandas
from rep.utils import train_test_split

In [4]:
sig_data = pandas.read_csv('toy_datasets/toyMC_sig_mass.csv', sep='\t', usecols=loaded_variables)
bck_data = pandas.read_csv('toy_datasets/toyMC_bck_mass.csv', sep='\t', usecols=loaded_variables)

labels = numpy.array([1] * len(sig_data) + [0] * len(bck_data))
data = pandas.concat([sig_data, bck_data])

# Get train and test data
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, train_size=0.5)

Factory of different models

This class is OrderedDict, with additional interface, main methods are:

  • factory.add_classifier(name, classifier)

  • factory.fit(X, y, sample_weight=None, ipc_profile=None, features=None)
    train all classifiers in factory
    if features is not None, then all classifiers will be trained on these features
    you can pass the name of ipython cluster via ipc_profile for parallel training

  • factory.test_on_lds(lds) - test all models on lds(rep.data.storage.LabeledDataStorage)
    returns report (rep.report.classification.ClassificationReport)


In [5]:
from rep.metaml import ClassifiersFactory
from rep.estimators import TMVAClassifier, SklearnClassifier, XGBoostClassifier
from sklearn.ensemble import AdaBoostClassifier

Define classifiers (that will be compared)


In [6]:
factory = ClassifiersFactory()
# there are different ways to add classifiers to Factory:
factory.add_classifier('tmva', TMVAClassifier(NTrees=50, features=train_variables[:5], Shrinkage=0.05))
factory.add_classifier('ada', AdaBoostClassifier(n_estimators=10))
factory['xgb'] = XGBoostClassifier(features=train_variables[2:6])

Create a copy of the factory with all classifiers


In [7]:
from copy import deepcopy
factory_copy = deepcopy(factory)

Training

pay attention:
for the first factory all the classifiers will use those features we point in their constuctors,
for the second factory we pointed features that will be used in training and all classifiers will use them


In [8]:
%time factory.fit(train_data, train_labels)
pass


model tmva         was trained in 9.67 seconds
model ada          was trained in 1.39 seconds
model xgb          was trained in 8.58 seconds
Totally spent 19.64 seconds on training
CPU times: user 10.3 s, sys: 113 ms, total: 10.4 s
Wall time: 19.6 s
/Users/antares/code/xgboost/wrapper/xgboost.py:80: FutureWarning: comparison to `None` will result in an elementwise object comparison in the future.
  if label != None:
/Users/antares/code/xgboost/wrapper/xgboost.py:82: FutureWarning: comparison to `None` will result in an elementwise object comparison in the future.
  if weight !=None:

In [9]:
factory.predict_proba(train_data)


data was predicted by tmva         in 3.46 seconds
data was predicted by ada          in 0.09 seconds
data was predicted by xgb          in 0.77 seconds
Totally spent 4.32 seconds on prediction
Out[9]:
OrderedDict([('tmva', array([[ 0.34412062,  0.65587938],
       [ 0.68926048,  0.31073952],
       [ 0.14409667,  0.85590333],
       ..., 
       [ 0.84946364,  0.15053636],
       [ 0.78480178,  0.21519822],
       [ 0.24511415,  0.75488585]])), ('ada', array([[ 0.35275404,  0.64724596],
       [ 0.62142341,  0.37857659],
       [ 0.29552299,  0.70447701],
       ..., 
       [ 0.71210128,  0.28789872],
       [ 0.63381179,  0.36618821],
       [ 0.37981714,  0.62018286]])), ('xgb', array([[ 0.02789063,  0.97210938],
       [ 0.102053  ,  0.89794695],
       [ 0.06710427,  0.93289578],
       ..., 
       [ 0.18055718,  0.81944281],
       [ 0.44376361,  0.55623639],
       [ 0.01008851,  0.9899115 ]], dtype=float32))])

In [10]:
%time factory_copy.fit(train_data, train_labels, features=train_variables)
pass


Overwriting features of estimator tmva
Overwriting features of estimator xgb
model tmva         was trained in 11.62 seconds
model ada          was trained in 1.30 seconds
model xgb          was trained in 14.67 seconds
Totally spent 27.59 seconds on training
CPU times: user 16.2 s, sys: 177 ms, total: 16.4 s
Wall time: 27.6 s

Everybody loves plots!

Visualizing of training result with factory

ClassificationReport class provides the posibility to get classification description to compare different models.
Below you can find available functions which can help you to analyze result on arbitrary dataset.

There are different plotting backends supported:

  • matplotlib (default, de-facto standard plotting library),
  • plotly (proprietary package with interactive plots, information is kept on the server),
  • ROOT (the library used by CERN people),
  • bokeh (open-source package with interactive plots)

Get ClassificationReport object

report has many useful methods


In [11]:
report = factory.test_on(test_data, test_labels)

Plot importances of features

Only the features used in training are compared


In [12]:
features_importances = report.feature_importance()
features_importances.plot()


Estimator tmva doesn't support feature importances

feature_importances is object that can be plotted (not only in matplotlib, but in )


In [13]:
features_importances.plot_plotly('importances', figsize=(15, 6))


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

Plot learning curves to see possible overfitting of trained classifier

Learning curves are powerful and simple tool to analyze the behaviour of your model.


In [14]:
from sklearn.metrics import roc_auc_score, log_loss

def log_likelihood(y_true, y_pred, sample_weight=None):
    return log_loss(y_true, y_pred[:, 1])
    
def roc_auc(y_true, y_pred, sample_weight=None):
    return roc_auc_score(y_true, y_pred[:, 1], sample_weight=sample_weight)

learning_curve = report.learning_curve(log_likelihood, metric_label='log likelihood', steps=1)
learning_curve.plot(new_plot=True)


Estimator tmva doesn't support stage predictions

In [15]:
learning_curve = report.learning_curve(roc_auc, metric_label='roc auc', steps=1)
learning_curve.plot(new_plot=True)


Estimator tmva doesn't support stage predictions

In [16]:
learning_curve.plot_plotly(plotly_filename='learning curves', figsize=(18, 8))


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

Plot correlation between features


In [17]:
correlation_pairs = []
correlation_pairs.append((plot_variables[0], plot_variables[1]))
correlation_pairs.append((plot_variables[0], plot_variables[2]))

report.scatter(correlation_pairs, alpha=0.01).plot()


Plot data information: features correlation matrix


In [18]:
# plot correlations between variables for signal-like and bck-like events
report.features_correlation_matrix(features=loaded_variables).plot(new_plot=True, show_legend=False, figsize=(7, 5))



In [19]:
report.features_correlation_matrix_by_class(features=plot_variables).plot(new_plot=True, show_legend=False, figsize=(15, 5))



In [20]:
# plot correlations between variables just for bck-like events
corr = report.features_correlation_matrix_by_class(features=plot_variables[:4], labels_dict={0: 'background'}, grid_columns=1)
corr.plot_plotly(plotly_filename='correlations', show_legend=False, fontsize=8, figsize=(8, 6))


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

Plot distribution for each feature


In [21]:
# use just common features for all classifiers
report.features_pdf().plot()



In [22]:
# use all features in data
report.features_pdf(data.columns).plot_plotly('distributions')


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

Plot predictions distributions


In [23]:
report.prediction_pdf().plot(new_plot=True, figsize = (9, 4))



In [24]:
report.prediction_pdf(labels_dict={0: 'background'}, size=5).plot_plotly('models pdf')


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

ROC curves (receiver operating characteristic)

Plot roc curve for train, test data (it's the same as BackgroundRejection vs Signal Efficiency plot)


In [25]:
report.roc().plot(xlim=(0.5, 1))



In [26]:
# plot the same distribution using interactive plot
report.roc().plot_plotly(plotly_filename='ROC')


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

Plot 'flatness' of classifier prediction

(this is dependence of efficiency on variables of dataset)


In [27]:
efficiencies = report.efficiencies(['mass'])
efficiencies_with_errors = report.efficiencies(['mass'], errors=True, bins=15, ignored_sideband=0.01)

In [ ]:


In [28]:
efficiencies.plot(figsize=(18, 25), fontsize=12, show_legend=False)
efficiencies_with_errors.plot(figsize=(18, 25), fontsize=12, show_legend=False)



In [29]:
efficiencies.plot_plotly("efficiencies", show_legend=False, figsize=(18, 20))
efficiencies_with_errors.plot_plotly("efficiencies error", show_legend=False, figsize=(18, 20))


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

Quality on different metrics

look how simple you can estimate the quality with your custom metrics


In [31]:
# define metric functions
def AMS(s, b): 
    br = 0.01
    radicand = 2 *( (s+b+br) * numpy.log (1.0 + s/(b+br)) - s)
    return numpy.sqrt(radicand)

def significance(s, b): 
    br = 0.01
    radicand = s / numpy.sqrt(b + br)
    return radicand


metrics = report.metrics_vs_cut(AMS, metric_label='AMS')
metrics.plot(new_plot=True, figsize=(15, 6))



In [32]:
metrics = report.metrics_vs_cut(significance, metric_label='significance')
metrics.plot(new_plot=True, figsize=(15, 6))


The same using plotly


In [33]:
metrics.plot_plotly('metrics')


/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/plotly/plotly/plotly.py:186: UserWarning:

Woah there! Look at all those points! Due to browser limitations, Plotly has a hard time graphing more than 500k data points for line charts, or 40k points for other types of charts. Here are some suggestions:
(1) Trying using the image API to return an image instead of a graph URL
(2) Use matplotlib
(3) See if you can create your visualization with fewer data points

If the visualization you're using aggregates points (e.g., box plot, histogram, etc.) you can disregard this warning.

/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/requests/packages/urllib3/util/ssl_.py:79: InsecurePlatformWarning:

A true SSLContext object is not available. This prevents urllib3 from configuring SSL appropriately and may cause certain SSL connections to fail. For more information, see https://urllib3.readthedocs.org/en/latest/security.html#insecureplatformwarning.

The draw time for this plot will be slow for all clients.
/Users/antares/.virtualenvs/rep/lib/python2.7/site-packages/plotly/plotly/plotly.py:674: UserWarning:

Estimated Draw Time Too Long

Exercises

Exercise 1. Create weight column for test and train datasets. Then do fit for factory using this weights columns. Get model information using weights.

Exercise 2. Train another classifiers, plays with parameters and feature sets.

Exercise 3. Try use your cluster (change paths and configurations)