In [1]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

In [2]:
import pandas as pd

Importing data

Getting data from kaggle first :


In [3]:
import pkg_resources

In [4]:
raw_data = pd.read_csv(pkg_resources.resource_stream('deepforest', 'data/train.csv'))

In [5]:
clean_data = raw_data.drop(["Cabin", "Name", "PassengerId", "Ticket"], axis=1)

In [6]:
clean_data = pd.get_dummies(clean_data).fillna(-1)

In [7]:
train, test = train_test_split(clean_data)

In [8]:
def split_x_y(dataframe, target):
    return dataframe.drop(target, axis=1), dataframe[target]

In [9]:
X_train, y_train = split_x_y(train, "Survived")
X_test, y_test = split_x_y(test, "Survived")

Baseline model


In [10]:
rf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
rf.fit(X_train, y_train)


Out[10]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=100, n_jobs=-1, oob_score=False,
            random_state=None, verbose=0, warm_start=False)

In [11]:
y_pred = rf.predict_proba(X_test)

In [12]:
auc = roc_auc_score(y_true=y_test, y_score=y_pred[:, 1])

In [13]:
auc


Out[13]:
0.86788399570354469

Deep Forest

By Hand


In [14]:
from sklearn.model_selection import StratifiedKFold

In [15]:
rf1 = RandomForestClassifier(n_estimators=100, n_jobs=-1, max_depth=4)
rf2 = RandomForestClassifier(n_estimators=100, n_jobs=-1, max_depth=10)

In [16]:
rf1.fit(X_train, y_train)
y_pred_train_1 = rf1.predict_proba(X_train)
y_pred_test_1 = rf1.predict_proba(X_test)

y_pred_train_1 = pd.DataFrame(y_pred_train_1, columns=["rf1_0", "rf1_1"], index=X_train.index)
y_pred_test_1 = pd.DataFrame(y_pred_test_1, columns=["rf1_0", "rf1_1"], index=X_test.index)

In [17]:
rf2.fit(X_train, y_train)
y_pred_train_2 = rf2.predict_proba(X_train)
y_pred_test_2 = rf2.predict_proba(X_test)

y_pred_train_2 = pd.DataFrame(y_pred_train_2, columns=["rf2_0", "rf2_1"], index=X_train.index)
y_pred_test_2 = pd.DataFrame(y_pred_test_2, columns=["rf2_0", "rf2_1"], index=X_test.index)

In [18]:
hidden_train_1 = pd.concat([X_train, y_pred_train_1, y_pred_train_2], axis=1)
hidden_test_1 = pd.concat([X_test, y_pred_test_1, y_pred_test_2], axis=1)

In [19]:
hidden_train_1.head()


Out[19]:
Pclass Age SibSp Parch Fare Sex_female Sex_male Embarked_C Embarked_Q Embarked_S rf1_0 rf1_1 rf2_0 rf2_1
786 3 18.0 0 0 7.4958 1 0 0 0 1 0.410273 0.589727 0.256744 0.743256
627 1 21.0 0 0 77.9583 1 0 0 0 1 0.070023 0.929977 0.000000 1.000000
695 2 52.0 0 0 13.5000 0 1 0 0 1 0.863169 0.136831 0.938984 0.061016
379 3 19.0 0 0 7.7750 0 1 0 0 1 0.890193 0.109807 0.931481 0.068519
708 1 22.0 0 0 151.5500 1 0 0 0 1 0.087613 0.912387 0.030000 0.970000

In [20]:
rf3 = RandomForestClassifier(n_estimators=300, n_jobs=-1)
rf3.fit(hidden_train_1, y_train)


Out[20]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=300, n_jobs=-1, oob_score=False,
            random_state=None, verbose=0, warm_start=False)

In [21]:
y_pred3 = rf3.predict_proba(hidden_test_1)

In [22]:
roc_auc_score(y_test, y_pred3[:, 1])


Out[22]:
0.86040995345506621

