Wednesday, September 6, 2017

Apache Spark - Cumulative Aggregations over Groups using Window Functions

There are many tasks you want to perform over a data set that correspond to some form of aggregation over groups of your data, but in a way that takes some kind of ordering into consideration.

For example calculating the cumulative total sales per salesperson over the course of each financial year. You want the data grouped by salesperson and the financial year, but then transformed to calculate the cumulative total up to and including each sale.

These kinds of functions are possible in SQL syntax, but they are complicated, and difficult to read. Apache Spark by comparison has a very elegant and easy to use API for generating these kinds of results.

Here is a simple example in which we want the average sales per salesperson, leading up to and including their current sale.

First lets create a dummy data set and look at it. (In reality you will usually be loading an enormous set from your cluster, but this lets us experiment with the API).

val redy = sc.parallelize(
    Seq((201601, "Jane", 10), (201602, "Tim", 20), (201603, "Jane", 30),
        (201604, "Tim", 40), (201605, "Jane", 50), (201606, "Tim", 60),
        (201607, "Jane", 70), (201608, "Jane", 80), (201609, "Tim", 90),
        (201610, "Tim",  100), (201611, "Jane", 110), (201612, "Tim", 120)
    )
)

case class X(id: Int, name: String, sales: Int)
val redy2 = redy.map( in => X(in._1, in._2, in._3) )
val df = sqlContext.createDataFrame(redy2)

df.show

...and this it what it looks like

+------+----+-----+
|    id|name|sales|
+------+----+-----+
|201601|Jane|   10|
|201602| Tim|   20|
|201603|Jane|   30|
|201604| Tim|   40|
|201605|Jane|   50|
|201606| Tim|   60|
|201607|Jane|   70|
|201608|Jane|   80|
|201609| Tim|   90|
|201610| Tim|  100|
|201611|Jane|  110|
|201612| Tim|  120|
+------+----+-----+

To add our additional column with the aggregation over the grouped and ordered data, it is as simple as:

df.withColumn("avg_sales", avg(df("sales"))
   .over( Window.partitionBy("name").orderBy("id") )
).show


...which will produce the following output:


+------+----+-----+------------------+
|    id|name|sales|         avg_sales|
+------+----+-----+------------------+
|201602| Tim|   20|              20.0|
|201604| Tim|   40|              30.0|
|201606| Tim|   60|              40.0|
|201609| Tim|   90|              52.5|
|201610| Tim|  100|              62.0|
|201612| Tim|  120| 71.66666666666667|
|201601|Jane|   10|              10.0|
|201603|Jane|   30|              20.0|
|201605|Jane|   50|              30.0|
|201607|Jane|   70|              40.0|
|201608|Jane|   80|              48.0|
|201611|Jane|  110|58.333333333333336|
+------+----+-----+------------------+


Voila. An additional column containing the average sales per salesperson leading up to and including the current sale. You can modify this to change the aggregation function, add additional columns to the grouping or the ordering. It is clean and readable, and fast.


No comments:

Post a Comment