Churn Prediction with PySpark using MLlib and ML Packages

Churn prediction is big business. It minimizes customer defection by predicting which customers are likely to cancel a subscription to a service. Though originally used within the telecommunications industry, it has become common practice across banks, ISPs, insurance firms, and other verticals.

The prediction process is heavily data driven and often utilizes advanced machine learning techniques. In this post, we'll take a look at what types of customer data are typically used, do some preliminary analysis of the data, and generate churn prediction models - all with PySpark and its machine learning frameworks. We'll also discuss the differences between two Apache Spark version 1.6.0 frameworks, MLlib and ML.

Install and Run Jupyter on Spark

Installation on a single machine

To run this notebook tutorial, we'll need to install Spark and Jupyter/IPython, along with Python's Pandas and Matplotlib libraries.

For the sake of simplicity, let's run PySpark in local mode, using a single machine:

PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS=notebook /path/to/bin/pyspark --packages com.databricks:spark-csv_2.10:1.3.0 --master local[*]

Installation on a Cloudera CDH cluster using Spark on Yarn

Run Jupyter Notebook on Cloudera


In [49]:
# Disable warnings, set Matplotlib inline plotting and load Pandas package
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline
import pandas as pd
pd.options.display.mpl_style = 'default'

Fetching and Importing Churn Data

For this tutorial, we'll be using the Orange Telecoms Churn Dataset. It consists of cleaned customer activity data (features), along with a churn label specifying whether the customer canceled their subscription or not. The data can be fetched from BigML's S3 bucket, churn-80 and churn-20. The two sets are from the same batch, but have been split by an 80/20 ratio. We'll use the larger set for training and cross-validation purposes, and the smaller set for final testing and model performance evaluation. The two data sets have been included in this repository for convenience.

In order to read the CSV data and parse it into Spark DataFrames, we'll use the CSV package. The library has already been loaded using the initial pyspark bin command call, so we're ready to go.

Let's load the two CSV data sets into DataFrames, keeping the header information and caching them into memory for quick, repeated access. We'll also print the schema of the sets.


In [2]:
CV_data = sqlContext.read.load('data/churn-bigml-80.csv', 
                          format='com.databricks.spark.csv', 
                          header='true', 
                          inferSchema='true')

final_test_data = sqlContext.read.load('data/churn-bigml-20.csv', 
                          format='com.databricks.spark.csv', 
                          header='true', 
                          inferSchema='true')
CV_data.cache()
CV_data.printSchema()


root
 |-- State: string (nullable = true)
 |-- Account length: integer (nullable = true)
 |-- Area code: integer (nullable = true)
 |-- International plan: string (nullable = true)
 |-- Voice mail plan: string (nullable = true)
 |-- Number vmail messages: integer (nullable = true)
 |-- Total day minutes: double (nullable = true)
 |-- Total day calls: integer (nullable = true)
 |-- Total day charge: double (nullable = true)
 |-- Total eve minutes: double (nullable = true)
 |-- Total eve calls: integer (nullable = true)
 |-- Total eve charge: double (nullable = true)
 |-- Total night minutes: double (nullable = true)
 |-- Total night calls: integer (nullable = true)
 |-- Total night charge: double (nullable = true)
 |-- Total intl minutes: double (nullable = true)
 |-- Total intl calls: integer (nullable = true)
 |-- Total intl charge: double (nullable = true)
 |-- Customer service calls: integer (nullable = true)
 |-- Churn: boolean (nullable = true)

By taking 5 rows of the CV_data variable and generating a Pandas DataFrame with them, we can get a display of what the rows look like. We're using Pandas instead of the Spark DataFrame.show() function because it creates a prettier print.


In [3]:
pd.DataFrame(CV_data.take(5), columns=CV_data.columns)


