In [ ]:
classpath.add(
"org.apache.spark" %% "spark-core" % "2.0.2"
);
In [2]:
import org.apache.spark.sql.{SparkSession, DataFrame, Dataset}
In [ ]:
val spark = SparkSession.builder().master("local[*]").getOrCreate()
In [49]:
// Let's assume we have some sorted data that we want to calculate the cumulative sum for
val data = Seq(1, 2, 3, 4, 5)
// Here's the expected cumulative sum
val expected = Seq(1, 3, 6, 10, 15)
// If this was a local Iterator we could just use scanLeft
val local = data.scanLeft(0)(_ + _).drop(1)
// But what if it's distributed?
val rdd = spark.sparkContext.parallelize(data)
In [47]:
// Calculate the sum per partition
val x = rdd.mapPartitionsWithIndex { (index, partition) =>
Iterator((index, partition.sum))
}.collect().toMap
In [48]:
rdd.mapPartitionsWithIndex { (index, partition) =>
// For each partition calculate the sum of all the previous partitions
val sums = (0 until index).map(x).sum
// Scan left starting with the cumulative sum for all previous partitions
partition.scanLeft(sums)(_ + _).drop(1)
}.collect()
In [45]:
// Calculate the cumulative sum at each partition index once
val x = rdd.mapPartitionsWithIndex { (index, partition) =>
Iterator((index, partition.sum))
}.collect().scanLeft((0, 0))((a, b) => (b._1, a._2 + b._2)).toMap
In [46]:
rdd.mapPartitionsWithIndex { (index, partition) =>
partition.scanLeft(x.getOrElse(index - 1, 0) )(_ + _).drop(1)
}.collect()
Thanks to jmorra for teaching me this.