In [20]:
import sys
sys.path.insert(0, "C:/Users/magaxels/AutoML")

from gazer import GazerMetaLearner

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits

Load some toy dataset and split into train and validation


In [24]:
X, y = load_digits(return_X_y=True)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=0)

print(X_train.shape, y_train.shape, X_val.shape, y_val.shape)


(1437, 64) (1437,) (360, 64) (360,)

Define a learner object using method='select' and estimators=['neuralnet']


In [13]:
learner = GazerMetaLearner(method='select', estimators=['neuralnet'], verbose=1)


Available algorithms (use '.clf' attribute for access):
neuralnet

The entry point to network optimization is found in the optimization module


In [15]:
from gazer.optimization import grid_search

It expects the data to be shipped in the following format:


In [22]:
data = {'train': (X_train, y_train), 'val': (X_val, y_val)}

We also provide a dictionary of iterables to iterate over.


In [23]:
params = {
    'batch_norm': (True, False),
    'batch_size': 16,
    'dropout': True,
    'epochs': (10, 20),
    'input_units': np.linspace(250, 500, 6, dtype=int),
    'n_hidden': (2,3),
    'p': (0.1, 0.5),
    'validation_split': 0.0,
}

Perform grid search over "architectures"


In [16]:
config, df = grid_search(learner, params, data)


C:\Users\magaxels\anaconda3\lib\site-packages\keras\callbacks.py:972: RuntimeWarning: Reduce LR on plateau conditioned on metric `val_loss` which is not available. Available metrics are: loss,acc,lr
  (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
C:\Users\magaxels\anaconda3\lib\site-packages\keras\callbacks.py:526: RuntimeWarning: Early stopping conditioned on metric `val_loss` which is not available. Available metrics are: loss,acc,lr
  (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
360/360 [==============================] - ETA:  - 0s 381us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 61us/step
360/360 [==============================] - ETA:  - 0s 264us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 46us/step
360/360 [==============================] - ETA:  - 0s 513us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 59us/step
360/360 [==============================] - ETA:  - 0s 492us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 54us/step
360/360 [==============================] - ETA:  - 0s 616us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 56us/step
360/360 [==============================] - ETA:  - 0s 546us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 52us/step
360/360 [==============================] - ETA:  - 0s 686us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 67us/step
360/360 [==============================] - ETA:  - 0s 805us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 64us/step
360/360 [==============================] - ETA:  - 0s 867us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 69us/step
360/360 [==============================] - ETA:  - 0s 833us/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 64us/step
360/360 [==============================] - ETA:  - 0s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA:  - 0s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 60us/step
360/360 [==============================] - ETA:  - 0s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 77us/step
360/360 [==============================] - ETA:  - 0s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 57us/step
360/360 [==============================] - ETA:  - 1s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 80us/step
360/360 [==============================] - ETA:  - 0s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 68us/step
360/360 [==============================] - ETA:  - 0s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 84us/step
360/360 [==============================] - ETA:  - 1s 1ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 69us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 84us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 86us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA:  - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 69us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 81us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 81us/step
360/360 [==============================] - ETA:  - 1s 2ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 85us/step
360/360 [==============================] - ETA: 10 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA: 10 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 96us/step
360/360 [==============================] - ETA:  - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 70us/step
360/360 [==============================] - ETA:  - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 81us/step
360/360 [==============================] - ETA: 11 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 85us/step
360/360 [==============================] - ETA: 10 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 101us/step
360/360 [==============================] - ETA:  - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 85us/step
360/360 [==============================] - ETA: 10 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 98us/step
360/360 [==============================] - ETA: 10 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 86us/step
360/360 [==============================] - ETA: 17 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 94us/step
360/360 [==============================] - ETA: 11 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 92us/step
360/360 [==============================] - ETA: 12 - 1s 3ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 87us/step
360/360 [==============================] - ETA: 12 - 1s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - 0s 90us/step
360/360 [==============================] - ETA: 12 - 1s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 104us/step
360/360 [==============================] - ETA: 13 - 1s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 85us/step
360/360 [==============================] - ETA: 13 - 1s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 104us/step
360/360 [==============================] - ETA: 13 - 1s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 98us/step
360/360 [==============================] - ETA: 14 - 1s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 96us/step
360/360 [==============================] - ETA: 17 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 107us/step
360/360 [==============================] - ETA: 14 - 1s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 112us/step
360/360 [==============================] - ETA: 15 - 2s 4ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 105us/step
360/360 [==============================] - ETA: 17 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 100us/step
360/360 [==============================] - ETA: 21 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 122us/step
360/360 [==============================] - ETA: 18 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 105us/step
360/360 [==============================] - ETA: 16 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 88us/step
360/360 [==============================] - ETA: 18 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 113us/step
360/360 [==============================] - ETA: 17 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 99us/step
360/360 [==============================] - ETA: 20 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 113us/step
360/360 [==============================] - ETA: 20 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 105us/step
360/360 [==============================] - ETA: 20 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 137us/step
360/360 [==============================] - ETA: 19 - 2s 5ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 107us/step
360/360 [==============================] - ETA: 20 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 128us/step
360/360 [==============================] - ETA: 21 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 109us/step
360/360 [==============================] - ETA: 22 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 122us/step
360/360 [==============================] - ETA: 21 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 112us/step
360/360 [==============================] - ETA: 22 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 117us/step
360/360 [==============================] - ETA: 23 - 2s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 107us/step
360/360 [==============================] - ETA: 23 - 2s 6ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 134us/step
360/360 [==============================] - ETA: 24 - 2s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 105us/step
360/360 [==============================] - ETA: 24 - 2s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 149us/step
360/360 [==============================] - ETA: 24 - 2s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 118us/step
360/360 [==============================] - ETA: 25 - 3s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 133us/step
360/360 [==============================] - ETA: 26 - 3s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 120us/step
360/360 [==============================] - ETA: 25 - 3s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 135us/step
360/360 [==============================] - ETA: 26 - 3s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 116us/step
360/360 [==============================] - ETA: 28 - 3s 8ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 134us/step
360/360 [==============================] - ETA: 27 - 3s 7ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 128us/step
360/360 [==============================] - ETA: 27 - ETA: 0 - 3s 8ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 133us/step
360/360 [==============================] - ETA: 28 - 3s 8ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 113us/step
360/360 [==============================] - ETA: 31 - ETA: 0 - 3s 9ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 142us/step
360/360 [==============================] - ETA: 29 - 3s 8ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - 0s 128us/step
360/360 [==============================] - ETA: 30 - 3s 8ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 152us/step
360/360 [==============================] - ETA: 30 - ETA: 0 - 3s 8ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 133us/step
360/360 [==============================] - ETA: 31 - 3s 9ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 136us/step
360/360 [==============================] - ETA: 32 - 3s 9ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 138us/step
360/360 [==============================] - ETA: 31 - 3s 9ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 150us/step
360/360 [==============================] - ETA: 32 - 3s 9ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 149us/step
360/360 [==============================] - ETA: 34 - 3s 9ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 148us/step
360/360 [==============================] - ETA: 37 - ETA: 0 - 4s 10ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 137us/step
360/360 [==============================] - ETA: 35 - ETA: 0 - 3s 10ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 165us/step
360/360 [==============================] - ETA: 36 - 4s 10ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 148us/step
360/360 [==============================] - ETA: 37 - ETA: 0 - 4s 10ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - ETA:  - 0s 172us/step
360/360 [==============================] - ETA: 36 - ETA: 0 - 4s 10ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 137us/step
360/360 [==============================] - ETA: 37 - ETA: 0 - 4s 10ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - ETA:  - 0s 187us/step
360/360 [==============================] - ETA: 37 - 4s 10ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - 0s 150us/step
360/360 [==============================] - ETA: 39 - ETA: 0 - 4s 11ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - ETA:  - 0s 173us/step
360/360 [==============================] - ETA: 38 - ETA: 0 - 4s 11ms/step
1437/1437 [==============================] - ETA:  - ETA:  - ETA:  - ETA:  - ETA:  - 0s 173us/step

Have a look at results


In [19]:
df.head()


Out[19]:
index batch_norm batch_size dropout epochs input_units n_hidden p train_loss train_score val_loss val_score validation_split
0 11 False 16 True 10 350 2 0.5 0.0378 0.9875 0.0480 0.9944 0.0
1 23 False 16 True 10 500 2 0.5 0.0208 0.9965 0.0554 0.9944 0.0
2 4 True 16 True 10 300 2 0.1 0.0058 0.9993 0.0566 0.9944 0.0
3 21 False 16 True 10 500 2 0.1 0.0042 0.9993 0.0626 0.9917 0.0
4 44 True 16 True 20 500 2 0.1 0.0069 0.9979 0.0745 0.9917 0.0

The best estimator parameters are found in the 'config' dictionary:


In [18]:
config


Out[18]:
{'batch_norm': False,
 'batch_size': 16,
 'dropout': True,
 'epochs': 10,
 'index': 11,
 'input_units': 350,
 'n_hidden': 2,
 'p': 0.5,
 'validation_split': 0.0}

End of demo