Out[3]:
State Account length Area code International plan Voice mail plan Number vmail messages Total day minutes Total day calls Total day charge Total eve minutes Total eve calls Total eve charge Total night minutes Total night calls Total night charge Total intl minutes Total intl calls Total intl charge Customer service calls Churn
0 KS 128 415 No Yes 25 265.1 110 45.07 197.4 99 16.78 244.7 91 11.01 10.0 3 2.70 1 False
1 OH 107 415 No Yes 26 161.6 123 27.47 195.5 103 16.62 254.4 103 11.45 13.7 3 3.70 1 False
2 NJ 137 415 No No 0 243.4 114 41.38 121.2 110 10.30 162.6 104 7.32 12.2 5 3.29 0 False
3 OH 84 408 Yes No 0 299.4 71 50.90 61.9 88 5.26 196.9 89 8.86 6.6 7 1.78 2 False
4 OK 75 415 Yes No 0 166.7 113 28.34 148.3 122 12.61 186.9 121 8.41 10.1 3 2.73 3 False

Summary Statistics

Spark DataFrames include some built-in functions for statistical processing. The describe() function performs summary statistics calculations on all numeric columns, and returns them as a DataFrame.


In [4]:
CV_data.describe().toPandas().transpose()


Out[4]:
0 1 2 3 4
summary count mean stddev min max
Account length 2666 100.62040510127532 39.56397365334985 1 243
Area code 2666 437.43885971492875 42.521018019427174 408 510
Number vmail messages 2666 8.021755438859715 13.61227701829193 0 50
Total day minutes 2666 179.48162040510135 54.21035022086982 0.0 350.8
Total day calls 2666 100.31020255063765 19.988162186059512 0 160
Total day charge 2666 30.512404351087813 9.215732907163497 0.0 59.64
Total eve minutes 2666 200.38615903976006 50.95151511764598 0.0 363.7
Total eve calls 2666 100.02363090772693 20.16144511531889 0 170
Total eve charge 2666 17.033072018004518 4.330864176799864 0.0 30.91
Total night minutes 2666 201.16894223555968 50.780323368725206 43.7 395.0
Total night calls 2666 100.10615153788447 19.418458551101697 33 166
Total night charge 2666 9.052689422355604 2.2851195129157564 1.97 17.77
Total intl minutes 2666 10.23702175543886 2.7883485770512566 0.0 20.0
Total intl calls 2666 4.467366841710428 2.4561949030129466 0 20
Total intl charge 2666 2.764489872468112 0.7528120531228477 0.0 5.4
Customer service calls 2666 1.5626406601650413 1.3112357589949093 0 9

Correlations and Data Preparation

We can also perform our own statistical analyses, using the MLlib statistics package or other python packages. Here, we're use the Pandas library to examine correlations between the numeric columns by generating scatter plots of them.

For the Pandas workload, we don't want to pull the entire data set into the Spark driver, as that might exhaust the available RAM and throw an out-of-memory exception. Instead, we'll randomly sample a portion of the data (say 10%) to get a rough idea of how it looks.


In [5]:
numeric_features = [t[0] for t in CV_data.dtypes if t[1] == 'int' or t[1] == 'double']

sampled_data = CV_data.select(numeric_features).sample(False, 0.10).toPandas()

axs = pd.scatter_matrix(sampled_data, figsize=(12, 12));

# Rotate axis labels and remove axis ticks
n = len(sampled_data.columns)
for i in range(n):
    v = axs[i, 0]
    v.yaxis.label.set_rotation(0)
    v.yaxis.label.set_ha('right')
    v.set_yticks(())
    h = axs[n-1, i]
    h.xaxis.label.set_rotation(90)
    h.set_xticks(())


It's obvious that there are several highly correlated fields, ie Total day minutes and Total day charge. Such correlated data won't be very beneficial for our model training runs, so we're going to remove them. We'll do so by dropping one column of each pair of correlated fields, along with the State and Area code columns.

While we're in the process of manipulating the data sets, let's transform the categorical data into numeric as required by the machine learning routines, using a simple user-defined function that maps Yes/True and No/False to 1 and 0, respectively.


In [37]:
from pyspark.sql.types import DoubleType,StringType
from pyspark.sql.functions import UserDefinedFunction
binary_map = {'Yes':1.0, 'No':0.0, 'True':1.0, 'False':0.0, True:1.0, False:0.0}
toNum = UserDefinedFunction(lambda k: binary_map[k], DoubleType())
from pyspark.sql.functions import udf
def gort(k):
    if k == 'Yes' or k == "True" or k == True:
        return 1.0
    elif k == "No" or k == "False" or k == False:
        return 0.0
    else: return k
gort("hi")
udfGort = udf(gort, DoubleType())

# CV_data.withColumn("Churn2", udfGort("Churn")).toPandas()
CV_data.withColumn("Churn2", toNum("Churn")).toPandas()


Out[37]:
State Account length Area code International plan Voice mail plan Number vmail messages Total day minutes Total day calls Total day charge Total eve minutes ... Total eve charge Total night minutes Total night calls Total night charge Total intl minutes Total intl calls Total intl charge Customer service calls Churn Churn2
0 KS 128 415 No Yes 25 265.1 110 45.07 197.4 ... 16.78 244.7 91 11.01 10.0 3 2.70 1 False 0.0
1 OH 107 415 No Yes 26 161.6 123 27.47 195.5 ... 16.62 254.4 103 11.45 13.7 3 3.70 1 False 0.0
2 NJ 137 415 No No 0 243.4 114 41.38 121.2 ... 10.30 162.6 104 7.32 12.2 5 3.29 0 False 0.0
3 OH 84 408 Yes No 0 299.4 71 50.90 61.9 ... 5.26 196.9 89 8.86 6.6 7 1.78 2 False 0.0
4 OK 75 415 Yes No 0 166.7 113 28.34 148.3 ... 12.61 186.9 121 8.41 10.1 3 2.73 3 False 0.0
5 AL 118 510 Yes No 0 223.4 98 37.98 220.6 ... 18.75 203.9 118 9.18 6.3 6 1.70 0 False 0.0
6 MA 121 510 No Yes 24 218.2 88 37.09 348.5 ... 29.62 212.6 118 9.57 7.5 7 2.03 3 False 0.0
7 MO 147 415 Yes No 0 157.0 79 26.69 103.1 ... 8.76 211.8 96 9.53 7.1 6 1.92 0 False 0.0
8 WV 141 415 Yes Yes 37 258.6 84 43.96 222.0 ... 18.87 326.4 97 14.69 11.2 5 3.02 0 False 0.0
9 RI 74 415 No No 0 187.7 127 31.91 163.4 ... 13.89 196.0 94 8.82 9.1 5 2.46 0 False 0.0
10 IA 168 408 No No 0 128.8 96 21.90 104.9 ... 8.92 141.1 128 6.35 11.2 2 3.02 1 False 0.0
11 MT 95 510 No No 0 156.6 88 26.62 247.6 ... 21.05 192.3 115 8.65 12.3 5 3.32 3 False 0.0
12 IA 62 415 No No 0 120.7 70 20.52 307.2 ... 26.11 203.0 99 9.14 13.1 6 3.54 4 False 0.0
13 ID 85 408 No Yes 27 196.4 139 33.39 280.9 ... 23.88 89.3 75 4.02 13.8 4 3.73 1 False 0.0
14 VT 93 510 No No 0 190.7 114 32.42 218.2 ... 18.55 129.6 121 5.83 8.1 3 2.19 3 False 0.0
15 VA 76 510 No Yes 33 189.7 66 32.25 212.8 ... 18.09 165.7 108 7.46 10.0 5 2.70 1 False 0.0
16 TX 73 415 No No 0 224.4 90 38.15 159.5 ... 13.56 192.8 74 8.68 13.0 2 3.51 1 False 0.0
17 FL 147 415 No No 0 155.1 117 26.37 239.7 ... 20.37 208.8 133 9.40 10.6 4 2.86 0 False 0.0
18 CO 77 408 No No 0 62.4 89 10.61 169.9 ... 14.44 209.6 64 9.43 5.7 6 1.54 5 True 1.0
19 AZ 130 415 No No 0 183.0 112 31.11 72.9 ... 6.20 181.8 78 8.18 9.5 19 2.57 0 False 0.0
20 VA 132 510 No No 0 81.1 86 13.79 245.2 ... 20.84 237.0 115 10.67 10.3 2 2.78 0 False 0.0
21 NE 174 415 No No 0 124.3 76 21.13 277.1 ... 23.55 250.7 115 11.28 15.5 5 4.19 3 False 0.0
22 WY 57 408 No Yes 39 213.0 115 36.21 191.1 ... 16.24 182.7 115 8.22 9.5 3 2.57 0 False 0.0
23 MT 54 408 No No 0 134.3 73 22.83 155.5 ... 13.22 102.1 68 4.59 14.7 4 3.97 3 False 0.0
24 MO 20 415 No No 0 190.0 109 32.30 258.2 ... 21.95 181.5 102 8.17 6.3 6 1.70 0 False 0.0
25 IL 142 415 No No 0 84.8 95 14.42 136.7 ... 11.62 250.5 148 11.27 14.2 6 3.83 2 False 0.0
26 NH 75 510 No No 0 226.1 105 38.44 201.5 ... 17.13 246.2 98 11.08 10.3 5 2.78 1 False 0.0
27 LA 172 408 No No 0 212.0 121 36.04 31.2 ... 2.65 293.3 78 13.20 12.6 10 3.40 3 False 0.0
28 AZ 12 408 No No 0 249.6 118 42.43 252.4 ... 21.45 280.2 90 12.61 11.8 3 3.19 1 True 1.0
29 OK 57 408 No Yes 25 176.8 94 30.06 195.0 ... 16.58 213.5 116 9.61 8.3 4 2.24 0 False 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
2636 ME 148 408 No Yes 33 241.7 84 41.09 165.8 ... 14.09 160.6 80 7.23 11.3 3 3.05 1 False 0.0
2637 MD 83 510 No No 0 78.1 70 13.28 239.3 ... 20.34 144.4 112 6.50 12.3 4 3.32 1 False 0.0
2638 NC 73 408 No No 0 187.8 95 31.93 149.2 ... 12.68 201.4 113 9.06 11.0 4 2.97 2 False 0.0
2639 SC 111 415 No Yes 21 127.1 94 21.61 228.3 ... 19.41 166.7 108 7.50 7.1 3 1.92 1 False 0.0
2640 LA 75 510 Yes No 0 153.2 78 26.04 210.8 ... 17.92 153.5 100 6.91 7.8 3 2.11 1 False 0.0
2641 IL 71 510 Yes No 0 186.1 114 31.64 198.6 ... 16.88 206.5 80 9.29 13.8 5 3.73 4 True 1.0
2642 IN 58 415 No Yes 22 224.1 127 38.10 238.8 ... 20.30 174.2 86 7.84 11.5 7 3.11 2 False 0.0
2643 OK 172 408 No No 0 203.9 109 34.66 234.0 ... 19.89 160.7 65 7.23 17.8 4 4.81 4 False 0.0
2644 IA 45 415 No No 0 211.3 87 35.92 165.7 ... 14.08 265.9 72 11.97 13.3 6 3.59 1 False 0.0
2645 VT 100 408 Yes No 0 219.4 112 37.30 225.7 ... 19.18 255.3 95 11.49 12.0 4 3.24 4 False 0.0
2646 NY 94 415 No No 0 190.4 91 32.37 92.0 ... 7.82 224.8 108 10.12 13.6 17 3.67 2 False 0.0
2647 LA 128 415 No No 0 147.7 94 25.11 283.3 ... 24.08 188.3 124 8.47 6.9 5 1.86 2 False 0.0
2648 SC 181 408 No No 0 229.9 130 39.08 144.4 ... 12.27 262.4 110 11.81 14.2 4 3.83 2 False 0.0
2649 ID 127 408 No No 0 102.8 128 17.48 143.7 ... 12.21 191.4 97 8.61 10.0 5 2.70 1 False 0.0
2650 MO 89 415 No No 0 178.7 81 30.38 233.7 ... 19.86 131.9 120 5.94 9.1 4 2.46 1 False 0.0
2651 ME 149 415 No Yes 18 148.5 106 25.25 114.5 ... 9.73 178.3 98 8.02 6.5 4 1.76 0 False 0.0
2652 MS 103 510 No Yes 29 164.1 111 27.90 219.1 ... 18.62 220.3 108 9.91 12.3 9 3.32 0 False 0.0
2653 SD 163 415 Yes No 0 197.2 90 33.52 188.5 ... 16.02 211.1 94 9.50 7.8 8 2.11 1 False 0.0
2654 OK 52 415 No No 0 124.9 131 21.23 300.5 ... 25.54 192.5 106 8.66 11.6 4 3.13 2 False 0.0
2655 WY 89 415 No No 0 115.4 99 19.62 209.9 ... 17.84 280.9 112 12.64 15.9 6 4.29 3 False 0.0
2656 GA 122 510 Yes No 0 140.0 101 23.80 196.4 ... 16.69 120.1 133 5.40 9.7 4 2.62 4 True 1.0
2657 MD 62 408 No No 0 321.1 105 54.59 265.5 ... 22.57 180.5 72 8.12 11.5 2 3.11 4 True 1.0
2658 IN 117 415 No No 0 118.4 126 20.13 249.3 ... 21.19 227.0 56 10.22 13.6 3 3.67 5 True 1.0
2659 OH 78 408 No No 0 193.4 99 32.88 116.9 ... 9.94 243.3 109 10.95 9.3 4 2.51 2 False 0.0
2660 OH 96 415 No No 0 106.6 128 18.12 284.8 ... 24.21 178.9 92 8.05 14.9 7 4.02 1 False 0.0
2661 SC 79 415 No No 0 134.7 98 22.90 189.7 ... 16.12 221.4 128 9.96 11.8 5 3.19 2 False 0.0
2662 AZ 192 415 No Yes 36 156.2 77 26.55 215.5 ... 18.32 279.1 83 12.56 9.9 6 2.67 2 False 0.0
2663 WV 68 415 No No 0 231.1 57 39.29 153.4 ... 13.04 191.3 123 8.61 9.6 4 2.59 3 False 0.0
2664 RI 28 510 No No 0 180.8 109 30.74 288.8 ... 24.55 191.9 91 8.64 14.1 6 3.81 2 False 0.0
2665 TN 74 415 No Yes 25 234.4 113 39.85 265.9 ... 22.60 241.4 77 10.86 13.7 4 3.70 0 False 0.0

2666 rows × 21 columns


In [21]:
gort("Nop")


Out[21]:
'Nop'

In [41]:
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import UserDefinedFunction

binary_map = {'Yes':1.0, 'No':0.0, 'True':1.0, 'False':0.0, True:1.0, False:0.0}
toNum = UserDefinedFunction(lambda k: binary_map[k], DoubleType())

CV_data = CV_data.drop('State').drop('Area code') \
    .drop('Total day charge').drop('Total eve charge') \
    .drop('Total night charge').drop('Total intl charge') \
    .withColumn('Churn', toNum('Churn')) \
    .withColumn('International plan', toNum('International plan')) \
    .withColumn('Voice mail plan', toNum('Voice mail plan')).cache()

final_test_data = final_test_data.drop('State').drop('Area code') \
    .drop('Total day charge').drop('Total eve charge') \
    .drop('Total night charge').drop('Total intl charge') \
    .withColumn('Churn', toNum(final_test_data['Churn'])) \
    .withColumn('International plan', toNum('International plan')) \
    .withColumn('Voice mail plan', toNum('Voice mail plan')).cache()

Let's take a quick look at the resulting data set.


In [44]:
# CV_data.toPandas()
# final_test_data.toPandas()

In [47]:
pd.DataFrame(CV_data.take(5), columns=CV_data.columns)


