Chapter 1 – The Machine Learning landscape

This is the code used to generate some of the figures in chapter 1.

Setup

First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:


In [1]:
# To support both python 2 and python 3
from __future__ import division, print_function, unicode_literals

# Common imports
import numpy as np
import numpy.random as rnd
import os

# to make this notebook's output stable across runs
rnd.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

# Where to save the figures
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "fundamentals"

def save_fig(fig_id, tight_layout=True):
    path = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id + ".png")
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format='png', dpi=300)

Load and prepare Life satisfaction data


In [2]:
import pandas as pd

# Download CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI
datapath = "datasets/lifesat/"

oecd_bli = pd.read_csv(datapath+"oecd_bli_2015.csv", thousands=',')
oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]
oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")
oecd_bli.head(2)


Out[2]:
Indicator Air pollution Assault rate Consultation on rule-making Dwellings without basic facilities Educational attainment Employees working very long hours Employment rate Homicide rate Household net adjusted disposable income Household net financial wealth ... Long-term unemployment rate Personal earnings Quality of support network Rooms per person Self-reported health Student skills Time devoted to leisure and personal care Voter turnout Water quality Years in education
Country
Australia 13.0 2.1 10.5 1.1 76.0 14.02 72.0 0.8 31588.0 47657.0 ... 1.08 50449.0 92.0 2.3 85.0 512.0 14.41 93.0 91.0 19.4
Austria 27.0 3.4 7.1 1.0 83.0 7.61 72.0 0.4 31173.0 49887.0 ... 1.19 45199.0 89.0 1.6 69.0 500.0 14.46 75.0 94.0 17.0

2 rows × 24 columns


In [3]:
oecd_bli["Life satisfaction"].head()


Out[3]:
Country
Australia    7.3
Austria      6.9
Belgium      6.9
Brazil       7.0
Canada       7.3
Name: Life satisfaction, dtype: float64

Load and prepare GDP per capita data


In [4]:
# Download data from http://goo.gl/j1MSKe (=> imf.org)
gdp_per_capita = pd.read_csv(datapath+"gdp_per_capita.csv", thousands=',', delimiter='\t',
                             encoding='latin1', na_values="n/a")
gdp_per_capita.rename(columns={"2015": "GDP per capita"}, inplace=True)
gdp_per_capita.set_index("Country", inplace=True)
gdp_per_capita.head(2)


Out[4]:
Subject Descriptor Units Scale Country/Series-specific Notes GDP per capita Estimates Start After
Country
Afghanistan Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 599.994 2013.0
Albania Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 3995.383 2010.0

In [5]:
full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)
full_country_stats.sort_values(by="GDP per capita", inplace=True)
full_country_stats


Out[5]:
Air pollution Assault rate Consultation on rule-making Dwellings without basic facilities Educational attainment Employees working very long hours Employment rate Homicide rate Household net adjusted disposable income Household net financial wealth ... Time devoted to leisure and personal care Voter turnout Water quality Years in education Subject Descriptor Units Scale Country/Series-specific Notes GDP per capita Estimates Start After
Country
Brazil 18.0 7.9 4.0 6.7 45.0 10.41 67.0 25.5 11664.0 6844.0 ... 14.97 79.0 72.0 16.3 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 8669.998 2014.0
Mexico 30.0 12.8 9.0 4.2 37.0 28.83 61.0 23.4 13085.0 9056.0 ... 13.89 63.0 67.0 14.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 9009.280 2015.0
Russia 15.0 3.8 2.5 15.1 94.0 0.16 69.0 12.8 19292.0 3412.0 ... 14.97 65.0 56.0 16.0 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 9054.914 2015.0
Turkey 35.0 5.0 5.5 12.7 34.0 40.86 50.0 1.2 14095.0 3251.0 ... 13.42 88.0 62.0 16.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 9437.372 2013.0
Hungary 15.0 3.6 7.9 4.8 82.0 3.19 58.0 1.3 15442.0 13277.0 ... 15.04 62.0 77.0 17.6 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 12239.894 2015.0
Poland 33.0 1.4 10.8 3.2 90.0 7.41 60.0 0.9 17852.0 10919.0 ... 14.20 55.0 79.0 18.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 12495.334 2014.0
Chile 46.0 6.9 2.0 9.4 57.0 15.42 62.0 4.4 14533.0 17733.0 ... 14.41 49.0 73.0 16.5 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 13340.905 2014.0
Slovak Republic 13.0 3.0 6.6 0.6 92.0 7.02 60.0 1.2 17503.0 8663.0 ... 14.99 59.0 81.0 16.3 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 15991.736 2015.0
Czech Republic 16.0 2.8 6.8 0.9 92.0 6.98 68.0 0.8 18404.0 17299.0 ... 14.98 59.0 85.0 18.1 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 17256.918 2015.0
Estonia 9.0 5.5 3.3 8.1 90.0 3.30 68.0 4.8 15167.0 7680.0 ... 14.90 64.0 79.0 17.5 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 17288.083 2014.0
Greece 27.0 3.7 6.5 0.7 68.0 6.16 49.0 1.6 18575.0 14579.0 ... 14.91 64.0 69.0 18.6 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 18064.288 2014.0
Portugal 18.0 5.7 6.5 0.9 38.0 9.62 61.0 1.1 20086.0 31245.0 ... 14.95 58.0 86.0 17.6 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 19121.592 2014.0
Slovenia 26.0 3.9 10.3 0.5 85.0 5.63 63.0 0.4 19326.0 18465.0 ... 14.62 52.0 88.0 18.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 20732.482 2015.0
Spain 24.0 4.2 7.3 0.1 55.0 5.89 56.0 0.6 22477.0 24774.0 ... 16.06 69.0 71.0 17.6 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 25864.721 2014.0
Korea 30.0 2.1 10.4 4.2 82.0 18.72 64.0 1.1 19510.0 29091.0 ... 14.63 76.0 78.0 17.5 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 27195.197 2014.0
Italy 21.0 4.7 5.0 1.1 57.0 3.66 56.0 0.7 25166.0 54987.0 ... 14.98 75.0 71.0 16.8 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 29866.581 2015.0
Japan 24.0 1.4 7.3 6.4 94.0 22.26 72.0 0.3 26111.0 86764.0 ... 14.93 53.0 85.0 16.3 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 32485.545 2015.0
Israel 21.0 6.4 2.5 3.7 85.0 16.03 67.0 2.3 22104.0 52933.0 ... 14.48 68.0 68.0 15.8 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 35343.336 2015.0
New Zealand 11.0 2.2 10.3 0.2 74.0 13.87 73.0 1.2 23815.0 28290.0 ... 14.87 77.0 89.0 18.1 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 37044.891 2015.0
France 12.0 5.0 3.5 0.5 73.0 8.15 64.0 0.6 28799.0 48741.0 ... 15.33 80.0 82.0 16.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 37675.006 2015.0
Belgium 21.0 6.6 4.5 2.0 72.0 4.57 62.0 1.1 28307.0 83876.0 ... 15.71 89.0 87.0 18.9 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 40106.632 2014.0
Germany 16.0 3.6 4.5 0.1 86.0 5.25 73.0 0.5 31252.0 50394.0 ... 15.31 72.0 95.0 18.2 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 40996.511 2014.0
Finland 15.0 2.4 9.0 0.6 85.0 3.58 69.0 1.4 27927.0 18761.0 ... 14.89 69.0 94.0 19.7 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 41973.988 2014.0
Canada 15.0 1.3 10.5 0.2 89.0 3.94 72.0 1.5 29365.0 67913.0 ... 14.25 61.0 91.0 17.2 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 43331.961 2015.0
Netherlands 30.0 4.9 6.1 0.0 73.0 0.45 74.0 0.9 27888.0 77961.0 ... 15.44 75.0 92.0 18.7 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 43603.115 2014.0
Austria 27.0 3.4 7.1 1.0 83.0 7.61 72.0 0.4 31173.0 49887.0 ... 14.46 75.0 94.0 17.0 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 43724.031 2015.0
United Kingdom 13.0 1.9 11.5 0.2 78.0 12.70 71.0 0.3 27029.0 60778.0 ... 14.83 66.0 88.0 16.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 43770.688 2015.0
Sweden 10.0 5.1 10.9 0.0 88.0 1.13 74.0 0.7 29185.0 60328.0 ... 15.11 86.0 95.0 19.3 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 49866.266 2014.0
Iceland 18.0 2.7 5.1 0.4 71.0 12.25 82.0 0.3 23965.0 43045.0 ... 14.61 81.0 97.0 19.8 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 50854.583 2014.0
Australia 13.0 2.1 10.5 1.1 76.0 14.02 72.0 0.8 31588.0 47657.0 ... 14.41 93.0 91.0 19.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 50961.865 2014.0
Ireland 13.0 2.6 9.0 0.2 75.0 4.20 60.0 0.8 23917.0 31580.0 ... 15.19 70.0 80.0 17.6 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 51350.744 2014.0
Denmark 15.0 3.9 7.0 0.9 78.0 2.03 73.0 0.3 26491.0 44488.0 ... 16.06 88.0 94.0 19.4 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 52114.165 2015.0
United States 18.0 1.5 8.3 0.1 89.0 11.30 67.0 5.2 41355.0 145769.0 ... 14.27 68.0 85.0 17.2 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 55805.204 2015.0
Norway 16.0 3.3 8.1 0.3 82.0 2.82 75.0 0.6 33492.0 8797.0 ... 15.56 78.0 94.0 17.9 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 74822.106 2015.0
Switzerland 20.0 4.2 8.4 0.0 86.0 6.72 80.0 0.5 33491.0 108823.0 ... 14.98 49.0 96.0 17.3 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 80675.308 2015.0
Luxembourg 12.0 4.3 6.0 0.1 78.0 3.47 66.0 0.4 38951.0 61765.0 ... 15.12 91.0 86.0 15.1 Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 101994.093 2014.0

36 rows × 30 columns


In [6]:
full_country_stats[["GDP per capita", 'Life satisfaction']].loc["United States"]


Out[6]:
GDP per capita       55805.204
Life satisfaction        7.200
Name: United States, dtype: float64

In [7]:
remove_indices = [0, 1, 6, 8, 33, 34, 35]
keep_indices = list(set(range(36)) - set(remove_indices))

sample_data = full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[keep_indices]
missing_data = full_country_stats[["GDP per capita", 'Life satisfaction']].iloc[remove_indices]

In [8]:
sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(5,3))
plt.axis([0, 60000, 0, 10])
position_text = {
    "Hungary": (5000, 1),
    "Korea": (18000, 1.7),
    "France": (29000, 2.4),
    "Australia": (40000, 3.0),
    "United States": (52000, 3.8),
}
for country, pos_text in position_text.items():
    pos_data_x, pos_data_y = sample_data.loc[country]
    country = "U.S." if country == "United States" else country
    plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,
            arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))
    plt.plot(pos_data_x, pos_data_y, "ro")
save_fig('money_happy_scatterplot')
plt.show()


Saving figure money_happy_scatterplot

In [9]:
sample_data.to_csv("life_satisfaction_vs_gdp_per_capita.csv")

In [10]:
sample_data.loc[list(position_text.keys())]


Out[10]:
GDP per capita Life satisfaction
Country
Hungary 12239.894 4.9
Korea 27195.197 5.8
Australia 50961.865 7.3
United States 55805.204 7.2
France 37675.006 6.5

In [11]:
import numpy as np

sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(5,3))
plt.axis([0, 60000, 0, 10])
X=np.linspace(0, 60000, 1000)
plt.plot(X, 2*X/100000, "r")
plt.text(40000, 2.7, r"$\theta_0 = 0$", fontsize=14, color="r")
plt.text(40000, 1.8, r"$\theta_1 = 2 \times 10^{-5}$", fontsize=14, color="r")
plt.plot(X, 8 - 5*X/100000, "g")
plt.text(5000, 9.1, r"$\theta_0 = 8$", fontsize=14, color="g")
plt.text(5000, 8.2, r"$\theta_1 = -5 \times 10^{-5}$", fontsize=14, color="g")
plt.plot(X, 4 + 5*X/100000, "b")
plt.text(5000, 3.5, r"$\theta_0 = 4$", fontsize=14, color="b")
plt.text(5000, 2.6, r"$\theta_1 = 5 \times 10^{-5}$", fontsize=14, color="b")
save_fig('tweaking_model_params_plot')
plt.show()


Saving figure tweaking_model_params_plot

In [12]:
from sklearn import linear_model
lin1 = linear_model.LinearRegression()
Xsample = np.c_[sample_data["GDP per capita"]]
ysample = np.c_[sample_data["Life satisfaction"]]
lin1.fit(Xsample, ysample)
t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]
t0, t1


Out[12]:
(4.8530528002664362, 4.9115445891584838e-05)

In [13]:
sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(5,3))
plt.axis([0, 60000, 0, 10])
X=np.linspace(0, 60000, 1000)
plt.plot(X, t0 + t1*X, "b")
plt.text(5000, 3.1, r"$\theta_0 = 4.85$", fontsize=14, color="b")
plt.text(5000, 2.2, r"$\theta_1 = 4.91 \times 10^{-5}$", fontsize=14, color="b")
save_fig('best_fit_model_plot')
plt.show()


Saving figure best_fit_model_plot

In [14]:
cyprus_gdp_per_capita = gdp_per_capita.loc["Cyprus"]["GDP per capita"]
print(cyprus_gdp_per_capita)
cyprus_predicted_life_satisfaction = lin1.predict(cyprus_gdp_per_capita)[0][0]
cyprus_predicted_life_satisfaction


22587.49
Out[14]:
5.9624474431881502

In [15]:
sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(5,3), s=1)
X=np.linspace(0, 60000, 1000)
plt.plot(X, t0 + t1*X, "b")
plt.axis([0, 60000, 0, 10])
plt.text(5000, 7.5, r"$\theta_0 = 4.85$", fontsize=14, color="b")
plt.text(5000, 6.6, r"$\theta_1 = 4.91 \times 10^{-5}$", fontsize=14, color="b")
plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita], [0, cyprus_predicted_life_satisfaction], "r--")
plt.text(25000, 5.0, r"Prediction = 5.96", fontsize=14, color="b")
plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, "ro")
save_fig('cyprus_prediction_plot')
plt.show()


Saving figure cyprus_prediction_plot

In [16]:
sample_data[7:10]


Out[16]:
GDP per capita Life satisfaction
Country
Portugal 19121.592 5.1
Slovenia 20732.482 5.7
Spain 25864.721 6.5

In [17]:
(5.1+5.7+6.5)/3


Out[17]:
5.766666666666667

In [18]:
backup = oecd_bli, gdp_per_capita

def prepare_country_stats(oecd_bli, gdp_per_capita):
    return sample_data

In [19]:
# Code example
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn

# Load the data
oecd_bli = pd.read_csv(datapath + "oecd_bli_2015.csv", thousands=',')
gdp_per_capita = pd.read_csv(datapath + "gdp_per_capita.csv",thousands=',',delimiter='\t',
                             encoding='latin1', na_values="n/a")

# Prepare the data
country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

# Visualize the data
country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction')
plt.show()

# Select a linear model
model = sklearn.linear_model.LinearRegression()

# Train the model
model.fit(X, y)

# Make a prediction for Cyprus
X_new = [[22587]]  # Cyprus' GDP per capita
print(model.predict(X_new)) # outputs [[ 5.96242338]]


[[ 5.96242338]]

In [20]:
oecd_bli, gdp_per_capita = backup

In [21]:
missing_data


Out[21]:
GDP per capita Life satisfaction
Country
Brazil 8669.998 7.0
Mexico 9009.280 6.7
Chile 13340.905 6.7
Czech Republic 17256.918 6.5
Norway 74822.106 7.4
Switzerland 80675.308 7.5
Luxembourg 101994.093 6.9

In [22]:
position_text2 = {
    "Brazil": (1000, 9.0),
    "Mexico": (11000, 9.0),
    "Chile": (25000, 9.0),
    "Czech Republic": (35000, 9.0),
    "Norway": (60000, 3),
    "Switzerland": (72000, 3.0),
    "Luxembourg": (90000, 3.0),
}

In [23]:
sample_data.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(8,3))
plt.axis([0, 110000, 0, 10])

for country, pos_text in position_text2.items():
    pos_data_x, pos_data_y = missing_data.loc[country]
    plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,
            arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))
    plt.plot(pos_data_x, pos_data_y, "rs")

X=np.linspace(0, 110000, 1000)
plt.plot(X, t0 + t1*X, "b:")

lin_reg_full = linear_model.LinearRegression()
Xfull = np.c_[full_country_stats["GDP per capita"]]
yfull = np.c_[full_country_stats["Life satisfaction"]]
lin_reg_full.fit(Xfull, yfull)

t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]
X = np.linspace(0, 110000, 1000)
plt.plot(X, t0full + t1full * X, "k")

save_fig('representative_training_data_scatterplot')
plt.show()


Saving figure representative_training_data_scatterplot

In [24]:
full_country_stats.plot(kind='scatter', x="GDP per capita", y='Life satisfaction', figsize=(8,3))
plt.axis([0, 110000, 0, 10])

from sklearn import preprocessing
from sklearn import pipeline

poly = preprocessing.PolynomialFeatures(degree=60, include_bias=False)
scaler = preprocessing.StandardScaler()
lin_reg2 = linear_model.LinearRegression()

pipeline_reg = pipeline.Pipeline([('poly', poly), ('scal', scaler), ('lin', lin_reg2)])
pipeline_reg.fit(Xfull, yfull)
curve = pipeline_reg.predict(X[:, np.newaxis])
plt.plot(X, curve)
save_fig('overfitting_model_plot')
plt.show()


/home/ageron/dev/py/envs/ml/lib/python3.5/site-packages/numpy/core/_methods.py:116: RuntimeWarning: overflow encountered in multiply
  x = um.multiply(x, x, out=x)
Saving figure overfitting_model_plot

In [25]:
full_country_stats.loc[[c for c in full_country_stats.index if "W" in c.upper()]]["Life satisfaction"]


Out[25]:
Country
New Zealand    7.3
Sweden         7.2
Norway         7.4
Switzerland    7.5
Name: Life satisfaction, dtype: float64

In [26]:
gdp_per_capita.loc[[c for c in gdp_per_capita.index if "W" in c.upper()]].head()


Out[26]:
Subject Descriptor Units Scale Country/Series-specific Notes GDP per capita Estimates Start After
Country
Botswana Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 6040.957 2008.0
Kuwait Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 29363.027 2014.0
Malawi Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 354.275 2011.0
New Zealand Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 37044.891 2015.0
Norway Gross domestic product per capita, current prices U.S. dollars Units See notes for: Gross domestic product, curren... 74822.106 2015.0

In [27]:
plt.figure(figsize=(8,3))

plt.xlabel("GDP per capita")
plt.ylabel('Life satisfaction')

plt.plot(list(sample_data["GDP per capita"]), list(sample_data["Life satisfaction"]), "bo")
plt.plot(list(missing_data["GDP per capita"]), list(missing_data["Life satisfaction"]), "rs")

X = np.linspace(0, 110000, 1000)
plt.plot(X, t0full + t1full * X, "r--", label="Linear model on all data")
plt.plot(X, t0 + t1*X, "b:", label="Linear model on partial data")

ridge = linear_model.Ridge(alpha=10**9.5)
Xsample = np.c_[sample_data["GDP per capita"]]
ysample = np.c_[sample_data["Life satisfaction"]]
ridge.fit(Xsample, ysample)
t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]
plt.plot(X, t0ridge + t1ridge * X, "b", label="Regularized linear model on partial data")

plt.legend(loc="lower right")
plt.axis([0, 110000, 0, 10])
save_fig('ridge_model_plot')
plt.show()


Saving figure ridge_model_plot

In [28]:
backup = oecd_bli, gdp_per_capita

def prepare_country_stats(oecd_bli, gdp_per_capita):
    return sample_data

In [29]:
# Replace this linear model:
model = sklearn.linear_model.LinearRegression()

In [30]:
# with this k-neighbors regression model:
model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)

In [31]:
X = np.c_[country_stats["GDP per capita"]]
y = np.c_[country_stats["Life satisfaction"]]

# Train the model
model.fit(X, y)

# Make a prediction for Cyprus
X_new = np.array([[22587.0]])  # Cyprus' GDP per capita
print(model.predict(X_new)) # outputs [[ 5.76666667]]


[[ 5.76666667]]

In [ ]: