Apache Spark and window functions

Versions: Apache Spark 2.3.1 https://github.com/bartosz25/spark-...rcode/sql/WindowFunctionsTest.scala

One of previous posts in SQL category presented window functions that can be used to compute values per grouped rows. These analytics functions are also available in Apache Spark SQL.

The post is organized in 3 parts. The first one lists all window functions available in Apache Spark 2.3.1 and shows their use through small code samples. The next one focuses on the execution plan of such queries by explaining 3 main components of physical execution. Finally, the last section digs a little bit deeper and presents some internal details of the computation.

Window functions list

Window functions were pretty clearly described in the post about window functions in SQL. Here we'll just recall that they're the functions applied on rows logically grouped in different window frames. In order to illustrate them in Apache Spark SQL we'll take an example of the last football World Cup (2018) and the list of the best scorers and assist makers for 4 countries (France, Russia, Belgium and England). The DataFrame used in the code snippets looks like:

import sparkSession.implicits._
private val WorldCupPlayers = Seq(
  Player("Harry Kane", "England", 6, 0),
  Player("Cristiano Ronaldo", "Portugal", 4, 1),
  Player("Antoine Griezmann", "France", 4, 2),
  Player("Romelu Lukaku", "Belgium", 4, 1),
  Player("Denis Cheryshev ", "Russia", 4, 0),
  Player("Kylian MBappe", "France", 4, 0),
  Player("Eden Hazard", "Belgium", 3, 2),
  Player("Artem Dzuyba", "Russia", 3, 2),
  Player("John Stones", "England", 2, 0),
  Player("Kevin De Bruyne", "Belgium", 1, 3),
  Player("Aleksandr Golovin", "Russia", 1, 2),
  Player("Paul Pogba", "France", 1, 0),
  Player("Pepe", "Portugal", 1, 0),
  Player("Ricardo Quaresma", "Portugal", 1, 0),
  Player("Dele Alli", "England", 1, 0)
).toDF()

To see the list of available window functions we can go throughout org.apache.spark.sql.functions and look for the methods annotated with @group window_funcs. It'll return us the list similar to this one:

As you can see throughout above examples, window functions are constructed with org.apache.spark.sql.expressions.Window object exposing all required factory functions: partitionBy(cols: Column*) and orderBy(cols: Column*). They build an instance of org.apache.spark.sql.expressions.WindowSpec that is later used in select expressions.

Execution plan

The execution plan for one of the presented examples looks like:

== Physical Plan ==
Window [lead(name#4, 1, -) windowspecdefinition(team#5, goals#6 DESC NULLS LAST, specifiedwindowframe(RowFrame, 1, 1)) AS next_scorer#13], [team#5], [goals#6 DESC NULLS LAST]
+- *(1) Sort [team#5 ASC NULLS FIRST, goals#6 DESC NULLS LAST], false, 0
   +- Exchange hashpartitioning(team#5, 200)
      +- LocalTableScan [name#4, team#5, goals#6, assists#7]

It gives 3 important insights. The first one is about the data behavior. Unsurprisingly window functions require a shuffle (Exchange hashpartitioning), here partitioned by team field. The second point is sorting. It's a required step to make the most of window functions working. Already these 2 properties show that executing window functions can be expensive in terms of computation time and resources. The last important concept to define is Window operator itself, displayed at the top of the plan. As you can see, it's composed of 3 elements:

Window internals

The class responsible for the window functions execution is WindowExec. It, and more exactly its doExecute() method gives some insight about windowed functions execution. The processing consists on applying org.apache.spark.rdd.RDD#mapPartitions([U: ClassTag](f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false)) on all partitioned buckets. The iterator returned by this method jumps from one partition group to another and for each item applies all of defined window frames:

val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
val frames = factories.map(_(windowFunctionResult))
val numFrames = frames.length

override final def next(): InternalRow = {
  // Load the next partition if we need to.
  if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
    fetchNextPartition()
  }
  // ...
  val current = bufferIterator.next()

    // Get the results for the window frames.
    var i = 0
    while (i < numFrames) {
      frames(i).write(rowIndex, current)
      i += 1
    }

    // 'Merge' the input row with the window function result
    join(current, windowFunctionResult)
    rowIndex += 1

    // Return the projection.
    result(join)
  // ...

The "how" to compute the frames is handled by windowFrameExpressionFactoryPairs returning a frame expression with corresponding factory method creating the computation. And for 2 previously described frame boundaries, we can adapt 5 frames:

Apache Spark analytical window functions look similar to the aggregations applied on groups. As shown in the first section, these functions have a lot of common points with SQL-oriented ones. We retrieve among them lead, lag, rank, ntile and so forth. Even their physical execution is similar to the grouped aggregations. After all it starts by shuffling all rows with the same partitioning key to the same Apache Spark's partition. However, as shown in the 3rd section talking about internals, window analytical functions are little bit more flexible thanks to 5 different frames they provide.