Out[47]:
Account length International plan Voice mail plan Number vmail messages Total day minutes Total day calls Total eve minutes Total eve calls Total night minutes Total night calls Total intl minutes Total intl calls Customer service calls Churn
0 128 0.0 1.0 25 265.1 110 197.4 99 244.7 91 10.0 3 1 0.0
1 107 0.0 1.0 26 161.6 123 195.5 103 254.4 103 13.7 3 1 0.0
2 137 0.0 0.0 0 243.4 114 121.2 110 162.6 104 12.2 5 0 0.0
3 84 1.0 0.0 0 299.4 71 61.9 88 196.9 89 6.6 7 2 0.0
4 75 1.0 0.0 0 166.7 113 148.3 122 186.9 121 10.1 3 3 0.0

Using the Spark MLlib Package

The MLlib package provides a variety of machine learning algorithms for classification, regression, cluster and dimensionality reduction, as well as utilities for model evaluation. The decision tree is a popular classification algorithm, and we'll be using extensively here.

Decision Tree Models

Decision trees have played a significant role in data mining and machine learning since the 1960's. They generate white-box classification and regression models which can be used for feature selection and sample prediction. The transparency of these models is a big advantage over black-box learners, because the models are easy to understand and interpret, and they can be readily extracted and implemented in any programming language (with nested if-else statements) for use in production environments. Furthermore, decision trees require almost no data preparation (ie normalization) and can handle both categorical and continuous data. To remedy over-fitting and improve prediction accuracy, decision trees can also be limited to a certain depth or complexity, or bundled into ensembles of trees (ie random forests).

A decision tree is a predictive model which maps observations (features) about an item to conclusions about the item's label or class. The model is generated using a top-down approach, where the source dataset is split into subsets using a statistical measure, often in the form of the Gini index or information gain via Shannon entropy. This process is applied recursively until a subset contains only samples with the same target class, or is halted by a predefined stopping criteria.

Model Training

MLlib classifiers and regressors require data sets in a format of rows of type LabeledPoint, which separates row labels and feature lists, and names them accordingly. The custom labelData() function shown below performs the row parsing. We'll pass it the prepared data set (CV_data) and split it further into training and testing sets. A decision tree classifier model is then generated using the training data, using a maxDepth of 2, to build a "shallow" tree. The tree depth can be regarded as an indicator of model complexity.


In [48]:
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree

def labelData(data):
    # label: row[end], features: row[0:end-1]
    return data.map(lambda row: LabeledPoint(row[-1], row[:-1]))

training_data, testing_data = labelData(CV_data).randomSplit([0.8, 0.2])

model = DecisionTree.trainClassifier(training_data, numClasses=2, maxDepth=2,
                                     categoricalFeaturesInfo={1:2, 2:2},
                                     impurity='gini', maxBins=32)

print model.toDebugString()


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-48-cc971f6e5287> in <module>()
      6     return data.map(lambda row: LabeledPoint(row[-1], row[:-1]))
      7 
----> 8 training_data, testing_data = labelData(CV_data).randomSplit([0.8, 0.2])
      9 
     10 model = DecisionTree.trainClassifier(training_data, numClasses=2, maxDepth=2,

<ipython-input-48-cc971f6e5287> in labelData(data)
      4 def labelData(data):
      5     # label: row[end], features: row[0:end-1]
----> 6     return data.map(lambda row: LabeledPoint(row[-1], row[:-1]))
      7 
      8 training_data, testing_data = labelData(CV_data).randomSplit([0.8, 0.2])

/opt/cloudera/parcels/SPARK2-2.0.0.cloudera.beta2-1.cdh5.7.0.p0.110234/lib/spark2/python/pyspark/sql/dataframe.py in __getattr__(self, name)
    839         if name not in self.columns:
    840             raise AttributeError(
--> 841                 "'%s' object has no attribute '%s'" % (self.__class__.__name__, name))
    842         jc = self._jdf.apply(name)
    843         return Column(jc)

AttributeError: 'DataFrame' object has no attribute 'map'

The toDebugString() function provides a print of the tree's decision nodes and final prediction outcomes at the end leafs. We can see that features 12 and 4 are used for decision making and should thus be considered as having high predictive power to determine a customer's likeliness to churn. It's not surprising that these feature numbers map to the fields Customer service calls and Total day minutes. Decision trees are often used for feature selection because they provide an automated mechanism for determining the most important features (those closest to the tree root).


In [ ]:
print 'Feature 12:', CV_data.columns[12]
print 'Feature 4: ', CV_data.columns[4]

Model Evaluation

Predictions of the testing data's churn outcome are made with the model's predict() function and grouped together with the actual churn label of each customer data using getPredictionsLabels().

We'll use MLlib's MulticlassMetrics() for the model evaluation, which takes rows of (prediction, label) tuples as input. It provides metrics such as precision, recall, F1 score and confusion matrix, which have been bundled for printing with the custom printMetrics() function.


In [ ]:
from pyspark.mllib.evaluation import MulticlassMetrics

def getPredictionsLabels(model, test_data):
    predictions = model.predict(test_data.map(lambda r: r.features))
    return predictions.zip(test_data.map(lambda r: r.label))

def printMetrics(predictions_and_labels):
    metrics = MulticlassMetrics(predictions_and_labels)
    print 'Precision of True ', metrics.precision(1)
    print 'Precision of False', metrics.precision(0)
    print 'Recall of True    ', metrics.recall(1)
    print 'Recall of False   ', metrics.recall(0)
    print 'F-1 Score         ', metrics.fMeasure()
    print 'Confusion Matrix\n', metrics.confusionMatrix().toArray()

predictions_and_labels = getPredictionsLabels(model, testing_data)

printMetrics(predictions_and_labels)

The overall accuracy, ie F-1 score, seems quite good, but one troubling issue is the discrepancy between the recall measures. The recall (aka sensitivity) for the Churn=False samples is high, while the recall for the Churn=True examples is relatively low. Business decisions made using these predictions will be used to retain the customers most likely to leave, not those who are likely to stay. Thus, we need to ensure that our model is sensitive to the Churn=True samples.

Perhaps the model's sensitivity bias toward Churn=False samples is due to a skewed distribution of the two types of samples. Let's try grouping the CV_data DataFrame by the Churn field and counting the number of instances in each group.


In [ ]:
CV_data.groupby('Churn').count().toPandas()

Stratified Sampling

There are roughly 6 times as many False churn samples as True churn samples. We can put the two sample types on the same footing using stratified sampling. The DataFrames sampleBy() function does this when provided with fractions of each sample type to be returned.

Here we're keeping all instances of the Churn=True class, but downsampling the Churn=False class to a fraction of 388/2278.


In [ ]:
stratified_CV_data = CV_data.sampleBy('Churn', fractions={0: 388./2278, 1: 1.0}).cache()

stratified_CV_data.groupby('Churn').count().toPandas()

Let's build a new model using the evenly distributed data set and see how it performs.


In [ ]:
training_data, testing_data = labelData(stratified_CV_data).randomSplit([0.8, 0.2])

model = DecisionTree.trainClassifier(training_data, numClasses=2, maxDepth=2,
                                     categoricalFeaturesInfo={1:2, 2:2},
                                     impurity='gini', maxBins=32)

predictions_and_labels = getPredictionsLabels(model, testing_data)
printMetrics(predictions_and_labels)

With these new recall values, we can see that the stratified data was helpful in building a less biased model, which will ultimately provide more generalized and robust predictions.

Using the Spark ML Package

The ML package is the newer library of machine learning routines. It provides an API for pipelining data transformers, estimators and model selectors. We'll use it here to perform cross-validation across several decision trees with various maxDepth parameters in order to find the optimal model.

Pipelining

The ML package needs data be put in a (label: Double, features: Vector) DataFrame format with correspondingly named fields. The vectorizeData() function below performs this formatting.

Next we'll pass the data through a pipeline of two transformers, StringIndexer() and VectorIndexer() which index the label and features fields respectively. Indexing categorical features allows decision trees to treat categorical features appropriately, improving performance. The final element in our pipeline is an estimator (a decision tree classifier) training on the indexed labels and features.

Model Selection

Given the data set at hand, we would like to determine which parameter values of the decision tree produce the best model. We need a systematic approach to quantatively measure the performance of the models and ensure that the results are reliable. This task of model selection is often done using cross validation techniques. A common technique is k-fold cross validation, where the data is randomly split into k partitions. Each partition is used once as the testing data set, while the rest are used for training. Models are then generated using the training sets and evaluated with the testing sets, resulting in k model performance measurements. The average of the performance scores is often taken to be the overall score of the model, given its build parameters.

For model selection we can search through the model parameters, comparing their cross validation performances. The model parameters leading to the highest performance metric produce the best model.

The ML package supports k-fold cross validation, which can be readily coupled with a parameter grid builder and an evaluator to construct a model selection workflow. Below, we'll use a transformation/estimation pipeline to train our models. The cross validator will use the ParamGridBuilder to iterate through the maxDepth parameter of the decision tree and evaluate the models using the F1-score, repeating 3 times per parameter value for reliable results.


In [ ]:
from pyspark.mllib.linalg import Vectors
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

def vectorizeData(data):
    return data.map(lambda r: [r[-1], Vectors.dense(r[:-1])]).toDF(['label','features'])

vectorized_CV_data = vectorizeData(stratified_CV_data)

# Index labels, adding metadata to the label column
labelIndexer = StringIndexer(inputCol='label',
                             outputCol='indexedLabel').fit(vectorized_CV_data)

# Automatically identify categorical features and index them
featureIndexer = VectorIndexer(inputCol='features',
                               outputCol='indexedFeatures',
                               maxCategories=2).fit(vectorized_CV_data)

# Train a DecisionTree model
dTree = DecisionTreeClassifier(labelCol='indexedLabel', featuresCol='indexedFeatures')

# Chain indexers and tree in a Pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dTree])

# Search through decision tree's maxDepth parameter for best model
paramGrid = ParamGridBuilder().addGrid(dTree.maxDepth, [2,3,4,5,6,7]).build()

# Set F-1 score as evaluation metric for best model selection
evaluator = MulticlassClassificationEvaluator(labelCol='indexedLabel',
                                              predictionCol='prediction', metricName='f1')    

# Set up 3-fold cross validation
crossval = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=evaluator,
                          numFolds=3)

CV_model = crossval.fit(vectorized_CV_data)

# Fetch best model
tree_model = CV_model.bestModel.stages[2]
print tree_model

We find that the best tree model produced using the cross-validation process is one with a depth of 5. So we can assume that our initial "shallow" tree of depth 2 in the previous section was not complex enough, while trees of depth higher than 5 overfit the data and will not perform well in practice.

Predictions and Model Evaluation

The actual performance of the model can be determined using the final_test_data set which has not been used for any training or cross-validation activities. We'll transform the test set with the model pipeline, which will map the labels and features according to the same recipe. The evaluator will provide us with the F-1 score of the predictions, and then we'll print them along with their probabilities. Predictions on new, unlabeled customer activity data can also be made using the same pipeline CV_model.transform() function.


In [ ]:
vectorized_test_data = vectorizeData(final_test_data)

transformed_data = CV_model.transform(vectorized_test_data)
print evaluator.getMetricName(), 'accuracy:', evaluator.evaluate(transformed_data)

predictions = transformed_data.select('indexedLabel', 'prediction', 'probability')
predictions.toPandas().head()

The prediction probabilities can be very useful in ranking customers by their likeliness to defect. This way, the limited resources available to the business for retention can be focused on the appropriate customers.

Thank you for reading and I hope this tutorial was helpful. You can find me on Twitter @BenSadeghi.