In [44]:
import sklearn.linear_model
import sklearn.datasets
import sklearn.model_selection

diabetes = sklearn.datasets.load_diabetes()
X = diabetes['data']
y = diabetes['target']

Preprocess data.


In [45]:
X = sklearn.preprocessing.PolynomialFeatures(5).fit_transform(X)

Now make a train/test/validation split.


In [46]:
train_X, test_X, validation_X = X[:300], X[300:400], X[400:]
train_y, test_y, validation_y = y[:300], y[300:400], y[400:]

Tune hyperparameters.


In [52]:
sgd = sklearn.linear_model.SGDRegressor()

parameters = {'alpha': [10 ** i for i in range(-5, 6)], 'eta0': [10 ** i for i in range(-3, 0)]}
clf = sklearn.model_selection.GridSearchCV(sgd, parameters)
clf.fit(train_X, train_y)


Out[52]:
{'mean_fit_time': array([ 0.03899145,  0.03139973,  0.03181307,  0.03123728,  0.03258602,
         0.03123943,  0.03194666,  0.0314192 ,  0.03141419,  0.03168527,
         0.03674324,  0.03147395,  0.03154318,  0.03158673,  0.03125326,
         0.03126216,  0.03243995,  0.03195294,  0.03139162,  0.03123784,
         0.03151615,  0.03153539,  0.03135864,  0.03283985,  0.03176967,
         0.03293435,  0.03277   ,  0.03289763,  0.03274067,  0.03275625,
         0.03322196,  0.03438965,  0.03366804]),
 'mean_score_time': array([ 0.00244633,  0.00190361,  0.00187556,  0.00180499,  0.00198142,
         0.00182501,  0.00186833,  0.00188708,  0.00187031,  0.00211676,
         0.0022347 ,  0.00182939,  0.00181079,  0.00184798,  0.00182001,
         0.00188581,  0.00187262,  0.00182748,  0.00189869,  0.00179744,
         0.00184218,  0.00183455,  0.00182867,  0.00182859,  0.00183241,
         0.00184059,  0.00194502,  0.00181595,  0.00184123,  0.00188835,
         0.00182637,  0.00189662,  0.00194216]),
 'mean_test_score': array([-1.47541777, -0.04239653, -0.00593487, -1.47699681, -0.04221579,
         0.04884508, -1.47454521, -0.04514547,  0.08023233, -1.47497606,
        -0.03988915,  0.0052768 , -1.486762  , -0.03840523, -0.01408577,
        -1.57291348, -0.0826831 , -0.05631788, -2.02005225, -0.08796999,
        -0.04417944, -2.33477068, -0.07349925, -0.05518288, -2.37049415,
        -0.07635296, -0.06586221, -2.3673979 , -0.07544157, -0.04893762,
        -2.37027145, -0.07301434, -0.05423084]),
 'mean_train_score': array([ -1.44051426e+00,   1.35385587e-02,   1.05506848e-01,
         -1.44141150e+00,   1.36190424e-02,   1.15114627e-01,
         -1.44013947e+00,   1.33744944e-02,   1.14023721e-01,
         -1.43969984e+00,   1.27193367e-02,   9.39772249e-02,
         -1.45120497e+00,   1.02241249e-02,   5.05302816e-02,
         -1.53631818e+00,  -4.21031321e-02,   4.06945218e-03,
         -1.97362383e+00,  -4.41840177e-02,  -2.09555412e-04,
         -2.28102030e+00,  -3.20381359e-02,  -5.07752602e-04,
         -2.31595797e+00,  -3.36652135e-02,  -4.92334356e-03,
         -2.31282060e+00,  -3.19352988e-02,  -4.24385967e-03,
         -2.31594870e+00,  -3.21759140e-02,  -7.02862809e-03]),
 'param_alpha': masked_array(data = [1e-05 1e-05 1e-05 0.0001 0.0001 0.0001 0.001 0.001 0.001 0.01 0.01 0.01
  0.1 0.1 0.1 1 1 1 10 10 10 100 100 100 1000 1000 1000 10000 10000 10000
  100000 100000 100000],
              mask = [False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False],
        fill_value = ?),
 'param_eta0': masked_array(data = [0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1
  0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1 0.001 0.01 0.1
  0.001 0.01 0.1],
              mask = [False False False False False False False False False False False False
  False False False False False False False False False False False False
  False False False False False False False False False],
        fill_value = ?),
 'params': ({'alpha': 1e-05, 'eta0': 0.001},
  {'alpha': 1e-05, 'eta0': 0.01},
  {'alpha': 1e-05, 'eta0': 0.1},
  {'alpha': 0.0001, 'eta0': 0.001},
  {'alpha': 0.0001, 'eta0': 0.01},
  {'alpha': 0.0001, 'eta0': 0.1},
  {'alpha': 0.001, 'eta0': 0.001},
  {'alpha': 0.001, 'eta0': 0.01},
  {'alpha': 0.001, 'eta0': 0.1},
  {'alpha': 0.01, 'eta0': 0.001},
  {'alpha': 0.01, 'eta0': 0.01},
  {'alpha': 0.01, 'eta0': 0.1},
  {'alpha': 0.1, 'eta0': 0.001},
  {'alpha': 0.1, 'eta0': 0.01},
  {'alpha': 0.1, 'eta0': 0.1},
  {'alpha': 1, 'eta0': 0.001},
  {'alpha': 1, 'eta0': 0.01},
  {'alpha': 1, 'eta0': 0.1},
  {'alpha': 10, 'eta0': 0.001},
  {'alpha': 10, 'eta0': 0.01},
  {'alpha': 10, 'eta0': 0.1},
  {'alpha': 100, 'eta0': 0.001},
  {'alpha': 100, 'eta0': 0.01},
  {'alpha': 100, 'eta0': 0.1},
  {'alpha': 1000, 'eta0': 0.001},
  {'alpha': 1000, 'eta0': 0.01},
  {'alpha': 1000, 'eta0': 0.1},
  {'alpha': 10000, 'eta0': 0.001},
  {'alpha': 10000, 'eta0': 0.01},
  {'alpha': 10000, 'eta0': 0.1},
  {'alpha': 100000, 'eta0': 0.001},
  {'alpha': 100000, 'eta0': 0.01},
  {'alpha': 100000, 'eta0': 0.1}),
 'rank_test_score': array([25,  9,  4, 26,  8,  2, 23, 11,  1, 24,  7,  3, 27,  6,  5, 28, 21,
        15, 29, 22, 10, 30, 18, 14, 33, 20, 16, 31, 19, 12, 32, 17, 13], dtype=int32),
 'split0_test_score': array([ -1.17011763e+00,  -8.86401924e-02,  -1.72178753e-01,
         -1.16475154e+00,  -9.09386199e-02,  -6.59543291e-02,
         -1.16624041e+00,  -9.27816873e-02,   7.52643335e-02,
         -1.16515726e+00,  -7.31289496e-02,  -5.44449101e-02,
         -1.18355456e+00,  -5.84160104e-02,  -7.96296006e-02,
         -1.26945587e+00,  -1.37756271e-03,  -9.25000337e-02,
         -1.75151802e+00,  -7.78977095e-03,  -8.29077547e-02,
         -2.09776805e+00,  -1.46441233e-02,  -1.03615313e-01,
         -2.13617512e+00,  -1.46560014e-02,  -1.70043292e-01,
         -2.13478891e+00,  -1.63392608e-02,  -5.62908198e-02,
         -2.13102304e+00,  -1.44777313e-02,  -8.86116383e-02]),
 'split0_train_score': array([ -1.46984904e+00,   1.69745429e-02,   9.78465658e-02,
         -1.46459838e+00,   1.76762412e-02,   1.29755457e-01,
         -1.46612126e+00,   1.72188134e-02,   1.26756214e-01,
         -1.46495556e+00,   1.60376280e-02,   1.21814030e-01,
         -1.48273600e+00,   1.26445651e-02,   5.91060210e-02,
         -1.56477556e+00,  -4.01499737e-02,   6.28051688e-03,
         -2.01497575e+00,  -4.35927693e-02,  -1.39044841e-03,
         -2.33034121e+00,  -3.42802575e-02,  -1.79947299e-04,
         -2.36501801e+00,  -3.43915279e-02,  -4.10328780e-03,
         -2.36378378e+00,  -3.22988778e-02,  -7.33661052e-03,
         -2.36040583e+00,  -3.46015176e-02,  -1.19015206e-03]),
 'split1_test_score': array([ -1.87325777e+00,  -4.16353889e-02,   5.57258848e-02,
         -1.88539783e+00,  -4.08545807e-02,   9.75130959e-02,
         -1.86965697e+00,  -4.73982902e-02,   7.05912490e-02,
         -1.87894368e+00,  -4.62969160e-02,   1.54608282e-02,
         -1.88868229e+00,  -5.47404119e-02,   2.10494608e-03,
         -1.97881936e+00,  -1.70188567e-01,  -7.66238372e-02,
         -2.45477783e+00,  -1.88960400e-01,  -4.40458462e-02,
         -2.78058229e+00,  -1.47509886e-01,  -5.86328642e-02,
         -2.81774329e+00,  -1.53804594e-01,  -2.80798978e-02,
         -2.81387377e+00,  -1.52160833e-01,  -7.94516163e-02,
         -2.82014748e+00,  -1.44708656e-01,  -3.29376781e-02]),
 'split1_train_score': array([ -1.34101832e+00,   1.22896507e-02,   1.16264148e-01,
         -1.35134089e+00,   1.15337908e-02,   1.20431926e-01,
         -1.33797059e+00,   1.12225088e-02,   1.16110891e-01,
         -1.34583604e+00,   1.18870034e-02,   9.32233371e-02,
         -1.35414004e+00,   9.44809222e-03,   4.94562364e-02,
         -1.43093627e+00,  -3.69233944e-02,   1.18515551e-04,
         -1.84131898e+00,  -4.80193715e-02,   4.91874172e-04,
         -2.12603497e+00,  -2.84759687e-02,  -7.20924235e-04,
         -2.15867968e+00,  -3.13317745e-02,  -2.32405167e-03,
         -2.15528191e+00,  -3.05860975e-02,  -4.39662377e-03,
         -2.16079587e+00,  -2.72815283e-02,  -4.91058617e-04]),
 'split2_test_score': array([ -1.38287791e+00,   3.08598695e-03,   9.86482635e-02,
         -1.38084105e+00,   5.14583605e-03,   1.14976480e-01,
         -1.38773827e+00,   4.74357176e-03,   9.48414215e-02,
         -1.38082723e+00,  -2.41587663e-04,   5.48144782e-02,
         -1.38804914e+00,  -2.05927402e-03,   3.52673556e-02,
         -1.47046521e+00,  -7.64831621e-02,   1.70238309e-04,
         -1.85386091e+00,  -6.71597850e-02,  -5.58473292e-03,
         -2.12596170e+00,  -5.83437380e-02,  -3.30047661e-03,
         -2.15756403e+00,  -6.05982958e-02,   5.36550805e-04,
         -2.15353103e+00,  -5.78246226e-02,  -1.10704367e-02,
         -2.15964383e+00,  -5.98566239e-02,  -4.11431995e-02]),
 'split2_train_score': array([ -1.51067542e+00,   1.13514825e-02,   1.02409831e-01,
         -1.50829523e+00,   1.16470952e-02,   9.51564988e-02,
         -1.51632655e+00,   1.16821609e-02,   9.92040589e-02,
         -1.50830792e+00,   1.02333788e-02,   6.68943078e-02,
         -1.51673887e+00,   8.57971737e-03,   4.30285874e-02,
         -1.61324271e+00,  -4.92360281e-02,   5.80932411e-03,
         -2.06457677e+00,  -4.09399123e-02,   2.69908004e-04,
         -2.38668472e+00,  -3.33581816e-02,  -6.22386273e-04,
         -2.42417622e+00,  -3.52723381e-02,  -8.34269121e-03,
         -2.41939611e+00,  -3.29209210e-02,  -9.98344706e-04,
         -2.42664442e+00,  -3.46446961e-02,  -1.94046736e-02]),
 'std_fit_time': array([  5.80643365e-03,   2.50607982e-04,   7.50145489e-04,
          9.68735368e-05,   9.05241245e-04,   9.49277943e-05,
          6.70867376e-04,   4.02696124e-05,   2.14967578e-04,
          5.05292067e-04,   3.29635511e-03,   3.66729335e-04,
          4.24646194e-04,   3.66423752e-04,   3.27715470e-05,
          7.68797951e-05,   1.21453618e-03,   4.63699817e-04,
          9.22900633e-05,   2.05436958e-05,   2.40985273e-04,
          4.86494187e-04,   1.80813349e-04,   1.91486124e-04,
          3.36317969e-04,   3.27446680e-04,   1.36301816e-04,
          2.42968406e-04,   2.47636717e-05,   3.19968827e-05,
          3.28112708e-04,   1.21716252e-03,   1.22846081e-03]),
 'std_score_time': array([  8.15746032e-04,   3.89381329e-05,   1.98297128e-05,
          3.35181438e-05,   1.61751305e-04,   3.40715267e-06,
          3.36974297e-05,   2.78878407e-05,   6.80019312e-05,
          1.37843791e-04,   4.84638453e-04,   7.23319167e-05,
          1.92583582e-05,   4.65170880e-06,   1.75658393e-05,
          6.23582972e-05,   3.99312963e-05,   3.41768411e-05,
          1.83043094e-05,   2.24577984e-05,   2.39499608e-05,
          9.32443924e-06,   5.22711552e-06,   4.18573827e-06,
          1.24120223e-05,   2.24982599e-05,   1.38867856e-04,
          1.82046632e-05,   2.25624554e-05,   1.03674775e-05,
          1.39487887e-05,   5.47930644e-05,   1.33533278e-04]),
 'std_test_score': array([ 0.29441946,  0.03745092,  0.11885104,  0.30195719,  0.03923812,
         0.08148792,  0.29365547,  0.03984638,  0.01050487,  0.29890998,
         0.03009913,  0.04518251,  0.29620877,  0.02574424,  0.04828341,
         0.29851954,  0.0690561 ,  0.04046557,  0.31022386,  0.07541219,
         0.03156713,  0.31544647,  0.05529072,  0.04102595,  0.31637342,
         0.0578892 ,  0.07458775,  0.31579882,  0.05683099,  0.02839658,
         0.3183249 ,  0.05397448,  0.02454061]),
 'std_train_score': array([ 0.07230159,  0.00245961,  0.00783137,  0.06614071,  0.00286925,
         0.01461682,  0.0750955 ,  0.00272481,  0.01134453,  0.06869093,
         0.00244157,  0.02242722,  0.07002499,  0.00174784,  0.00660738,
         0.07709854,  0.00521289,  0.00280035,  0.09572006,  0.00292026,
         0.00083992,  0.11197913,  0.00254681,  0.00023526,  0.11380471,
         0.00168872,  0.0025246 ,  0.11368674,  0.00098725,  0.00258984,
         0.11299318,  0.0034609 ,  0.00875584])}

Make a predictor and train.


In [47]:
clf.score(train_X, train_y)


Out[47]:
SGDRegressor(alpha=0.0001, average=False, epsilon=0.1, eta0=0.01,
       fit_intercept=True, l1_ratio=0.15, learning_rate='invscaling',
       loss='squared_loss', n_iter=5, penalty='l2', power_t=0.25,
       random_state=None, shuffle=True, verbose=0, warm_start=False)

In [48]:
pred.score(validation_X, validation_y)


Out[48]:
0.021233124890070117

In [ ]: