Chapter 7 – Ensemble Learning and Random Forests

This notebook contains all the sample code and solutions to the exercices in chapter 7.

Setup

First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:


In [1]:
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# Common imports
import numpy as np
import os

# to make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "ensembles"

def image_path(fig_id):
    return os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id)

def save_fig(fig_id, tight_layout=True):
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(image_path(fig_id) + ".png", format='png', dpi=300)

Voting classifiers


In [2]:
heads_proba = 0.51
coin_tosses = (np.random.rand(10000, 10) < heads_proba).astype(np.int32)
cumulative_heads_ratio = np.cumsum(coin_tosses, axis=0) / np.arange(1, 10001).reshape(-1, 1)

In [3]:
plt.figure(figsize=(8,3.5))
plt.plot(cumulative_heads_ratio)
plt.plot([0, 10000], [0.51, 0.51], "k--", linewidth=2, label="51%")
plt.plot([0, 10000], [0.5, 0.5], "k-", label="50%")
plt.xlabel("Number of coin tosses")
plt.ylabel("Heads ratio")
plt.legend(loc="lower right")
plt.axis([0, 10000, 0.42, 0.58])
save_fig("law_of_large_numbers_plot")
plt.show()


Saving figure law_of_large_numbers_plot

In [4]:
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons

X, y = make_moons(n_samples=500, noise=0.30, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [5]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

log_clf = LogisticRegression(random_state=42)
rnd_clf = RandomForestClassifier(random_state=42)
svm_clf = SVC(random_state=42)

voting_clf = VotingClassifier(
    estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],
    voting='hard')
voting_clf.fit(X_train, y_train)


Out[5]:
VotingClassifier(estimators=[('lr', LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=42, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)), ('rf', RandomFor...f',
  max_iter=-1, probability=False, random_state=42, shrinking=True,
  tol=0.001, verbose=False))],
         n_jobs=1, voting='hard', weights=None)

In [6]:
from sklearn.metrics import accuracy_score

for clf in (log_clf, rnd_clf, svm_clf, voting_clf):
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print(clf.__class__.__name__, accuracy_score(y_test, y_pred))


LogisticRegression 0.864
RandomForestClassifier 0.872
SVC 0.888
VotingClassifier 0.896

In [7]:
log_clf = LogisticRegression(random_state=42)
rnd_clf = RandomForestClassifier(random_state=42)
svm_clf = SVC(probability=True, random_state=42)

voting_clf = VotingClassifier(
    estimators=[('lr', log_clf), ('rf', rnd_clf), ('svc', svm_clf)],
    voting='soft')
voting_clf.fit(X_train, y_train)


Out[7]:
VotingClassifier(estimators=[('lr', LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=42, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)), ('rf', RandomFor...bf',
  max_iter=-1, probability=True, random_state=42, shrinking=True,
  tol=0.001, verbose=False))],
         n_jobs=1, voting='soft', weights=None)

In [8]:
from sklearn.metrics import accuracy_score

for clf in (log_clf, rnd_clf, svm_clf, voting_clf):
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)
    print(clf.__class__.__name__, accuracy_score(y_test, y_pred))


LogisticRegression 0.864
RandomForestClassifier 0.872
SVC 0.888
VotingClassifier 0.912

Bagging ensembles


In [9]:
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier

bag_clf = BaggingClassifier(
    DecisionTreeClassifier(random_state=42), n_estimators=500,
    max_samples=100, bootstrap=True, n_jobs=-1, random_state=42)
bag_clf.fit(X_train, y_train)
y_pred = bag_clf.predict(X_test)

In [10]:
from sklearn.metrics import accuracy_score
print(accuracy_score(y_test, y_pred))


0.904

In [11]:
tree_clf = DecisionTreeClassifier(random_state=42)
tree_clf.fit(X_train, y_train)
y_pred_tree = tree_clf.predict(X_test)
print(accuracy_score(y_test, y_pred_tree))


0.856

In [12]:
from matplotlib.colors import ListedColormap

def plot_decision_boundary(clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.5, contour=True):
    x1s = np.linspace(axes[0], axes[1], 100)
    x2s = np.linspace(axes[2], axes[3], 100)
    x1, x2 = np.meshgrid(x1s, x2s)
    X_new = np.c_[x1.ravel(), x2.ravel()]
    y_pred = clf.predict(X_new).reshape(x1.shape)
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap, linewidth=10)
    if contour:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)
    plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", alpha=alpha)
    plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", alpha=alpha)
    plt.axis(axes)
    plt.xlabel(r"$x_1$", fontsize=18)
    plt.ylabel(r"$x_2$", fontsize=18, rotation=0)

In [13]:
plt.figure(figsize=(11,4))
plt.subplot(121)
plot_decision_boundary(tree_clf, X, y)
plt.title("Decision Tree", fontsize=14)
plt.subplot(122)
plot_decision_boundary(bag_clf, X, y)
plt.title("Decision Trees with Bagging", fontsize=14)
save_fig("decision_tree_without_and_with_bagging_plot")
plt.show()


Saving figure decision_tree_without_and_with_bagging_plot

Random Forests


In [14]:
bag_clf = BaggingClassifier(
    DecisionTreeClassifier(splitter="random", max_leaf_nodes=16, random_state=42),
    n_estimators=500, max_samples=1.0, bootstrap=True, n_jobs=-1, random_state=42)

In [15]:
bag_clf.fit(X_train, y_train)
y_pred = bag_clf.predict(X_test)

In [16]:
from sklearn.ensemble import RandomForestClassifier

rnd_clf = RandomForestClassifier(n_estimators=500, max_leaf_nodes=16, n_jobs=-1, random_state=42)
rnd_clf.fit(X_train, y_train)

y_pred_rf = rnd_clf.predict(X_test)

In [17]:
np.sum(y_pred == y_pred_rf) / len(y_pred)  # almost identical predictions


Out[17]:
0.97599999999999998

In [18]:
from sklearn.datasets import load_iris
iris = load_iris()
rnd_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1, random_state=42)
rnd_clf.fit(iris["data"], iris["target"])
for name, score in zip(iris["feature_names"], rnd_clf.feature_importances_):
    print(name, score)


sepal length (cm) 0.112492250999
sepal width (cm) 0.0231192882825
petal length (cm) 0.441030464364
petal width (cm) 0.423357996355

In [19]:
rnd_clf.feature_importances_


Out[19]:
array([ 0.11249225,  0.02311929,  0.44103046,  0.423358  ])

In [20]:
plt.figure(figsize=(6, 4))

for i in range(15):
    tree_clf = DecisionTreeClassifier(max_leaf_nodes=16, random_state=42 + i)
    indices_with_replacement = np.random.randint(0, len(X_train), len(X_train))
    tree_clf.fit(X[indices_with_replacement], y[indices_with_replacement])
    plot_decision_boundary(tree_clf, X, y, axes=[-1.5, 2.5, -1, 1.5], alpha=0.02, contour=False)

plt.show()


Out-of-Bag evaluation


In [21]:
bag_clf = BaggingClassifier(
    DecisionTreeClassifier(random_state=42), n_estimators=500,
    bootstrap=True, n_jobs=-1, oob_score=True, random_state=40)
bag_clf.fit(X_train, y_train)
bag_clf.oob_score_


Out[21]:
0.90133333333333332

In [22]:
bag_clf.oob_decision_function_


Out[22]:
array([[ 0.31746032,  0.68253968],
       [ 0.34117647,  0.65882353],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.08379888,  0.91620112],
       [ 0.31693989,  0.68306011],
       [ 0.02923977,  0.97076023],
       [ 0.97687861,  0.02312139],
       [ 0.97765363,  0.02234637],
       [ 0.74404762,  0.25595238],
       [ 0.        ,  1.        ],
       [ 0.71195652,  0.28804348],
       [ 0.83957219,  0.16042781],
       [ 0.97777778,  0.02222222],
       [ 0.0625    ,  0.9375    ],
       [ 0.        ,  1.        ],
       [ 0.97297297,  0.02702703],
       [ 0.95238095,  0.04761905],
       [ 1.        ,  0.        ],
       [ 0.01704545,  0.98295455],
       [ 0.38947368,  0.61052632],
       [ 0.88700565,  0.11299435],
       [ 1.        ,  0.        ],
       [ 0.96685083,  0.03314917],
       [ 0.        ,  1.        ],
       [ 0.99428571,  0.00571429],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.64804469,  0.35195531],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.13402062,  0.86597938],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.36065574,  0.63934426],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.27093596,  0.72906404],
       [ 0.34146341,  0.65853659],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.00531915,  0.99468085],
       [ 0.98265896,  0.01734104],
       [ 0.91428571,  0.08571429],
       [ 0.97282609,  0.02717391],
       [ 0.97029703,  0.02970297],
       [ 0.        ,  1.        ],
       [ 0.06134969,  0.93865031],
       [ 0.98019802,  0.01980198],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.97790055,  0.02209945],
       [ 0.79473684,  0.20526316],
       [ 0.41919192,  0.58080808],
       [ 0.99473684,  0.00526316],
       [ 0.        ,  1.        ],
       [ 0.67613636,  0.32386364],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.87356322,  0.12643678],
       [ 1.        ,  0.        ],
       [ 0.56140351,  0.43859649],
       [ 0.16304348,  0.83695652],
       [ 0.67539267,  0.32460733],
       [ 0.90673575,  0.09326425],
       [ 0.        ,  1.        ],
       [ 0.16201117,  0.83798883],
       [ 0.89005236,  0.10994764],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.995     ,  0.005     ],
       [ 0.        ,  1.        ],
       [ 0.07272727,  0.92727273],
       [ 0.05418719,  0.94581281],
       [ 0.29533679,  0.70466321],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.81871345,  0.18128655],
       [ 0.01092896,  0.98907104],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.22513089,  0.77486911],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.9368932 ,  0.0631068 ],
       [ 0.76536313,  0.23463687],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.17127072,  0.82872928],
       [ 0.65306122,  0.34693878],
       [ 0.        ,  1.        ],
       [ 0.03076923,  0.96923077],
       [ 0.49444444,  0.50555556],
       [ 1.        ,  0.        ],
       [ 0.02673797,  0.97326203],
       [ 0.98870056,  0.01129944],
       [ 0.23121387,  0.76878613],
       [ 0.5       ,  0.5       ],
       [ 0.9947644 ,  0.0052356 ],
       [ 0.00555556,  0.99444444],
       [ 0.98963731,  0.01036269],
       [ 0.25641026,  0.74358974],
       [ 0.92972973,  0.07027027],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.80681818,  0.19318182],
       [ 1.        ,  0.        ],
       [ 0.0106383 ,  0.9893617 ],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.98181818,  0.01818182],
       [ 1.        ,  0.        ],
       [ 0.01036269,  0.98963731],
       [ 0.97752809,  0.02247191],
       [ 0.99453552,  0.00546448],
       [ 0.01960784,  0.98039216],
       [ 0.18367347,  0.81632653],
       [ 0.98387097,  0.01612903],
       [ 0.29533679,  0.70466321],
       [ 0.98295455,  0.01704545],
       [ 0.        ,  1.        ],
       [ 0.00561798,  0.99438202],
       [ 0.75138122,  0.24861878],
       [ 0.38624339,  0.61375661],
       [ 0.42708333,  0.57291667],
       [ 0.86315789,  0.13684211],
       [ 0.92964824,  0.07035176],
       [ 0.05699482,  0.94300518],
       [ 0.82802548,  0.17197452],
       [ 0.01546392,  0.98453608],
       [ 0.        ,  1.        ],
       [ 0.02298851,  0.97701149],
       [ 0.96721311,  0.03278689],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.01041667,  0.98958333],
       [ 0.        ,  1.        ],
       [ 0.0326087 ,  0.9673913 ],
       [ 0.01020408,  0.98979592],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.93785311,  0.06214689],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.99462366,  0.00537634],
       [ 0.        ,  1.        ],
       [ 0.38860104,  0.61139896],
       [ 0.32065217,  0.67934783],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.31182796,  0.68817204],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.00588235,  0.99411765],
       [ 0.        ,  1.        ],
       [ 0.98387097,  0.01612903],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.62264151,  0.37735849],
       [ 0.92344498,  0.07655502],
       [ 0.        ,  1.        ],
       [ 0.99526066,  0.00473934],
       [ 1.        ,  0.        ],
       [ 0.98888889,  0.01111111],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.06451613,  0.93548387],
       [ 1.        ,  0.        ],
       [ 0.05154639,  0.94845361],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.03278689,  0.96721311],
       [ 1.        ,  0.        ],
       [ 0.95808383,  0.04191617],
       [ 0.79532164,  0.20467836],
       [ 0.55665025,  0.44334975],
       [ 0.        ,  1.        ],
       [ 0.18604651,  0.81395349],
       [ 1.        ,  0.        ],
       [ 0.93121693,  0.06878307],
       [ 0.97740113,  0.02259887],
       [ 1.        ,  0.        ],
       [ 0.00531915,  0.99468085],
       [ 0.        ,  1.        ],
       [ 0.44623656,  0.55376344],
       [ 0.86363636,  0.13636364],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.00558659,  0.99441341],
       [ 0.        ,  1.        ],
       [ 0.96923077,  0.03076923],
       [ 0.        ,  1.        ],
       [ 0.21649485,  0.78350515],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.98477157,  0.01522843],
       [ 0.8       ,  0.2       ],
       [ 0.99441341,  0.00558659],
       [ 0.        ,  1.        ],
       [ 0.08379888,  0.91620112],
       [ 0.98984772,  0.01015228],
       [ 0.01142857,  0.98857143],
       [ 0.        ,  1.        ],
       [ 0.02747253,  0.97252747],
       [ 1.        ,  0.        ],
       [ 0.79144385,  0.20855615],
       [ 0.        ,  1.        ],
       [ 0.90804598,  0.09195402],
       [ 0.98387097,  0.01612903],
       [ 0.20634921,  0.79365079],
       [ 0.19767442,  0.80232558],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.20338983,  0.79661017],
       [ 0.98181818,  0.01818182],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.98969072,  0.01030928],
       [ 0.        ,  1.        ],
       [ 0.48663102,  0.51336898],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.07821229,  0.92178771],
       [ 0.11176471,  0.88823529],
       [ 0.99415205,  0.00584795],
       [ 0.03015075,  0.96984925],
       [ 1.        ,  0.        ],
       [ 0.40837696,  0.59162304],
       [ 0.04891304,  0.95108696],
       [ 0.51595745,  0.48404255],
       [ 0.51898734,  0.48101266],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.59903382,  0.40096618],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.24157303,  0.75842697],
       [ 0.81052632,  0.18947368],
       [ 0.08717949,  0.91282051],
       [ 0.99453552,  0.00546448],
       [ 0.82142857,  0.17857143],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 0.125     ,  0.875     ],
       [ 0.04712042,  0.95287958],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.89150943,  0.10849057],
       [ 0.1978022 ,  0.8021978 ],
       [ 0.95238095,  0.04761905],
       [ 0.00515464,  0.99484536],
       [ 0.609375  ,  0.390625  ],
       [ 0.07692308,  0.92307692],
       [ 0.99484536,  0.00515464],
       [ 0.84210526,  0.15789474],
       [ 0.        ,  1.        ],
       [ 0.99484536,  0.00515464],
       [ 0.95876289,  0.04123711],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.26903553,  0.73096447],
       [ 0.98461538,  0.01538462],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.00574713,  0.99425287],
       [ 0.85142857,  0.14857143],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.76506024,  0.23493976],
       [ 0.8969697 ,  0.1030303 ],
       [ 1.        ,  0.        ],
       [ 0.73333333,  0.26666667],
       [ 0.47727273,  0.52272727],
       [ 0.        ,  1.        ],
       [ 0.92473118,  0.07526882],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.87709497,  0.12290503],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.74752475,  0.25247525],
       [ 0.09146341,  0.90853659],
       [ 0.44329897,  0.55670103],
       [ 0.22395833,  0.77604167],
       [ 0.        ,  1.        ],
       [ 0.87046632,  0.12953368],
       [ 0.78212291,  0.21787709],
       [ 0.00507614,  0.99492386],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.02884615,  0.97115385],
       [ 0.96571429,  0.03428571],
       [ 0.93478261,  0.06521739],
       [ 1.        ,  0.        ],
       [ 0.49756098,  0.50243902],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.01604278,  0.98395722],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.96987952,  0.03012048],
       [ 0.        ,  1.        ],
       [ 0.05747126,  0.94252874],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.98989899,  0.01010101],
       [ 0.01675978,  0.98324022],
       [ 1.        ,  0.        ],
       [ 0.13541667,  0.86458333],
       [ 0.        ,  1.        ],
       [ 0.00546448,  0.99453552],
       [ 0.        ,  1.        ],
       [ 0.41836735,  0.58163265],
       [ 0.11309524,  0.88690476],
       [ 0.22110553,  0.77889447],
       [ 1.        ,  0.        ],
       [ 0.97647059,  0.02352941],
       [ 0.22826087,  0.77173913],
       [ 0.98882682,  0.01117318],
       [ 0.        ,  1.        ],
       [ 0.        ,  1.        ],
       [ 1.        ,  0.        ],
       [ 0.96428571,  0.03571429],
       [ 0.33507853,  0.66492147],
       [ 0.98235294,  0.01764706],
       [ 1.        ,  0.        ],
       [ 0.        ,  1.        ],
       [ 0.99465241,  0.00534759],
       [ 0.        ,  1.        ],
       [ 0.06043956,  0.93956044],
       [ 0.97619048,  0.02380952],
       [ 1.        ,  0.        ],
       [ 0.03108808,  0.96891192],
       [ 0.57291667,  0.42708333]])

In [23]:
from sklearn.metrics import accuracy_score
y_pred = bag_clf.predict(X_test)
accuracy_score(y_test, y_pred)


Out[23]:
0.91200000000000003

Feature importance


In [24]:
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')

In [25]:
rnd_clf = RandomForestClassifier(random_state=42)
rnd_clf.fit(mnist["data"], mnist["target"])


Out[25]:
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=10, n_jobs=1, oob_score=False, random_state=42,
            verbose=0, warm_start=False)

In [26]:
def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap = matplotlib.cm.hot,
               interpolation="nearest")
    plt.axis("off")

In [27]:
plot_digit(rnd_clf.feature_importances_)

cbar = plt.colorbar(ticks=[rnd_clf.feature_importances_.min(), rnd_clf.feature_importances_.max()])
cbar.ax.set_yticklabels(['Not important', 'Very important'])

save_fig("mnist_feature_importance_plot")
plt.show()


Saving figure mnist_feature_importance_plot

AdaBoost


In [28]:
from sklearn.ensemble import AdaBoostClassifier

ada_clf = AdaBoostClassifier(
    DecisionTreeClassifier(max_depth=1), n_estimators=200,
    algorithm="SAMME.R", learning_rate=0.5, random_state=42)
ada_clf.fit(X_train, y_train)


Out[28]:
AdaBoostClassifier(algorithm='SAMME.R',
          base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=1,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best'),
          learning_rate=0.5, n_estimators=200, random_state=42)

In [29]:
plot_decision_boundary(ada_clf, X, y)



In [30]:
m = len(X_train)

plt.figure(figsize=(11, 4))
for subplot, learning_rate in ((121, 1), (122, 0.5)):
    sample_weights = np.ones(m)
    for i in range(5):
        plt.subplot(subplot)
        svm_clf = SVC(kernel="rbf", C=0.05, random_state=42)
        svm_clf.fit(X_train, y_train, sample_weight=sample_weights)
        y_pred = svm_clf.predict(X_train)
        sample_weights[y_pred != y_train] *= (1 + learning_rate)
        plot_decision_boundary(svm_clf, X, y, alpha=0.2)
        plt.title("learning_rate = {}".format(learning_rate - 1), fontsize=16)

plt.subplot(121)
plt.text(-0.7, -0.65, "1", fontsize=14)
plt.text(-0.6, -0.10, "2", fontsize=14)
plt.text(-0.5,  0.10, "3", fontsize=14)
plt.text(-0.4,  0.55, "4", fontsize=14)
plt.text(-0.3,  0.90, "5", fontsize=14)
save_fig("boosting_plot")
plt.show()


Saving figure boosting_plot

In [31]:
list(m for m in dir(ada_clf) if not m.startswith("_") and m.endswith("_"))


Out[31]:
['base_estimator_',
 'classes_',
 'estimator_errors_',
 'estimator_weights_',
 'estimators_',
 'feature_importances_',
 'n_classes_']

Gradient Boosting


In [32]:
np.random.seed(42)
X = np.random.rand(100, 1) - 0.5
y = 3*X[:, 0]**2 + 0.05 * np.random.randn(100)

In [33]:
from sklearn.tree import DecisionTreeRegressor

tree_reg1 = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg1.fit(X, y)


Out[33]:
DecisionTreeRegressor(criterion='mse', max_depth=2, max_features=None,
           max_leaf_nodes=None, min_impurity_split=1e-07,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, presort=False, random_state=42,
           splitter='best')

In [34]:
y2 = y - tree_reg1.predict(X)
tree_reg2 = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg2.fit(X, y2)


Out[34]:
DecisionTreeRegressor(criterion='mse', max_depth=2, max_features=None,
           max_leaf_nodes=None, min_impurity_split=1e-07,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, presort=False, random_state=42,
           splitter='best')

In [35]:
y3 = y2 - tree_reg2.predict(X)
tree_reg3 = DecisionTreeRegressor(max_depth=2, random_state=42)
tree_reg3.fit(X, y3)


Out[35]:
DecisionTreeRegressor(criterion='mse', max_depth=2, max_features=None,
           max_leaf_nodes=None, min_impurity_split=1e-07,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, presort=False, random_state=42,
           splitter='best')

In [36]:
X_new = np.array([[0.8]])

In [37]:
y_pred = sum(tree.predict(X_new) for tree in (tree_reg1, tree_reg2, tree_reg3))

In [38]:
y_pred


Out[38]:
array([ 0.75026781])

In [39]:
def plot_predictions(regressors, X, y, axes, label=None, style="r-", data_style="b.", data_label=None):
    x1 = np.linspace(axes[0], axes[1], 500)
    y_pred = sum(regressor.predict(x1.reshape(-1, 1)) for regressor in regressors)
    plt.plot(X[:, 0], y, data_style, label=data_label)
    plt.plot(x1, y_pred, style, linewidth=2, label=label)
    if label or data_label:
        plt.legend(loc="upper center", fontsize=16)
    plt.axis(axes)

plt.figure(figsize=(11,11))

plt.subplot(321)
plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h_1(x_1)$", style="g-", data_label="Training set")
plt.ylabel("$y$", fontsize=16, rotation=0)
plt.title("Residuals and tree predictions", fontsize=16)

plt.subplot(322)
plot_predictions([tree_reg1], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h(x_1) = h_1(x_1)$", data_label="Training set")
plt.ylabel("$y$", fontsize=16, rotation=0)
plt.title("Ensemble predictions", fontsize=16)

plt.subplot(323)
plot_predictions([tree_reg2], X, y2, axes=[-0.5, 0.5, -0.5, 0.5], label="$h_2(x_1)$", style="g-", data_style="k+", data_label="Residuals")
plt.ylabel("$y - h_1(x_1)$", fontsize=16)

plt.subplot(324)
plot_predictions([tree_reg1, tree_reg2], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h(x_1) = h_1(x_1) + h_2(x_1)$")
plt.ylabel("$y$", fontsize=16, rotation=0)

plt.subplot(325)
plot_predictions([tree_reg3], X, y3, axes=[-0.5, 0.5, -0.5, 0.5], label="$h_3(x_1)$", style="g-", data_style="k+")
plt.ylabel("$y - h_1(x_1) - h_2(x_1)$", fontsize=16)
plt.xlabel("$x_1$", fontsize=16)

plt.subplot(326)
plot_predictions([tree_reg1, tree_reg2, tree_reg3], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="$h(x_1) = h_1(x_1) + h_2(x_1) + h_3(x_1)$")
plt.xlabel("$x_1$", fontsize=16)
plt.ylabel("$y$", fontsize=16, rotation=0)

save_fig("gradient_boosting_plot")
plt.show()


Saving figure gradient_boosting_plot

In [40]:
from sklearn.ensemble import GradientBoostingRegressor

gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=3, learning_rate=1.0, random_state=42)
gbrt.fit(X, y)


Out[40]:
GradientBoostingRegressor(alpha=0.9, criterion='friedman_mse', init=None,
             learning_rate=1.0, loss='ls', max_depth=2, max_features=None,
             max_leaf_nodes=None, min_impurity_split=1e-07,
             min_samples_leaf=1, min_samples_split=2,
             min_weight_fraction_leaf=0.0, n_estimators=3, presort='auto',
             random_state=42, subsample=1.0, verbose=0, warm_start=False)

In [41]:
gbrt_slow = GradientBoostingRegressor(max_depth=2, n_estimators=200, learning_rate=0.1, random_state=42)
gbrt_slow.fit(X, y)


Out[41]:
GradientBoostingRegressor(alpha=0.9, criterion='friedman_mse', init=None,
             learning_rate=0.1, loss='ls', max_depth=2, max_features=None,
             max_leaf_nodes=None, min_impurity_split=1e-07,
             min_samples_leaf=1, min_samples_split=2,
             min_weight_fraction_leaf=0.0, n_estimators=200,
             presort='auto', random_state=42, subsample=1.0, verbose=0,
             warm_start=False)

In [42]:
plt.figure(figsize=(11,4))

plt.subplot(121)
plot_predictions([gbrt], X, y, axes=[-0.5, 0.5, -0.1, 0.8], label="Ensemble predictions")
plt.title("learning_rate={}, n_estimators={}".format(gbrt.learning_rate, gbrt.n_estimators), fontsize=14)

plt.subplot(122)
plot_predictions([gbrt_slow], X, y, axes=[-0.5, 0.5, -0.1, 0.8])
plt.title("learning_rate={}, n_estimators={}".format(gbrt_slow.learning_rate, gbrt_slow.n_estimators), fontsize=14)

save_fig("gbrt_learning_rate_plot")
plt.show()


Saving figure gbrt_learning_rate_plot

Gradient Boosting with Early stopping


In [43]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X_train, X_val, y_train, y_val = train_test_split(X, y, random_state=49)

gbrt = GradientBoostingRegressor(max_depth=2, n_estimators=120, random_state=42)
gbrt.fit(X_train, y_train)

errors = [mean_squared_error(y_val, y_pred)
          for y_pred in gbrt.staged_predict(X_val)]
bst_n_estimators = np.argmin(errors)

gbrt_best = GradientBoostingRegressor(max_depth=2,n_estimators=bst_n_estimators, random_state=42)
gbrt_best.fit(X_train, y_train)


Out[43]:
GradientBoostingRegressor(alpha=0.9, criterion='friedman_mse', init=None,
             learning_rate=0.1, loss='ls', max_depth=2, max_features=None,
             max_leaf_nodes=None, min_impurity_split=1e-07,
             min_samples_leaf=1, min_samples_split=2,
             min_weight_fraction_leaf=0.0, n_estimators=55, presort='auto',
             random_state=42, subsample=1.0, verbose=0, warm_start=False)

In [44]:
min_error = np.min(errors)

In [45]:
plt.figure(figsize=(11, 4))

plt.subplot(121)
plt.plot(errors, "b.-")
plt.plot([bst_n_estimators, bst_n_estimators], [0, min_error], "k--")
plt.plot([0, 120], [min_error, min_error], "k--")
plt.plot(bst_n_estimators, min_error, "ko")
plt.text(bst_n_estimators, min_error*1.2, "Minimum", ha="center", fontsize=14)
plt.axis([0, 120, 0, 0.01])
plt.xlabel("Number of trees")
plt.title("Validation error", fontsize=14)

plt.subplot(122)
plot_predictions([gbrt_best], X, y, axes=[-0.5, 0.5, -0.1, 0.8])
plt.title("Best model (%d trees)" % bst_n_estimators, fontsize=14)

save_fig("early_stopping_gbrt_plot")
plt.show()


Saving figure early_stopping_gbrt_plot

In [46]:
gbrt = GradientBoostingRegressor(max_depth=2, warm_start=True, random_state=42)

min_val_error = float("inf")
error_going_up = 0
for n_estimators in range(1, 120):
    gbrt.n_estimators = n_estimators
    gbrt.fit(X_train, y_train)
    y_pred = gbrt.predict(X_val)
    val_error = mean_squared_error(y_val, y_pred)
    if val_error < min_val_error:
        min_val_error = val_error
        error_going_up = 0
    else:
        error_going_up += 1
        if error_going_up == 5:
            break  # early stopping

In [47]:
print(gbrt.n_estimators)


61

Exercise solutions

Coming soon


In [ ]: