Our Example: Classifying customers for car insurances


In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
%matplotlib inline
%pylab inline
import matplotlib.pyplot as plt


Populating the interactive namespace from numpy and matplotlib

In [3]:
import pandas as pd
print(pd.__version__)


0.23.3

In [4]:
import numpy as np
print(np.__version__)


1.15.0

In [5]:
import seaborn as sns
print(sns.__version__)


0.9.0

Loading and exploring our data set

This is a database of customers of an insurance company. Each data point is one customer. The group represents the number of accidents the customer has been involved with in the past

  • 0 - red: many accidents
  • 1 - green: few or no accidents
  • 2 - yellow: in the middle

In [6]:
!curl -O https://raw.githubusercontent.com/DJCordhose/ai/master/notebooks/manning/data/insurance-customers-1500.csv


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 26822  100 26822    0     0  85967      0 --:--:-- --:--:-- --:--:-- 85967

In [7]:
df = pd.read_csv('./insurance-customers-1500.csv', sep=';')

In [8]:
df.describe()


Out[8]:
max speed age thousand miles per year group
count 1500.000000 1500.000000 1500.000000 1500.000000
mean 122.492667 44.952667 30.344000 0.998667
std 17.604333 17.191727 15.463152 0.816768
min 68.000000 9.000000 -21.000000 0.000000
25% 108.000000 32.000000 18.000000 0.000000
50% 120.000000 42.000000 29.000000 1.000000
75% 137.000000 55.000000 42.000000 2.000000
max 166.000000 102.000000 84.000000 2.000000

A pairplot of a few samples gives you a nice overview of your data


In [9]:
sample_df = df.sample(n=100, random_state=42)
sns.pairplot(sample_df, hue="group", palette={0: '#AA4444', 1: '#006000', 2: '#EEEE44'})


Out[9]:
<seaborn.axisgrid.PairGrid at 0x2799432e550>

Choose an intuitive view on your data - plot speed vs age only


In [10]:
# ignore this, it is just technical code to plot decision boundaries
# Adapted from:
# http://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
# http://jponttuset.cat/xkcd-deep-learning/

from matplotlib.colors import ListedColormap

cmap_print = ListedColormap(['#AA8888', '#004000', '#FFFFDD'])
cmap_bold = ListedColormap(['#AA4444', '#006000', '#EEEE44'])
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#FFFFDD'])
font_size=25
title_font_size=40

def meshGrid(x_data, y_data):
    h = 1  # step size in the mesh
    x_min, x_max = x_data.min() - 1, x_data.max() + 1
    y_min, y_max = y_data.min() - 1, y_data.max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    return (xx,yy)
    
def plot_prediction(clf, x_data, y_data, x_label, y_label, ground_truth, title="", 
                   mesh=True, fname=None, print=False):
    xx,yy = meshGrid(x_data, y_data)
    fig, ax = plt.subplots(figsize=(20,10))

    if clf and mesh:
        Z = clf.predict(np.c_[yy.ravel(), xx.ravel()])
        # Put the result into a color plot
        Z = Z.reshape(xx.shape)
        ax.pcolormesh(xx, yy, Z, cmap=cmap_light)
    
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    if print:
        ax.scatter(x_data, y_data, c=ground_truth, cmap=cmap_print, s=200, marker='o', edgecolors='k')
    else:
        ax.scatter(x_data, y_data, c=ground_truth, cmap=cmap_bold, s=100, marker='o', edgecolors='k')
        
    ax.set_xlabel(x_label, fontsize=font_size)
    ax.set_ylabel(y_label, fontsize=font_size)
    ax.set_title(title, fontsize=title_font_size)
    if fname:
        fig.savefig('figures/'+fname)

In [11]:
sample_df = df.sample(n=300, random_state=42)
y = sample_df['group']
sample_df.drop('group', axis='columns', inplace=True)
X = sample_df.as_matrix()

In [17]:
X


Out[17]:
array([[104.,  39.,   8.],
       [123.,  38.,  28.],
       [145.,  45.,  58.],
       [134.,  60.,  32.],
       [140.,  52.,  42.],
       [117.,  53.,  48.],
       [ 96.,  26.,  17.],
       [ 92.,  20.,  12.],
       [104.,  46.,  20.],
       [ 92.,  22.,  19.],
       [113.,  44.,  51.],
       [114.,  34.,  17.],
       [117.,  71.,  42.],
       [147.,  49.,  44.],
       [118.,  78.,  27.],
       [119.,  25.,  31.],
       [139.,  46.,  40.],
       [116.,  74.,  41.],
       [ 94.,  27.,  11.],
       [116.,  49.,  38.],
       [152.,  41.,  50.],
       [107.,  23.,  13.],
       [123.,  54.,  56.],
       [103.,  58.,   5.],
       [ 97.,  28.,  15.],
       [137.,  48.,  41.],
       [152.,  31.,  51.],
       [134.,  40.,  45.],
       [110.,  34.,  37.],
       [140.,  47.,  37.],
       [ 95.,  26.,  12.],
       [144.,  33.,  51.],
       [114.,  41.,  25.],
       [117.,  67.,  23.],
       [129.,  68.,  22.],
       [152.,  33.,  49.],
       [ 93.,  20.,  16.],
       [ 99.,  28.,  13.],
       [112.,  38.,  28.],
       [115.,  75.,  17.],
       [153.,  30.,  49.],
       [133.,  38.,  42.],
       [144.,  65.,  38.],
       [152.,  47.,  33.],
       [ 95.,  22.,  14.],
       [109.,  37.,   7.],
       [ 92.,  27.,  12.],
       [127.,  64.,  21.],
       [119.,  69.,  41.],
       [135.,  38.,  41.],
       [115.,  44.,  27.],
       [126.,  83.,  21.],
       [126.,  77.,  28.],
       [100.,  30.,  21.],
       [149.,  36.,  54.],
       [116.,  31.,  14.],
       [113.,  46.,  25.],
       [138.,  80.,  24.],
       [135.,  44.,  40.],
       [104.,  17.,  16.],
       [154.,  35.,  56.],
       [150.,  52.,  57.],
       [123.,  90.,  18.],
       [113.,  24.,  22.],
       [144.,  37.,  59.],
       [126.,  85.,  19.],
       [143.,  30.,  42.],
       [108.,  26.,  25.],
       [117.,  72.,  13.],
       [152.,  34.,  53.],
       [132.,  36.,  49.],
       [103.,  56.,  15.],
       [100.,  27.,  25.],
       [120.,  84.,  20.],
       [148.,  31.,  51.],
       [125.,  45.,  18.],
       [108.,  28.,  38.],
       [111.,  33.,   9.],
       [112.,  51.,  19.],
       [116.,  54.,  29.],
       [105.,  44.,  26.],
       [114.,  52.,  18.],
       [119.,  41.,  49.],
       [117.,  55.,  13.],
       [109.,  31.,  32.],
       [109.,  25.,  29.],
       [127.,  67.,  21.],
       [ 97.,  36.,   9.],
       [125.,  47.,  15.],
       [103.,  52.,  16.],
       [138.,  49., -15.],
       [113.,  37.,  14.],
       [134.,  47.,  35.],
       [100.,  41.,  26.],
       [156.,  34.,  61.],
       [113.,  42.,  13.],
       [146.,  29.,  55.],
       [146.,  39.,  36.],
       [117.,  27.,  37.],
       [119.,  37.,  15.],
       [134.,  65.,  27.],
       [105.,  56.,  34.],
       [151.,  33.,  48.],
       [145.,  40.,  42.],
       [128.,  84.,  38.],
       [118.,  44.,  21.],
       [115.,  25.,  44.],
       [105.,  42.,  21.],
       [151.,  34.,  50.],
       [118.,  88.,  24.],
       [148.,  27.,  48.],
       [111.,  27.,  32.],
       [104.,  26.,  17.],
       [108.,  43.,  13.],
       [106.,  31.,  -3.],
       [113.,  44.,   9.],
       [141.,  46.,  46.],
       [124.,  27.,  36.],
       [106.,  17.,  31.],
       [133.,  45.,  49.],
       [128.,  39.,  13.],
       [121.,  23.,  40.],
       [100.,  42.,  16.],
       [142.,  75.,  41.],
       [116.,  36.,  31.],
       [ 99.,  31.,   7.],
       [ 97.,  26.,  16.],
       [103.,  44.,  12.],
       [ 93.,  78.,  20.],
       [ 90.,  32.,  14.],
       [158.,  48.,  44.],
       [124.,  45.,  30.],
       [122.,  45.,   8.],
       [116.,  49.,  42.],
       [117.,  24.,  37.],
       [123.,  77.,  26.],
       [ 89.,  26.,  12.],
       [150.,  32.,  56.],
       [ 91.,  65.,  27.],
       [100.,  24.,  21.],
       [101.,  54.,  30.],
       [124.,  62.,  28.],
       [152.,  38.,  58.],
       [151.,  36.,  50.],
       [158.,  39.,  44.],
       [125.,  21.,  30.],
       [123.,  73.,  25.],
       [146.,  45.,  45.],
       [123.,  43.,  44.],
       [106.,  38.,  14.],
       [ 96.,  37.,  16.],
       [110.,  40.,  39.],
       [114.,  46.,  28.],
       [105.,  23.,  19.],
       [119.,  81.,  26.],
       [ 95.,  64.,  24.],
       [117.,  18.,  26.],
       [119.,  16.,  27.],
       [ 96.,  22.,  13.],
       [149.,  44.,  36.],
       [108.,  36.,  25.],
       [138.,  47.,  48.],
       [100.,  45.,  29.],
       [107.,  56.,   8.],
       [119.,  74.,  21.],
       [ 99.,  30.,  11.],
       [100.,  47.,  22.],
       [141.,  29.,  55.],
       [151.,  35.,  42.],
       [134.,  45.,  36.],
       [129.,  29.,  36.],
       [130.,  76.,  43.],
       [149.,  29.,  48.],
       [104.,  66.,  31.],
       [121.,  79.,  11.],
       [137.,  37.,  61.],
       [117.,  30.,  19.],
       [126.,  41.,  49.],
       [148.,  29.,  50.],
       [149.,  50.,  39.],
       [111.,  61.,  18.],
       [101.,  35.,  20.],
       [123.,  73.,  18.],
       [127.,  71.,  19.],
       [ 93.,  47.,   3.],
       [151.,  29.,  44.],
       [142.,  54.,  42.],
       [138.,  49.,  42.],
       [120.,  24.,  26.],
       [114.,  52.,  17.],
       [151.,  32.,  54.],
       [132.,  83.,  17.],
       [ 95.,  28.,  12.],
       [100.,  16.,  19.],
       [107.,  43.,   5.],
       [ 99.,  25.,   7.],
       [120.,  45.,  38.],
       [118.,  63.,  14.],
       [119.,  56.,  -6.],
       [ 99.,  23.,  18.],
       [133.,  65.,  15.],
       [120.,  42.,  35.],
       [110.,  76.,  34.],
       [121.,  23.,  23.],
       [100.,  32.,  51.],
       [114.,  69.,  22.],
       [155.,  69.,  53.],
       [ 95.,  45.,  17.],
       [100.,  29.,  59.],
       [135.,  30.,  32.],
       [101.,  23.,   6.],
       [123.,  35.,  39.],
       [117.,  42.,  26.],
       [105.,  38.,  26.],
       [151.,  36.,  47.],
       [117.,  67.,  49.],
       [ 92.,  10.,  18.],
       [152.,  32.,  58.],
       [102.,  45.,  32.],
       [126.,  66.,  37.],
       [103.,  15.,  11.],
       [121.,  48.,  40.],
       [118.,  77.,  22.],
       [122.,  52.,  30.],
       [124.,  34.,  49.],
       [121.,  86.,  12.],
       [125.,  59.,   8.],
       [120.,  24.,  16.],
       [ 98.,  25.,  17.],
       [128.,  54.,  47.],
       [106.,  58.,  55.],
       [ 99.,  23.,  15.],
       [153.,  38.,  55.],
       [ 90.,  53.,  12.],
       [104.,  25.,  17.],
       [123.,  31.,  13.],
       [101.,  25.,  21.],
       [152.,  37.,  65.],
       [146.,  33.,  46.],
       [127.,  46.,  57.],
       [ 96.,  33.,  14.],
       [110.,  62.,   3.],
       [138.,  41.,  36.],
       [122.,  54.,   2.],
       [146.,  50.,  37.],
       [113.,  39.,  38.],
       [138.,  43.,  71.],
       [152.,  31.,  49.],
       [129.,  86.,  19.],
       [118.,  45.,  42.],
       [148.,  34.,  46.],
       [120.,  30.,  24.],
       [114.,  36.,  19.],
       [149.,  49.,  42.],
       [111.,  76.,  20.],
       [147.,  36.,  56.],
       [147.,  35.,  52.],
       [150.,  39.,  51.],
       [124.,  92.,  18.],
       [129.,  56.,  39.],
       [108.,  36.,  23.],
       [101.,  31.,  12.],
       [129.,  24.,  37.],
       [150.,  37.,  53.],
       [126.,  18.,  41.],
       [116.,  72.,  51.],
       [ 99.,  25.,  17.],
       [117.,  75.,  18.],
       [110.,  37.,  44.],
       [139.,  44.,  63.],
       [134.,  45., -17.],
       [120.,  62.,  29.],
       [124.,  33.,  11.],
       [109.,  38.,  38.],
       [122.,  75.,  22.],
       [114.,  74.,  31.],
       [143.,  39.,  39.],
       [143.,  48.,  39.],
       [137.,  40.,  36.],
       [145.,  34.,  41.],
       [146.,  57.,  40.],
       [ 96.,  80.,  22.],
       [138.,  36.,  35.],
       [159.,  32.,  47.],
       [114.,  43.,  29.],
       [108.,  57.,   5.],
       [143.,  53.,  56.],
       [106.,  47.,  27.],
       [122.,  79.,   7.],
       [148.,  35.,  45.],
       [107.,  69.,  28.],
       [106.,  71.,  14.],
       [127.,  56.,  12.],
       [119.,  59.,  17.],
       [101.,  33.,  22.],
       [112.,  58.,  26.],
       [121.,  36.,  20.],
       [150.,  38.,  49.],
       [107.,  60.,  15.],
       [140.,  30.,  39.]])

In [12]:
plot_prediction(None, X[:, 1], X[:, 0], 
               'Age', 'Max Speed', y, mesh=False,
                title="Max Speed vs Age")
#                , fname='all.png')


Our Objective: Create a general Model to predict the risk group from our features


In [13]:
# 0: red
# 1: green
# 2: yellow

class ClassifierBase:
    def predict(self, X):
        return np.array([ self.predict_single(x) for x in X])
    def score(self, X, y):
        n = len(y)
        correct = 0
        predictions = self.predict(X)
        for prediction, ground_truth in zip(predictions, y):
            if prediction == ground_truth:
                correct = correct + 1
        return correct / n

First a random model as a base line - how well do you think will it perform?


In [14]:
from random import randrange

class RandomClassifier(ClassifierBase):
    def predict_single(self, x):
        return randrange(3)

In [15]:
random_clf = RandomClassifier()

In [16]:
plot_prediction(random_clf, X[:, 1], X[:, 0], 
               'Age', 'Max Speed', y,
                title="Max Speed vs Age (Random)",
                fname='random.png')


By just randomly guessing, we get approx. 1/3 right, which is what we expect


In [17]:
random_clf.score(X, y)


Out[17]:
0.3333333333333333

Creating a naive classifier manually, how much better is it?


In [18]:
# 0: red
# 1: green
# 2: yellow

class BaseLineClassifier(ClassifierBase):
    def predict_single(self, x):
        try:
            speed, age, miles_per_year = x
        except:
            speed, age = x
            miles_per_year = 0
        if age < 25:
            if speed > 140:
                return 0
            else:
                return 2
        if age > 75:
            return 0
        if miles_per_year > 30:
            return 0
        if miles_per_year > 20:
            return 2
        return 1

In [19]:
base_clf = BaseLineClassifier()

In [21]:
plot_prediction(base_clf, X[:, 1], X[:, 0], 
               'Age', 'Max Speed', y,
                title="Max Speed vs Age with Classification",
                fname='manual.png')


This is the baseline we have to beat


In [22]:
base_clf.score(X, y)


Out[22]:
0.43666666666666665

In [ ]: