In [1]:
from __future__ import division
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn.apionly as sns

from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import validation_curve, GridSearchCV, cross_val_score, ParameterGrid

from composition.analysis.load_sim import load_sim
from composition.analysis.preprocessing import get_train_test_sets, LabelEncoder
from composition.analysis.pipelines import get_pipeline
from composition.analysis.features import get_training_features
from composition.analysis.plotting_functions import plot_decision_regions
import composition.analysis.data_functions as data_functions
from composition.support_functions.checkdir import checkdir

%matplotlib inline


/home/jbourbeau/.local/lib/python2.7/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment.
  warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')

In [2]:
sns.set_palette('muted')
sns.set_color_codes()

In [3]:
df, cut_dict = load_sim(return_cut_dict=True)
selection_mask = np.array([True] * len(df))
standard_cut_keys = ['reco_exists', 'reco_zenith', 'num_hits', 'IT_signal',
                     'StationDensity', 'max_charge_frac', 'reco_containment', 'energy_range']
for key in standard_cut_keys:
    selection_mask *= cut_dict[key]

df = df[selection_mask]

feature_list = get_training_features()
X_train, X_test, y_train, y_test, le = get_train_test_sets(df, feature_list)

print('events = ' + str(y_train.shape[0]))


/home/jbourbeau/composition/analysis/load_sim.py:67: RuntimeWarning: divide by zero encountered in log10
  df['reco_log_energy'] = np.nan_to_num(np.log10(df['reco_energy']))
/home/jbourbeau/composition/analysis/load_sim.py:68: RuntimeWarning: invalid value encountered in log10
  df['InIce_log_charge'] = np.nan_to_num(np.log10(df['InIce_charge']))
events = 72644

In [4]:
pipeline = get_pipeline('RF')
param_range = np.arange(1, 20)
train_scores, test_scores = validation_curve(
                estimator=pipeline, 
                X=X_train, 
                y=y_train, 
                param_name='classifier__max_depth', 
                param_range=param_range,
                cv=10,
                verbose=3,
                n_jobs=10)

train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)

plt.plot(param_range, train_mean, 
         color='b', marker='o', 
         markersize=5, label='training accuracy')

plt.fill_between(param_range, train_mean + train_std,
                 train_mean - train_std, alpha=0.15,
                 color='b')

plt.plot(param_range, test_mean, 
         color='g', linestyle='None', 
         marker='s', markersize=5, 
         label='validation accuracy')

plt.fill_between(param_range, 
                 test_mean + test_std,
                 test_mean - test_std, 
                 alpha=0.15, color='g')

plt.grid()
# plt.xscale('log')
plt.legend(loc='lower right')
plt.xlabel('Max depth')
plt.ylabel('Accuracy [\%]')
# plt.ylim([0.8, 1.0])
# plt.tight_layout()
plt.savefig('/home/jbourbeau/public_html/figures/composition/parameter-tuning/RF-validation_curve.png', dpi=300)
# plt.show()


[CV] classifier__max_depth=1 .........................................
[CV] classifier__max_depth=2 .........................................
[CV] classifier__max_depth=3 .........................................
[CV] classifier__max_depth=4 .........................................
[CV] classifier__max_depth=5 .........................................
[CV] classifier__max_depth=6 .........................................
[CV] classifier__max_depth=7 .........................................
[CV] classifier__max_depth=8 .........................................
[CV] classifier__max_depth=9 .........................................
[CV] classifier__max_depth=10 ........................................
[CV] ................ classifier__max_depth=1, score=0.688507 -   0.6s
[CV] classifier__max_depth=11 ........................................
[CV] ................ classifier__max_depth=2, score=0.700206 -   0.5s
[CV] classifier__max_depth=12 ........................................
[CV] ................ classifier__max_depth=3, score=0.712044 -   0.4s
[CV] classifier__max_depth=13 ........................................
[CV] ................ classifier__max_depth=4, score=0.729112 -   0.6s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=6, score=0.760083 -   0.4s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=5, score=0.740812 -   0.3s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=7, score=0.766965 -   0.4s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=8, score=0.771783 -   0.4s
[CV] ............... classifier__max_depth=10, score=0.779491 -   0.2s
[CV] classifier__max_depth=18 ........................................
[CV] classifier__max_depth=19 ........................................
[CV] ................ classifier__max_depth=9, score=0.773572 -   0.4s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=11, score=0.775774 -   0.7s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=12, score=0.777426 -   0.5s
[CV] ................ classifier__max_depth=1, score=0.707226 -   0.5s
[CV] classifier__max_depth=3 .........................................
[CV] classifier__max_depth=4 .........................................
[CV] ............... classifier__max_depth=13, score=0.774673 -   0.6s
[CV] classifier__max_depth=5 .........................................
[CV] ............... classifier__max_depth=14, score=0.775774 -   0.3s
[CV] classifier__max_depth=6 .........................................
[CV] ............... classifier__max_depth=15, score=0.775499 -   0.5s
[CV] classifier__max_depth=7 .........................................
[CV] ............... classifier__max_depth=16, score=0.775499 -   0.4s
[CV] classifier__max_depth=8 .........................................
[CV] ................ classifier__max_depth=2, score=0.721129 -   0.4s
[CV] classifier__max_depth=9 .........................................
[CV] ............... classifier__max_depth=19, score=0.774260 -   0.4s
[CV] classifier__max_depth=10 ........................................
[CV] ............... classifier__max_depth=17, score=0.775499 -   0.4s
[CV] classifier__max_depth=11 ........................................
[CV] ............... classifier__max_depth=18, score=0.776187 -   0.4s
[CV] classifier__max_depth=12 ........................................
[CV] ................ classifier__max_depth=3, score=0.728975 -   0.4s
[CV] classifier__max_depth=13 ........................................
[CV] ................ classifier__max_depth=4, score=0.742602 -   0.5s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=5, score=0.753338 -   0.4s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=6, score=0.768754 -   0.6s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=7, score=0.776875 -   0.5s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=8, score=0.783620 -   0.5s
[CV] classifier__max_depth=18 ........................................
[CV] ................ classifier__max_depth=9, score=0.787061 -   0.5s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=10, score=0.789677 -   0.6s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=11, score=0.790640 -   0.5s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=12, score=0.788438 -   0.5s
[CV] classifier__max_depth=3 .........................................
[CV] ............... classifier__max_depth=13, score=0.788575 -   0.5s
[CV] classifier__max_depth=4 .........................................
[CV] ............... classifier__max_depth=14, score=0.788438 -   0.4s
[CV] classifier__max_depth=5 .........................................
[CV] ............... classifier__max_depth=15, score=0.786098 -   0.5s
[CV] classifier__max_depth=6 .........................................
[CV] ................ classifier__max_depth=1, score=0.699381 -   0.5s
[CV] classifier__max_depth=7 .........................................
[CV] ............... classifier__max_depth=16, score=0.786511 -   0.5s
[CV] classifier__max_depth=8 .........................................
[CV] ................ classifier__max_depth=2, score=0.707364 -   0.5s
[CV] classifier__max_depth=9 .........................................
[CV] ................ classifier__max_depth=3, score=0.723193 -   0.4s
[CV] classifier__max_depth=10 ........................................
[CV] ............... classifier__max_depth=17, score=0.786098 -   0.5s
[CV] classifier__max_depth=11 ........................................
[CV] ............... classifier__max_depth=18, score=0.785409 -   0.4s
[CV] classifier__max_depth=12 ........................................
[CV] ................ classifier__max_depth=4, score=0.735719 -   0.3s
[CV] classifier__max_depth=13 ........................................
[CV] ................ classifier__max_depth=5, score=0.750860 -   0.3s
[CV] classifier__max_depth=14 ........................................
[CV] ............... classifier__max_depth=19, score=0.783620 -   0.4s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=6, score=0.766690 -   0.4s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=8, score=0.780867 -   0.5s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=7, score=0.774948 -   0.7s
[CV] classifier__max_depth=18 ........................................
[CV] ................ classifier__max_depth=9, score=0.786924 -   0.5s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=10, score=0.789539 -   0.6s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=11, score=0.789952 -   0.5s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=12, score=0.790089 -   0.5s
[CV] classifier__max_depth=3 .........................................
[CV] ............... classifier__max_depth=13, score=0.790089 -   0.4s
[CV] classifier__max_depth=4 .........................................
[CV] ............... classifier__max_depth=15, score=0.785960 -   0.5s
[CV] ............... classifier__max_depth=14, score=0.788162 -   0.4s
[CV] classifier__max_depth=5 .........................................
[CV] classifier__max_depth=6 .........................................
[CV] ................ classifier__max_depth=1, score=0.707915 -   0.5s
[CV] classifier__max_depth=7 .........................................
[CV] ............... classifier__max_depth=16, score=0.787612 -   0.4s
[CV] classifier__max_depth=8 .........................................
[CV] ............... classifier__max_depth=17, score=0.784721 -   0.5s
[CV] ................ classifier__max_depth=2, score=0.715072 -   0.3s
[CV] classifier__max_depth=9 .........................................
[CV] classifier__max_depth=10 ........................................
[CV] ................ classifier__max_depth=3, score=0.726910 -   0.3s
[CV] classifier__max_depth=11 ........................................
[CV] ............... classifier__max_depth=18, score=0.783207 -   0.4s
[CV] classifier__max_depth=12 ........................................
[CV] ................ classifier__max_depth=4, score=0.739160 -   0.4s
[CV] classifier__max_depth=13 ........................................
[CV] ............... classifier__max_depth=19, score=0.782657 -   0.7s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=5, score=0.755403 -   0.5s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=6, score=0.770544 -   0.6s
[CV] classifier__max_depth=16 ........................................
[Parallel(n_jobs=10)]: Done  63 out of 190 | elapsed:   35.3s remaining:  1.2min
[CV] ................ classifier__max_depth=7, score=0.779491 -   0.4s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=8, score=0.784721 -   0.6s
[CV] classifier__max_depth=18 ........................................
[CV] ................ classifier__max_depth=9, score=0.787337 -   0.2s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=10, score=0.787887 -   0.5s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=11, score=0.789952 -   0.7s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=12, score=0.788300 -   0.6s
[CV] classifier__max_depth=3 .........................................
[CV] ............... classifier__max_depth=13, score=0.790089 -   0.6s
[CV] classifier__max_depth=4 .........................................
[CV] ............... classifier__max_depth=14, score=0.789814 -   0.3s
[CV] classifier__max_depth=5 .........................................
[CV] ............... classifier__max_depth=15, score=0.787474 -   0.4s
[CV] classifier__max_depth=6 .........................................
[CV] ............... classifier__max_depth=16, score=0.787337 -   0.4s
[CV] classifier__max_depth=7 .........................................
[CV] ................ classifier__max_depth=1, score=0.690434 -   0.6s
[CV] classifier__max_depth=8 .........................................
[CV] ................ classifier__max_depth=2, score=0.701721 -   0.4s
[CV] classifier__max_depth=9 .........................................
[CV] ............... classifier__max_depth=17, score=0.784721 -   0.4s
[CV] classifier__max_depth=10 ........................................
[CV] ................ classifier__max_depth=3, score=0.717412 -   0.5s
[CV] classifier__max_depth=11 ........................................
[CV] ................ classifier__max_depth=4, score=0.732003 -   0.3s
[CV] classifier__max_depth=12 ........................................
[CV] ............... classifier__max_depth=18, score=0.783895 -   0.3s
[CV] classifier__max_depth=13 ........................................
[CV] ............... classifier__max_depth=19, score=0.782381 -   0.5s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=5, score=0.748383 -   0.4s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=6, score=0.763524 -   0.5s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=7, score=0.776187 -   0.6s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=8, score=0.781968 -   0.5s
[CV] classifier__max_depth=18 ........................................
[CV] ................ classifier__max_depth=9, score=0.785272 -   0.5s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=10, score=0.789264 -   0.7s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=11, score=0.787887 -   0.6s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=12, score=0.789814 -   0.4s
[CV] classifier__max_depth=3 .........................................
[CV] ............... classifier__max_depth=13, score=0.790640 -   0.3s
[CV] classifier__max_depth=4 .........................................
[CV] ............... classifier__max_depth=14, score=0.789126 -   0.3s
[CV] classifier__max_depth=5 .........................................
[CV] ................ classifier__max_depth=1, score=0.694383 -   0.5s
[CV] classifier__max_depth=6 .........................................
[CV] ............... classifier__max_depth=15, score=0.784171 -   0.6s
[CV] classifier__max_depth=7 .........................................
[CV] ................ classifier__max_depth=2, score=0.703056 -   0.5s
[CV] classifier__max_depth=8 .........................................
[CV] ............... classifier__max_depth=16, score=0.784446 -   0.5s
[CV] classifier__max_depth=9 .........................................
[CV] ............... classifier__max_depth=17, score=0.782106 -   0.6s
[CV] classifier__max_depth=10 ........................................
[CV] ................ classifier__max_depth=4, score=0.733756 -   0.3s
[CV] classifier__max_depth=11 ........................................
[CV] ................ classifier__max_depth=3, score=0.719301 -   0.3s
[CV] classifier__max_depth=12 ........................................
[CV] ............... classifier__max_depth=18, score=0.781555 -   0.2s
[CV] ................ classifier__max_depth=5, score=0.746145 -   0.2s
[CV] classifier__max_depth=13 ........................................
[CV] classifier__max_depth=14 ........................................
[CV] ............... classifier__max_depth=19, score=0.780317 -   0.5s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=6, score=0.760325 -   0.5s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=7, score=0.771063 -   0.6s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=8, score=0.776569 -   0.6s
[CV] classifier__max_depth=18 ........................................
[CV] ............... classifier__max_depth=10, score=0.783315 -   1.4s
[CV] classifier__max_depth=19 ........................................
[CV] ................ classifier__max_depth=9, score=0.778634 -   0.3s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=13, score=0.782489 -   0.3s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=11, score=0.783866 -   0.7s
[CV] classifier__max_depth=3 .........................................
[CV] ............... classifier__max_depth=14, score=0.781388 -   0.4s
[CV] classifier__max_depth=4 .........................................
[CV] ............... classifier__max_depth=12, score=0.781250 -   0.4s
[CV] classifier__max_depth=5 .........................................
[CV] ............... classifier__max_depth=15, score=0.779598 -   0.5s
[CV] classifier__max_depth=6 .........................................
[CV] ............... classifier__max_depth=16, score=0.782351 -   0.2s
[CV] classifier__max_depth=7 .........................................
[CV] ................ classifier__max_depth=3, score=0.717649 -   0.2s
[CV] classifier__max_depth=8 .........................................
[CV] ................ classifier__max_depth=1, score=0.692869 -   0.2s
[CV] classifier__max_depth=9 .........................................
[CV] ............... classifier__max_depth=17, score=0.779736 -   0.3s
[CV] classifier__max_depth=10 ........................................
[CV] ............... classifier__max_depth=18, score=0.780562 -   0.4s
[CV] classifier__max_depth=11 ........................................
[CV] ................ classifier__max_depth=2, score=0.702781 -   0.2s
[CV] classifier__max_depth=12 ........................................
[CV] ................ classifier__max_depth=4, score=0.736233 -   0.2s
[CV] classifier__max_depth=13 ........................................
[CV] ................ classifier__max_depth=5, score=0.749312 -   0.2s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=6, score=0.763629 -   0.2s
[CV] classifier__max_depth=15 ........................................
[CV] ............... classifier__max_depth=19, score=0.779736 -   0.4s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=8, score=0.781525 -   0.2s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=7, score=0.775055 -   0.2s
[CV] classifier__max_depth=18 ........................................
[CV] ............... classifier__max_depth=11, score=0.783040 -   0.3s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=10, score=0.785105 -   0.3s
[CV] classifier__max_depth=1 .........................................
[CV] ................ classifier__max_depth=9, score=0.782076 -   0.4s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=18, score=0.777120 -   0.2s
[CV] ............... classifier__max_depth=17, score=0.776432 -   0.5s
[CV] classifier__max_depth=3 .........................................
[CV] classifier__max_depth=4 .........................................
[CV] ................ classifier__max_depth=1, score=0.715033 -   0.2s
[CV] classifier__max_depth=5 .........................................
[CV] ............... classifier__max_depth=13, score=0.782076 -   0.2s
[CV] classifier__max_depth=6 .........................................
[Parallel(n_jobs=10)]: Done 127 out of 190 | elapsed:  1.2min remaining:   35.8s
[CV] ............... classifier__max_depth=12, score=0.784416 -   0.4s
[CV] classifier__max_depth=7 .........................................
[CV] ................ classifier__max_depth=2, score=0.726597 -   0.6s
[CV] classifier__max_depth=8 .........................................
[CV] ................ classifier__max_depth=4, score=0.757159 -   0.2s
[CV] classifier__max_depth=9 .........................................
[CV] ............... classifier__max_depth=14, score=0.781663 -   0.3s
[CV] classifier__max_depth=10 ........................................
[CV] ............... classifier__max_depth=19, score=0.777395 -   0.5s
[CV] classifier__max_depth=11 ........................................
[CV] ................ classifier__max_depth=3, score=0.736646 -   0.2s
[CV] ................ classifier__max_depth=5, score=0.771338 -   0.4s
[CV] classifier__max_depth=12 ........................................
[CV] classifier__max_depth=13 ........................................
[CV] ............... classifier__max_depth=16, score=0.780699 -   0.3s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=7, score=0.787858 -   0.4s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=8, score=0.792676 -   0.5s
[CV] classifier__max_depth=16 ........................................
[CV] ............... classifier__max_depth=10, score=0.797081 -   0.2s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=9, score=0.793915 -   0.2s
[CV] classifier__max_depth=18 ........................................
[CV] ................ classifier__max_depth=6, score=0.781250 -   0.2s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=15, score=0.782351 -   0.4s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=13, score=0.795017 -   0.2s
[CV] classifier__max_depth=2 .........................................
[CV] ............... classifier__max_depth=11, score=0.795567 -   0.3s
[CV] classifier__max_depth=3 .........................................
[CV] ............... classifier__max_depth=15, score=0.793227 -   0.3s
[CV] classifier__max_depth=4 .........................................
[CV] ............... classifier__max_depth=12, score=0.794466 -   0.3s
[CV] classifier__max_depth=5 .........................................
[CV] ................ classifier__max_depth=2, score=0.705534 -   0.2s
[CV] classifier__max_depth=6 .........................................
[CV] ............... classifier__max_depth=18, score=0.789923 -   0.5s
[CV] classifier__max_depth=7 .........................................
[CV] ............... classifier__max_depth=19, score=0.788271 -   0.4s
[CV] classifier__max_depth=8 .........................................
[CV] ............... classifier__max_depth=16, score=0.794741 -   0.3s
[CV] classifier__max_depth=9 .........................................
[CV] ................ classifier__max_depth=1, score=0.696311 -   0.4s
[CV] classifier__max_depth=10 ........................................
[CV] ............... classifier__max_depth=17, score=0.791300 -   0.5s
[CV] classifier__max_depth=11 ........................................
[CV] ................ classifier__max_depth=5, score=0.749312 -   0.2s
[CV] classifier__max_depth=12 ........................................
[CV] ................ classifier__max_depth=4, score=0.738298 -   0.2s
[CV] classifier__max_depth=13 ........................................
[CV] ............... classifier__max_depth=14, score=0.791437 -   0.5s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=7, score=0.773541 -   0.2s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=3, score=0.722192 -   0.2s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=6, score=0.762941 -   0.2s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=8, score=0.782214 -   0.2s
[CV] classifier__max_depth=18 ........................................
[CV] ............... classifier__max_depth=11, score=0.787858 -   0.2s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=12, score=0.787996 -   0.3s
[CV] classifier__max_depth=1 .........................................
[CV] ............... classifier__max_depth=10, score=0.787720 -   0.4s
[CV] classifier__max_depth=2 .........................................
[CV] ................ classifier__max_depth=9, score=0.787307 -   0.3s
[CV] classifier__max_depth=3 .........................................
[CV] ............... classifier__max_depth=16, score=0.784141 -   0.3s
[CV] classifier__max_depth=4 .........................................
[CV] ................ classifier__max_depth=1, score=0.697370 -   0.2s
[CV] classifier__max_depth=5 .........................................
[CV] ................ classifier__max_depth=3, score=0.716646 -   0.2s
[CV] classifier__max_depth=6 .........................................
[CV] ............... classifier__max_depth=14, score=0.788271 -   0.3s
[CV] classifier__max_depth=7 .........................................
[CV] ............... classifier__max_depth=15, score=0.786344 -   0.3s
[CV] classifier__max_depth=8 .........................................
[CV] ............... classifier__max_depth=18, score=0.784416 -   0.2s
[CV] classifier__max_depth=9 .........................................
[CV] ................ classifier__max_depth=4, score=0.732480 -   0.2s
[CV] classifier__max_depth=10 ........................................
[CV] ............... classifier__max_depth=13, score=0.788271 -   0.3s
[CV] classifier__max_depth=11 ........................................
[CV] ................ classifier__max_depth=2, score=0.703566 -   0.2s
[CV] classifier__max_depth=12 ........................................
[CV] ............... classifier__max_depth=19, score=0.782902 -   0.2s
[CV] classifier__max_depth=13 ........................................
[CV] ................ classifier__max_depth=6, score=0.759879 -   0.2s
[CV] classifier__max_depth=14 ........................................
[CV] ................ classifier__max_depth=5, score=0.745835 -   0.2s
[CV] classifier__max_depth=15 ........................................
[CV] ................ classifier__max_depth=8, score=0.774886 -   0.2s
[CV] classifier__max_depth=16 ........................................
[CV] ................ classifier__max_depth=7, score=0.772821 -   0.4s
[CV] classifier__max_depth=17 ........................................
[CV] ................ classifier__max_depth=9, score=0.782459 -   0.3s
[CV] classifier__max_depth=18 ........................................
[CV] ............... classifier__max_depth=17, score=0.783728 -   0.5s
[CV] classifier__max_depth=19 ........................................
[CV] ............... classifier__max_depth=15, score=0.783423 -   0.3s
[CV] ............... classifier__max_depth=10, score=0.783147 -   0.2s
[CV] ............... classifier__max_depth=17, score=0.782321 -   0.3s
[CV] ............... classifier__max_depth=14, score=0.779430 -   0.2s
[CV] ............... classifier__max_depth=11, score=0.783698 -   0.2s
[CV] ............... classifier__max_depth=13, score=0.782734 -   0.2s
[CV] ............... classifier__max_depth=16, score=0.780118 -   0.2s
[CV] ............... classifier__max_depth=12, score=0.785488 -   0.2s
[CV] ............... classifier__max_depth=18, score=0.781358 -   0.1s
[CV] ............... classifier__max_depth=19, score=0.779705 -   0.2s
[Parallel(n_jobs=10)]: Done 190 out of 190 | elapsed:  1.8min finished

In [4]:
max_depth_list = [2, 8, 10, 20]

fig, axarr = plt.subplots(2,2)
for depth, ax in zip(max_depth_list, axarr.flatten()):
    pipeline = get_pipeline('RF')
    pipeline.named_steps['classifier'].set_params(max_depth=depth)
    pipeline.fit(X_train, y_train)
    scaler = pipeline.named_steps['scaler']
    clf = pipeline.named_steps['classifier']
    X_test_std = scaler.transform(X_test)
    plot_decision_regions(X_test_std, y_test, clf, scatter_fraction=None, ax=ax)
    ax.set_xlabel('Scaled energy')
    ax.set_ylabel('Scaled charge')
    ax.set_title('Max depth = {}'.format(depth))
    ax.legend()
plt.tight_layout()
plt.savefig('/home/jbourbeau/public_html/figures/composition/parameter-tuning/RF-decision-regions.png')


/home/jbourbeau/.local/lib/python2.7/site-packages/matplotlib/axes/_axes.py:519: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.
  warnings.warn("No labelled objects found. "

In [5]:
pipeline = get_pipeline('RF')
param_range = np.arange(1, 20)
param_grid = {'classifier__max_depth': param_range}
gs = GridSearchCV(estimator=pipeline, 
                  param_grid=param_grid, 
                  scoring='accuracy', 
                  cv=10,
                  n_jobs=10)
gs = gs.fit(X_train, y_train)
print(gs.best_score_)
print(gs.best_params_)


0.7872226199
{'classifier__max_depth': 10}

In [ ]: