In this lab, we're going to look at association rule mining for grocery data. At the end of the lab, you should be able to:
Let's start by importing the packages we'll need. As usual, we'll import pandas
for exploratory analysis, but this week we're also going to use pyspark
, a Python package that wraps Apache Spark and makes its functionality available in Python. Spark also supports frequent itemset generation using the FPGrowth algorithm, so we'll import this functionality too.
In [ ]:
import itertools
import pandas as pd
import pyspark
from pyspark.ml.fpm import FPGrowth
from pyspark.sql.functions import split
First, let's initialise a SparkContext
object, which will represent our connection to the Spark cluster. To do this, we must first specify the URL of the master node to connect to. As we're only running this notebook for demonstration purposes, we can just run the cluster locally, as follows:
In [ ]:
sc = pyspark.SparkContext(master='local[*]')
Note: By specifying
master='local[*]'
, we are instructing Spark to run with as many worker threads as there are logical cores available on the host machine. Alternatively, we could directly specify the number of threads, e.g.master='local[4]'
to run four threads. However, we need to make sure to specify at least two threads, so that there is one available for resource management and at least one available for data processing.
Spark supports reading from CSV files via its SQLContext
object, so let's create this next:
In [ ]:
sql = pyspark.SQLContext(sc)
Next, let's load the data. Write the path to your groceries.csv
file in the cell below:
In [ ]:
path = 'data/groceries.csv'
In [ ]:
df = sql.read.text(path)
Similar to the head
method in pandas
, we can peek at the first few rows of the data frame via its show
method:
In [ ]:
df.show(5, truncate=False) # Show the first five rows, and don't truncate the printout
As can be seen, the data consists of a collection of transactions from a supermarket, where each row corresponds to a transaction and the items in a row correspond to the items that were purchased in that transaction.
Currently, the rows in our data frame are CSV strings. We can see this more clearly using the take
method of the data frame, which gives more detailed information about the data than the high-level show
method above:
In [ ]:
df.take(1) # Take the first row
Before we can mine association rules, we'll need to split these strings into arrays of individual items. We can do this using the split
function from Spark's SQL library, as follows:
In [ ]:
df = df.select(split('value', ',').alias('items')) # Split the values column by comma and label the result as 'items'
df.show(truncate=False)
Next, let's mine our transaction data to find interesting dependencies between itemsets. While there are a number of approaches available for mining frequently occuring itemsets (e.g. Apriori, Eclat), Spark supports the FPGrowth
algorithm directly. To run the algorithm on our set of transactions, we need to specify two parameters:
minSupport
: A minimum support threshold, used to filter out itemsets that don't occur frequently enough.minConfidence
: A minimum confidence threshold, used to filter out association rules that aren't strong enough.Let's set the minimum support level at 1% and the minimum confidence level at 10%. We can then train a model using the fit
method of the FPGrowth
class (in a similar way to using scikit-learn
), as follows:
In [ ]:
algorithm = FPGrowth(minSupport=0.01, minConfidence=0.1)
model = algorithm.fit(df)
We can extract the most frequent itemsets from the model using its freqItemsets
attribute, which is just another data frame object that we can call show
on:
In [ ]:
model.freqItemsets.show(10, truncate=False)
We can print the top ten most frequent itemsets by sorting the data frame before calling show
, as follows:
In [ ]:
model.freqItemsets.sort('freq', ascending=False).show(10, truncate=False)
We can determing the total number of frequent itemsets found by counting the rows in the data frame via its count
method:
In [ ]:
model.freqItemsets.count()
As can be seen, the FPGrowth algorithm has identified 332 frequent itemsets in the transaction history.
We can extract association rules from the model using its associationRules
attribute, which is a further data frame object that we can call show on. As above, we can sort the data frame according to the computed confidence level to show the most significant rules first.
In [ ]:
model.associationRules.sort('confidence', ascending=False).show(10, truncate=False)