In [44]:
import sklearn.linear_model
import sklearn.datasets
import sklearn.model_selection
diabetes = sklearn.datasets.load_diabetes()
X = diabetes['data']
y = diabetes['target']
Preprocess data.
In [45]:
X = sklearn.preprocessing.PolynomialFeatures(5).fit_transform(X)
Now make a train/test/validation split.
In [46]:
train_X, test_X, validation_X = X[:300], X[300:400], X[400:]
train_y, test_y, validation_y = y[:300], y[300:400], y[400:]
Tune hyperparameters.
In [52]:
sgd = sklearn.linear_model.SGDRegressor()
parameters = {'alpha': [10 ** i for i in range(-5, 6)], 'eta0': [10 ** i for i in range(-3, 0)]}
clf = sklearn.model_selection.GridSearchCV(sgd, parameters)
clf.fit(train_X, train_y)
Out[52]:
{'mean_fit_time': array([ 0.03899145, 0.03139973, 0.03181307, 0.03123728, 0.03258602,
0.03123943, 0.03194666, 0.0314192 , 0.03141419, 0.03168527,
0.03674324, 0.03147395, 0.03154318, 0.03158673, 0.03125326,
0.03126216, 0.03243995, 0.03195294, 0.03139162, 0.03123784,
0.03151615, 0.03153539, 0.03135864, 0.03283985, 0.03176967,
0.03293435, 0.03277 , 0.03289763, 0.03274067, 0.03275625,
0.03322196, 0.03438965, 0.03366804]),
'mean_score_time': array([ 0.00244633, 0.00190361, 0.00187556, 0.00180499, 0.00198142,
0.00182501, 0.00186833, 0.00188708, 0.00187031, 0.00211676,
0.0022347 , 0.00182939, 0.00181079, 0.00184798, 0.00182001,
0.00188581, 0.00187262, 0.00182748, 0.00189869, 0.00179744,
0.00184218, 0.00183455, 0.00182867, 0.00182859, 0.00183241,
0.00184059, 0.00194502, 0.00181595, 0.00184123, 0.00188835,
0.00182637, 0.00189662, 0.00194216]),
'mean_test_score': array([-1.47541777, -0.04239653, -0.00593487, -1.47699681, -0.04221579,
0.04884508, -1.47454521, -0.04514547, 0.08023233, -1.47497606,
-0.03988915, 0.0052768 , -1.486762 , -0.03840523, -0.01408577,
-1.57291348, -0.0826831 , -0.05631788, -2.02005225, -0.08796999,
-0.04417944, -2.33477068, -0.07349925, -0.05518288, -2.37049415,
-0.07635296, -0.06586221, -2.3673979 , -0.07544157, -0.04893762,
-2.37027145, -0.07301434, -0.05423084]),
'mean_train_score': array([ -1.44051426e+00, 1.35385587e-02, 1.05506848e-01,
-1.44141150e+00, 1.36190424e-02, 1.15114627e-01,
-1.44013947e+00, 1.33744944e-02, 1.14023721e-01,
-1.43969984e+00, 1.27193367e-02, 9.39772249e-02,
-1.45120497e+00, 1.02241249e-02, 5.05302816e-02,
-1.53631818e+00, -4.21031321e-02, 4.06945218e-03,
-1.97362383e+00, -4.41840177e-02, -2.09555412e-04,
-2.28102030e+00, -3.20381359e-02, -5.07752602e-04,
-2.31595797e+00, -3.36652135e-02, -4.92334356e-03,
-2.31282060e+00, -3.19352988e-02, -4.24385967e-03,
-2.31594870e+00, -3.21759140e-02, -7.02862809e-03]),
'param_alpha': masked_array(data = [1e-05 1e-05 1e-05 0.0001 0.0001 0.0001 0.001 0.001 0.001 0.01 0.01 0.01
0.1 0.1 0.1 1 1 1 10 10 10 100 100 100 1000 1000 1000 10000 10000 10000
100000 100000 100000],
mask = [False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False],
fill_value = ?),
'param_eta0': masked_array(data = [0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1
0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1
0.001 0.01 0.1],
mask = [False False False False False False False False False False False False
False False False False False False False False False False False False
False False False False False False False False False],
fill_value = ?),
'params': ({'alpha': 1e-05, 'eta0': 0.001},
{'alpha': 1e-05, 'eta0': 0.01},
{'alpha': 1e-05, 'eta0': 0.1},
{'alpha': 0.0001, 'eta0': 0.001},
{'alpha': 0.0001, 'eta0': 0.01},
{'alpha': 0.0001, 'eta0': 0.1},
{'alpha': 0.001, 'eta0': 0.001},
{'alpha': 0.001, 'eta0': 0.01},
{'alpha': 0.001, 'eta0': 0.1},
{'alpha': 0.01, 'eta0': 0.001},
{'alpha': 0.01, 'eta0': 0.01},
{'alpha': 0.01, 'eta0': 0.1},
{'alpha': 0.1, 'eta0': 0.001},
{'alpha': 0.1, 'eta0': 0.01},
{'alpha': 0.1, 'eta0': 0.1},
{'alpha': 1, 'eta0': 0.001},
{'alpha': 1, 'eta0': 0.01},
{'alpha': 1, 'eta0': 0.1},
{'alpha': 10, 'eta0': 0.001},
{'alpha': 10, 'eta0': 0.01},
{'alpha': 10, 'eta0': 0.1},
{'alpha': 100, 'eta0': 0.001},
{'alpha': 100, 'eta0': 0.01},
{'alpha': 100, 'eta0': 0.1},
{'alpha': 1000, 'eta0': 0.001},
{'alpha': 1000, 'eta0': 0.01},
{'alpha': 1000, 'eta0': 0.1},
{'alpha': 10000, 'eta0': 0.001},
{'alpha': 10000, 'eta0': 0.01},
{'alpha': 10000, 'eta0': 0.1},
{'alpha': 100000, 'eta0': 0.001},
{'alpha': 100000, 'eta0': 0.01},
{'alpha': 100000, 'eta0': 0.1}),
'rank_test_score': array([25, 9, 4, 26, 8, 2, 23, 11, 1, 24, 7, 3, 27, 6, 5, 28, 21,
15, 29, 22, 10, 30, 18, 14, 33, 20, 16, 31, 19, 12, 32, 17, 13], dtype=int32),
'split0_test_score': array([ -1.17011763e+00, -8.86401924e-02, -1.72178753e-01,
-1.16475154e+00, -9.09386199e-02, -6.59543291e-02,
-1.16624041e+00, -9.27816873e-02, 7.52643335e-02,
-1.16515726e+00, -7.31289496e-02, -5.44449101e-02,
-1.18355456e+00, -5.84160104e-02, -7.96296006e-02,
-1.26945587e+00, -1.37756271e-03, -9.25000337e-02,
-1.75151802e+00, -7.78977095e-03, -8.29077547e-02,
-2.09776805e+00, -1.46441233e-02, -1.03615313e-01,
-2.13617512e+00, -1.46560014e-02, -1.70043292e-01,
-2.13478891e+00, -1.63392608e-02, -5.62908198e-02,
-2.13102304e+00, -1.44777313e-02, -8.86116383e-02]),
'split0_train_score': array([ -1.46984904e+00, 1.69745429e-02, 9.78465658e-02,
-1.46459838e+00, 1.76762412e-02, 1.29755457e-01,
-1.46612126e+00, 1.72188134e-02, 1.26756214e-01,
-1.46495556e+00, 1.60376280e-02, 1.21814030e-01,
-1.48273600e+00, 1.26445651e-02, 5.91060210e-02,
-1.56477556e+00, -4.01499737e-02, 6.28051688e-03,
-2.01497575e+00, -4.35927693e-02, -1.39044841e-03,
-2.33034121e+00, -3.42802575e-02, -1.79947299e-04,
-2.36501801e+00, -3.43915279e-02, -4.10328780e-03,
-2.36378378e+00, -3.22988778e-02, -7.33661052e-03,
-2.36040583e+00, -3.46015176e-02, -1.19015206e-03]),
'split1_test_score': array([ -1.87325777e+00, -4.16353889e-02, 5.57258848e-02,
-1.88539783e+00, -4.08545807e-02, 9.75130959e-02,
-1.86965697e+00, -4.73982902e-02, 7.05912490e-02,
-1.87894368e+00, -4.62969160e-02, 1.54608282e-02,
-1.88868229e+00, -5.47404119e-02, 2.10494608e-03,
-1.97881936e+00, -1.70188567e-01, -7.66238372e-02,
-2.45477783e+00, -1.88960400e-01, -4.40458462e-02,
-2.78058229e+00, -1.47509886e-01, -5.86328642e-02,
-2.81774329e+00, -1.53804594e-01, -2.80798978e-02,
-2.81387377e+00, -1.52160833e-01, -7.94516163e-02,
-2.82014748e+00, -1.44708656e-01, -3.29376781e-02]),
'split1_train_score': array([ -1.34101832e+00, 1.22896507e-02, 1.16264148e-01,
-1.35134089e+00, 1.15337908e-02, 1.20431926e-01,
-1.33797059e+00, 1.12225088e-02, 1.16110891e-01,
-1.34583604e+00, 1.18870034e-02, 9.32233371e-02,
-1.35414004e+00, 9.44809222e-03, 4.94562364e-02,
-1.43093627e+00, -3.69233944e-02, 1.18515551e-04,
-1.84131898e+00, -4.80193715e-02, 4.91874172e-04,
-2.12603497e+00, -2.84759687e-02, -7.20924235e-04,
-2.15867968e+00, -3.13317745e-02, -2.32405167e-03,
-2.15528191e+00, -3.05860975e-02, -4.39662377e-03,
-2.16079587e+00, -2.72815283e-02, -4.91058617e-04]),
'split2_test_score': array([ -1.38287791e+00, 3.08598695e-03, 9.86482635e-02,
-1.38084105e+00, 5.14583605e-03, 1.14976480e-01,
-1.38773827e+00, 4.74357176e-03, 9.48414215e-02,
-1.38082723e+00, -2.41587663e-04, 5.48144782e-02,
-1.38804914e+00, -2.05927402e-03, 3.52673556e-02,
-1.47046521e+00, -7.64831621e-02, 1.70238309e-04,
-1.85386091e+00, -6.71597850e-02, -5.58473292e-03,
-2.12596170e+00, -5.83437380e-02, -3.30047661e-03,
-2.15756403e+00, -6.05982958e-02, 5.36550805e-04,
-2.15353103e+00, -5.78246226e-02, -1.10704367e-02,
-2.15964383e+00, -5.98566239e-02, -4.11431995e-02]),
'split2_train_score': array([ -1.51067542e+00, 1.13514825e-02, 1.02409831e-01,
-1.50829523e+00, 1.16470952e-02, 9.51564988e-02,
-1.51632655e+00, 1.16821609e-02, 9.92040589e-02,
-1.50830792e+00, 1.02333788e-02, 6.68943078e-02,
-1.51673887e+00, 8.57971737e-03, 4.30285874e-02,
-1.61324271e+00, -4.92360281e-02, 5.80932411e-03,
-2.06457677e+00, -4.09399123e-02, 2.69908004e-04,
-2.38668472e+00, -3.33581816e-02, -6.22386273e-04,
-2.42417622e+00, -3.52723381e-02, -8.34269121e-03,
-2.41939611e+00, -3.29209210e-02, -9.98344706e-04,
-2.42664442e+00, -3.46446961e-02, -1.94046736e-02]),
'std_fit_time': array([ 5.80643365e-03, 2.50607982e-04, 7.50145489e-04,
9.68735368e-05, 9.05241245e-04, 9.49277943e-05,
6.70867376e-04, 4.02696124e-05, 2.14967578e-04,
5.05292067e-04, 3.29635511e-03, 3.66729335e-04,
4.24646194e-04, 3.66423752e-04, 3.27715470e-05,
7.68797951e-05, 1.21453618e-03, 4.63699817e-04,
9.22900633e-05, 2.05436958e-05, 2.40985273e-04,
4.86494187e-04, 1.80813349e-04, 1.91486124e-04,
3.36317969e-04, 3.27446680e-04, 1.36301816e-04,
2.42968406e-04, 2.47636717e-05, 3.19968827e-05,
3.28112708e-04, 1.21716252e-03, 1.22846081e-03]),
'std_score_time': array([ 8.15746032e-04, 3.89381329e-05, 1.98297128e-05,
3.35181438e-05, 1.61751305e-04, 3.40715267e-06,
3.36974297e-05, 2.78878407e-05, 6.80019312e-05,
1.37843791e-04, 4.84638453e-04, 7.23319167e-05,
1.92583582e-05, 4.65170880e-06, 1.75658393e-05,
6.23582972e-05, 3.99312963e-05, 3.41768411e-05,
1.83043094e-05, 2.24577984e-05, 2.39499608e-05,
9.32443924e-06, 5.22711552e-06, 4.18573827e-06,
1.24120223e-05, 2.24982599e-05, 1.38867856e-04,
1.82046632e-05, 2.25624554e-05, 1.03674775e-05,
1.39487887e-05, 5.47930644e-05, 1.33533278e-04]),
'std_test_score': array([ 0.29441946, 0.03745092, 0.11885104, 0.30195719, 0.03923812,
0.08148792, 0.29365547, 0.03984638, 0.01050487, 0.29890998,
0.03009913, 0.04518251, 0.29620877, 0.02574424, 0.04828341,
0.29851954, 0.0690561 , 0.04046557, 0.31022386, 0.07541219,
0.03156713, 0.31544647, 0.05529072, 0.04102595, 0.31637342,
0.0578892 , 0.07458775, 0.31579882, 0.05683099, 0.02839658,
0.3183249 , 0.05397448, 0.02454061]),
'std_train_score': array([ 0.07230159, 0.00245961, 0.00783137, 0.06614071, 0.00286925,
0.01461682, 0.0750955 , 0.00272481, 0.01134453, 0.06869093,
0.00244157, 0.02242722, 0.07002499, 0.00174784, 0.00660738,
0.07709854, 0.00521289, 0.00280035, 0.09572006, 0.00292026,
0.00083992, 0.11197913, 0.00254681, 0.00023526, 0.11380471,
0.00168872, 0.0025246 , 0.11368674, 0.00098725, 0.00258984,
0.11299318, 0.0034609 , 0.00875584])}
Make a predictor and train.
In [47]:
clf.score(train_X, train_y)
Out[47]:
SGDRegressor(alpha=0.0001, average=False, epsilon=0.1, eta0=0.01,
fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',
loss='squared_loss', n_iter=5, penalty='l2', power_t=0.25,
random_state=None, shuffle=True, verbose=0, warm_start=False)
In [48]:
pred.score(validation_X, validation_y)
Out[48]:
0.021233124890070117
In [ ]:
Content source: alasdairtran/mclearn
Similar notebooks: