Distributed Cumulative Sum

Getting spark up and running


In [ ]:
classpath.add(
  "org.apache.spark" %% "spark-core" % "2.0.2"
);

In [2]:
import org.apache.spark.sql.{SparkSession, DataFrame, Dataset}


import org.apache.spark.sql.{SparkSession, DataFrame, Dataset}

In [ ]:
val spark = SparkSession.builder().master("local[*]").getOrCreate()

In [49]:
// Let's assume we have some sorted data that we want to calculate the cumulative sum for
val data = Seq(1, 2, 3, 4, 5)

// Here's the expected cumulative sum
val expected = Seq(1, 3, 6, 10, 15)

// If this was a local Iterator we could just use scanLeft
val local = data.scanLeft(0)(_ + _).drop(1)

// But what if it's distributed?
val rdd = spark.sparkContext.parallelize(data)


data: Seq[Int] = List(1, 2, 3, 4, 5)
expected: Seq[Int] = List(1, 3, 6, 10, 15)
local: Seq[Int] = List(1, 3, 6, 10, 15)
rdd: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[39] at parallelize at Main.scala:36

Method 1


In [47]:
// Calculate the sum per partition
val x = rdd.mapPartitionsWithIndex { (index, partition) =>
    Iterator((index, partition.sum))
}.collect().toMap


x: Map[Int, Int] = Map(0 -> 0, 5 -> 0, 1 -> 1, 6 -> 4, 2 -> 0, 7 -> 5, 3 -> 2, 4 -> 3)

In [48]:
rdd.mapPartitionsWithIndex { (index, partition) =>
    // For each partition calculate the sum of all the previous partitions
    val sums = (0 until index).map(x).sum
    
    // Scan left starting with the cumulative sum for all previous partitions
    partition.scanLeft(sums)(_ + _).drop(1)
}.collect()


res47: Array[Int] = Array(1, 3, 6, 10, 15)

Method 2


In [45]:
// Calculate the cumulative sum at each partition index once
val x = rdd.mapPartitionsWithIndex { (index, partition) =>
    Iterator((index, partition.sum))
}.collect().scanLeft((0, 0))((a, b) => (b._1, a._2 + b._2)).toMap


x: Map[Int, Int] = Map(0 -> 0, 5 -> 6, 1 -> 1, 6 -> 10, 2 -> 1, 7 -> 15, 3 -> 3, 4 -> 6)

In [46]:
rdd.mapPartitionsWithIndex { (index, partition) =>
    partition.scanLeft(x.getOrElse(index - 1, 0) )(_ + _).drop(1)
}.collect()


res45: Array[Int] = Array(1, 3, 6, 10, 15)

Thanks to jmorra for teaching me this.