This is not very handy, not at all. We already see a lot of code duplication, and one may feel there may be a way to abstract a lot of the logic that is happening here, in a way that is more flexible and powerful that all this boilerplate code.

With API


In [23]:
from deepforest.layer import Layer, InputLayer

In [24]:
input_layer = InputLayer(RandomForestClassifier(n_estimators=100, n_jobs=-1, max_depth=4),
                         RandomForestClassifier(n_estimators=100, n_jobs=-1, max_depth=10))

In [25]:
hidden_layer = Layer(input_layer,
                     RandomForestClassifier(n_estimators=50, n_jobs=-1, max_depth=4),
                     RandomForestClassifier(n_estimators=50, n_jobs=-1, max_depth=10))

In [26]:
hidden_layer.fit(X_train, y_train)


Out[26]:
<deepforest.layer.Layer at 0x104522d30>

In [27]:
pd.DataFrame(hidden_layer.predict(X_test), index=X_test.index)


Out[27]:
0 1 2 3
81 0.886403 0.113597 0.840710 0.159290
889 0.176103 0.823897 0.020000 0.980000
188 0.929972 0.070028 0.993103 0.006897
742 0.029003 0.970997 0.000000 1.000000
146 0.892927 0.107073 0.921950 0.078050
99 0.902629 0.097371 0.755414 0.244586
197 0.893365 0.106635 0.892192 0.107808
814 0.876441 0.123559 0.781518 0.218482
836 0.890743 0.109257 0.918327 0.081673
13 0.909326 0.090674 0.936667 0.063333
465 0.907381 0.092619 0.971794 0.028206
860 0.923719 0.076281 0.977759 0.022241
156 0.130872 0.869128 0.004038 0.995962
674 0.911772 0.088228 0.978350 0.021650
790 0.918126 0.081874 0.883977 0.116023
567 0.161277 0.838723 0.021143 0.978857
133 0.094557 0.905443 0.006000 0.994000
709 0.893387 0.106613 0.936368 0.063632
843 0.885771 0.114229 0.968222 0.031778
734 0.891792 0.108208 0.922608 0.077392
537 0.028675 0.971325 0.000000 1.000000
92 0.843012 0.156988 0.957742 0.042258
729 0.747858 0.252142 0.976667 0.023333
694 0.871790 0.128210 0.955844 0.044156
89 0.882282 0.117718 0.878046 0.121954
269 0.050097 0.949903 0.001143 0.998857
606 0.891427 0.108573 0.886614 0.113386
552 0.911842 0.088158 0.848512 0.151488
576 0.117489 0.882511 0.079613 0.920387
386 0.932624 0.067376 0.997895 0.002105
... ... ... ... ...
49 0.725637 0.274363 0.988571 0.011429
384 0.899280 0.100720 0.957731 0.042269
303 0.099159 0.900841 0.035139 0.964861
839 0.890895 0.109105 0.992737 0.007263
162 0.895626 0.104374 0.944614 0.055386
492 0.885562 0.114438 0.977491 0.022509
753 0.897268 0.102732 0.954851 0.045149
65 0.893387 0.106613 0.936368 0.063632
72 0.858684 0.141316 0.903854 0.096146
755 0.157551 0.842449 0.004000 0.996000
749 0.915763 0.084237 0.978374 0.021626
69 0.900999 0.099001 0.943801 0.056199
221 0.889687 0.110313 0.912508 0.087492
267 0.908515 0.091485 0.958836 0.041164
590 0.907381 0.092619 0.980683 0.019317
175 0.913077 0.086923 0.925646 0.074354
704 0.906358 0.093642 0.954353 0.045647
292 0.874089 0.125911 0.895357 0.104643
888 0.727689 0.272311 0.940000 0.060000
2 0.743507 0.256493 0.980000 0.020000
346 0.116100 0.883900 0.519613 0.480387
95 0.879062 0.120938 0.894773 0.105227
260 0.918126 0.081874 0.883977 0.116023
475 0.877184 0.122816 0.981473 0.018527
47 0.147430 0.852570 0.147257 0.852743
599 0.180075 0.819925 0.020000 0.980000
644 0.112588 0.887412 0.040000 0.960000
166 0.044899 0.955101 0.008000 0.992000
214 0.927508 0.072492 0.911442 0.088558
622 0.901662 0.098338 0.977701 0.022299

223 rows × 4 columns

Going Further


In [28]:
def random_forest_generator():
    for i in range(2, 15, 2):
        yield RandomForestClassifier(n_estimators=100,
                                     n_jobs=-1,
                                     min_samples_leaf=5,
                                     max_depth=i)
    for i in range(2, 15, 2):
        yield RandomForestClassifier(n_estimators=100,
                                     n_jobs=-1,
                                     max_features=1,
                                     min_samples_leaf=5,
                                     max_depth=i)

In [29]:
def paper_like_generator():
    for i in range(2):
        yield RandomForestClassifier(n_estimators=1000,
                                     n_jobs=-1,
                                     min_samples_leaf=10)
    for i in range(2):
        yield RandomForestClassifier(n_estimators=1000,
                                     n_jobs=-1,
                                     max_features=1,
                                     min_samples_leaf=10)

In [30]:
def build_input_layer():
    return InputLayer(*paper_like_generator())

In [31]:
def build_hidden_layer(layer):
    return Layer(layer, *paper_like_generator())

In [32]:
def build_output_layer(layer):
    return Layer(layer,
                 RandomForestClassifier(n_estimators=500,
                                        n_jobs=-1,
                                        min_samples_leaf=5,
                                        max_depth=10))

In [33]:
input_l = build_input_layer()
hidden_1 = build_hidden_layer(input_l)
hidden_2 = build_hidden_layer(hidden_1)
hidden_3 = build_hidden_layer(hidden_2)
hidden_4 = build_hidden_layer(hidden_3)
output_l = build_output_layer(hidden_4)

In [34]:
output_l.fit(X_train, y_train)


Out[34]:
<deepforest.layer.Layer at 0x104512978>

In [35]:
y_pred = output_l.predict(X_test)

In [36]:
y_pred


Out[36]:
array([[  9.29638699e-01,   7.03613014e-02],
       [  1.72212883e-01,   8.27787117e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  7.27478992e-03,   9.92725210e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  7.52790052e-01,   2.47209948e-01],
       [  9.59129256e-01,   4.08707444e-02],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  3.09158979e-01,   6.90841021e-01],
       [  9.97466667e-01,   2.53333333e-03],
       [  8.88917085e-01,   1.11082915e-01],
       [  3.52073705e-01,   6.47926295e-01],
       [  3.72293325e-01,   6.27706675e-01],
       [  2.88891001e-01,   7.11108999e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.42849243e-01,   8.57150757e-01],
       [  9.94265995e-01,   5.73400488e-03],
       [  5.31304247e-01,   4.68695753e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  7.79309549e-01,   2.20690451e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  9.97466667e-01,   2.53333333e-03],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  2.85085104e-01,   7.14914896e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  9.76749260e-01,   2.32507404e-02],
       [  9.97733117e-01,   2.26688312e-03],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e-03,   9.99000000e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  4.18106720e-01,   5.81893280e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  3.68057044e-01,   6.31942956e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  7.43477107e-01,   2.56522893e-01],
       [  1.00000000e-03,   9.99000000e-01],
       [  3.33333333e-04,   9.99666667e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  5.89930916e-01,   4.10069084e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  8.00000000e-04,   9.99200000e-01],
       [  2.75012123e-01,   7.24987877e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  4.21895897e-01,   5.78104103e-01],
       [  9.97238528e-01,   2.76147186e-03],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  9.60791699e-01,   3.92083007e-02],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  6.05880231e-03,   9.93941198e-01],
       [  6.03921085e-02,   9.39607891e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  9.57124841e-01,   4.28751587e-02],
       [  9.37848142e-02,   9.06215186e-01],
       [  1.05903796e-02,   9.89409620e-01],
       [  9.71861676e-01,   2.81383244e-02],
       [  1.00000000e+00,   0.00000000e+00],
       [  3.72293325e-01,   6.27706675e-01],
       [  1.00121662e-01,   8.99878338e-01],
       [  1.48621515e-01,   8.51378485e-01],
       [  8.75704622e-03,   9.91242954e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.89456509e-01,   8.10543491e-01],
       [  9.99666667e-01,   3.33333333e-04],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  3.47143943e-01,   6.52856057e-01],
       [  9.30603195e-01,   6.93968050e-02],
       [  8.50221529e-01,   1.49778471e-01],
       [  6.38185437e-01,   3.61814563e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  7.05779498e-01,   2.94220502e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  9.92754884e-01,   7.24511600e-03],
       [  9.57180501e-04,   9.99042819e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  4.20217831e-01,   5.79782169e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.42777254e-01,   8.57222746e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.46889453e-01,   8.53110547e-01],
       [  9.98304545e-01,   1.69545455e-03],
       [  7.53396290e-01,   2.46603710e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  6.96086107e-01,   3.03913893e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.35718050e-03,   9.98642819e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  3.19960474e-03,   9.96800395e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.87874913e-03,   9.98121251e-01],
       [  5.07416827e-02,   9.49258317e-01],
       [  9.79580854e-01,   2.04191457e-02],
       [  9.99666667e-01,   3.33333333e-04],
       [  5.00000000e-04,   9.99500000e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  6.52558301e-01,   3.47441699e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  9.57180501e-04,   9.99042819e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  3.48450710e-01,   6.51549290e-01],
       [  8.74391634e-02,   9.12560837e-01],
       [  3.33333333e-04,   9.99666667e-01],
       [  7.78466518e-01,   2.21533482e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  2.97478992e-03,   9.97025210e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  9.97466667e-01,   2.53333333e-03],
       [  1.05016455e-01,   8.94983545e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  7.80700119e-01,   2.19299881e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  3.68494014e-01,   6.31505986e-01],
       [  9.99142857e-01,   8.57142857e-04],
       [  9.99682353e-01,   3.17647059e-04],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.39823314e-03,   9.98601767e-01],
       [  3.33333333e-04,   9.99666667e-01],
       [  1.25000000e-03,   9.98750000e-01],
       [  0.00000000e+00,   1.00000000e+00],
       [  6.48604955e-01,   3.51395045e-01],
       [  9.69221691e-01,   3.07783094e-02],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  5.38431373e-03,   9.94615686e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  8.13453221e-01,   1.86546779e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  8.00000000e-04,   9.99200000e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  9.29638699e-01,   7.03613014e-02],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  5.20721759e-01,   4.79278241e-01],
       [  8.88917085e-01,   1.11082915e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  4.24548278e-01,   5.75451722e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  7.90513834e-04,   9.99209486e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  2.88891001e-01,   7.11108999e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  4.00000000e-04,   9.99600000e-01],
       [  9.99666667e-01,   3.33333333e-04],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  5.99625042e-01,   4.00374958e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.00000000e+00,   0.00000000e+00],
       [  9.99666667e-01,   3.33333333e-04],
       [  9.43014139e-02,   9.05698586e-01],
       [  3.91634773e-01,   6.08365227e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  8.88917085e-01,   1.11082915e-01],
       [  1.00000000e+00,   0.00000000e+00],
       [  1.61653007e-01,   8.38346993e-01],
       [  9.23118030e-01,   7.68819699e-02],
       [  0.00000000e+00,   1.00000000e+00],
       [  0.00000000e+00,   1.00000000e+00],
       [  9.99666667e-01,   3.33333333e-04],
       [  9.96500000e-01,   3.50000000e-03]])

In [37]:
roc_auc_score(y_test, y_pred[:, 1])


Out[37]:
0.84586466165413543

Well the result is not that satisfactory yet, but let's not loose hope. There is a lot of room for improvement yet. First item on my todo list: make sure all the intermediary models are trained using cross-validation techniques, to reduce overfitting.