Spark SQL pivot table

Versions: Apache Spark 3.0.1

If you came to data engineering after having a BI career, you certainly know what the pivot is. It was not my case and was quite amazed by this operation that transforms values from rows into columns. If you want to understand how it's possible, this article will present some internals of pivoting data in Apache Spark.

The first part of the blog post will show you the code snippet pivoting a team name column's values. In the second one, you will see how Apache Spark constructs the columns. The last one will shed some light on the remaining parts of this operation.

Pivot example

In the example, you can see an in-memory dataset that will be pivoted by the team name column. Already you can notice an important property of the pivot operation. Yes, it's a group-based operation, so it involves moving the data across the network. But before going into the execution details, let's see the code:

object PivotExample extends App {

  val sparkSession = SparkSession.builder()
    .appName("Pivot example").master("local[*]")
    // Set this to see how Apache Spark prevents the OOM that could be caused
    // by too many distinct pivoted values:
    //.config("spark.sql.pivotMaxValues", 2)
    .getOrCreate()
  import sparkSession.implicits._

  val teams = Seq(
    Team("team1", "France", 3), Team("team2", "Poland", 4), Team("team3", "Germany", 8),
    Team("team4", "France", 3), Team("team5", "Poland", 5), Team("team6", "Germany", 9),
    Team("team7", "France", 3), Team("team5", "Poland", 6), Team("team3", "Germany", 1),
    Team("team1", "France", 3), Team("team1", "Poland", 7), Team("team6", "Germany", 2)
  ).toDF()

  val pivoted = teams.groupBy("country").pivot("name").sum("points")
  pivoted.show(false)

  println("Pivoted data execution plans")
  pivoted.explain(true)
}

case class Team(name: String, country: String, points: Int)

And the generated result looks like that:

+-------+-----+-----+-----+-----+-----+-----+-----+
|country|team1|team2|team3|team4|team5|team6|team7|
+-------+-----+-----+-----+-----+-----+-----+-----+
|Germany|null |null |9    |null |null |11   |null |
|France |6    |null |null |3    |null |null |3    |
|Poland |7    |4    |null |null |11   |null |null |
+-------+-----+-----+-----+-----+-----+-----+-----+

Columns extraction

As you can see from the snippet above, Apache Spark found all teams defined in the dataset and put them into the schema. And there, it doesn't matter if the column was present in the initial dataset for the given grouping expression (only the country column was in our case). The pivoted value will be common for all rows.

But, how does Apache Spark resolve these pivoted columns? The answer is hidden in the RelationalGroupedDataset#pivot(pivotColumn: Column) method that executes...an Apache Spark job retrieving all distinct values for the pivotColumn up to the limit specified in the spark.sql.pivotMaxValues property (defaults to 1000). This property is later used in the OOM prevention mechanism that could happen if the number of the distinct pivoted values would be too high:

    val maxValues = df.sparkSession.sessionState.conf.dataFramePivotMaxValues
    // Get the distinct values of the column and sort them so its consistent
    val values = df.select(pivotColumn)
      .distinct()
      .limit(maxValues + 1)
      .sort(pivotColumn)  // ensure that the output columns are in a consistent logical order
      .collect()
      .map(_.get(0))
      .toSeq

    if (values.length > maxValues) {
      throw new AnalysisException(
        s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
          "this could indicate an error. " +
          s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " +
          "to at least the number of distinct values of the pivot column.")
    }

Pivot analysis

If the number of distinct columns is acceptable, they're passed to another pivot method in RelationalGroupedDataset where a new RelationalGroupedDataset is created with them:

        new RelationalGroupedDataset(
          df,
          groupingExprs,
          RelationalGroupedDataset.PivotType(pivotColumn.expr, valueExprs))

The PivotType is later pattern matched in org.apache.spark.sql.RelationalGroupedDataset#toDF(aggExprs: Seq[Expression]) and after this parse stage, the parsed logical plan looks like that:

== Parsed Logical Plan ==
'Pivot ArrayBuffer(country#4), 'name, [team1, team2, team3, team4, team5, team6, team7], [sum(points#5)]
+- LocalRelation [name#3, country#4, points#5]

As a result of that, you can see a Pivot node with, respectively, the grouping column, the pivoted values and the aggregation to apply on every group. Remember, at this moment the plan is not resolved yet. And the resolution step is the key step in the pivot understanding because after its execution, the Pivot node ... disappears! The Analyzer has a rule dedicated to the pivot called ResolvePivot. ResolvePivot will expand the pivot node into the aggregations according to one of 2 strategies.

The first strategy applies when the aggregation results (sum in our example) are supported by the pivot function. The "supported" means here that the type of the aggregated column is handled by PivotFirst#updateFunction:

  private val updateFunction: PartialFunction[DataType, (InternalRow, Int, Any) => Unit] = {
    case DoubleType =>
      (row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double])
    case IntegerType =>
      (row, offset, value) => row.setInt(offset, value.asInstanceOf[Int])
// ...

Above, you can find a part of the code, but as of this writing, in addition to the double and integers, are supported the columns of LongType, FloatType, BooleanType, ShortType, ByteType, and DecimalType. Let's focus on this first strategy because it's much simpler than the second one that will be explained later. In this first strategy, the function starts by generating the aggregation composed of the columns from the groupBy expression, plus the pivot columns. In our example, this first aggregation will look like that:

Aggregate [country#4, name#3], [country#4, name#3, sum(points#5) AS sum(`points`)#61L]
+- LocalRelation [name#3, country#4, points#5]

After, the second aggregation is created, this time using the groupBy column(s) and PivotFirst aggregations:

val pivotAggs = namedAggExps.map { a =>
  Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues)
    .toAggregateExpression()
  , "__pivot_" + a.sql)()
}
val groupByExprsAttr = groupByExprs.map(_.toAttribute)
val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg)

In our example, it will create the node like:

Aggregate [country#4], [country#4, pivotfirst(name#3, sum(`points`)#61L, team1, team2, team3, team4, team5, team6, team7, 0, 0) AS __pivot_sum(`points`) AS `sum(``points``)`#77]
+- Aggregate [country#4, name#3], [country#4, name#3, sum(points#5) AS sum(`points`)#61L]
   +- LocalRelation [name#3, country#4, points#5]

At the end, the second aggregation is wrapped around a Project node, and the PivotFirst is expanded to every pivot value:

Project [country#4, __pivot_sum(`points`) AS `sum(``points``)`#77[0] AS team1#78L, __pivot_sum(`points`) AS `sum(``points``)`#77[1] AS team2#79L, __pivot_sum(`points`) AS `sum(``points``)`#77[2] AS team3#80L, __pivot_sum(`points`) AS `sum(``points``)`#77[3] AS team4#81L, __pivot_sum(`points`) AS `sum(``points``)`#77[4] AS team5#82L, __pivot_sum(`points`) AS `sum(``points``)`#77[5] AS team6#83L, __pivot_sum(`points`) AS `sum(``points``)`#77[6] AS team7#84L]
+- Aggregate [country#4], [country#4, pivotfirst(name#3, sum(`points`)#61L, team1, team2, team3, team4, team5, team6, team7, 0, 0) AS __pivot_sum(`points`) AS `sum(``points``)`#77]
   +- Aggregate [country#4, name#3], [country#4, name#3, sum(cast(points#5 as bigint)) AS sum(`points`)#61L]
      +- LocalRelation [name#3, country#4, points#5]

And what about the 2nd strategy, the one applied for the data types not supported by the PivotFirst? One of the aggregations you could use to test this scenario is the first(column) that will return the first result for the column, applied on the string column (or any other not supported by the PivotFirst).

In this second strategy, the aggregation expressions will be mapped to equal-safe aliases:

def ifExpr(e: Expression) = {
  If(
    EqualNullSafe(
      pivotColumn,
      Cast(value, pivotColumn.dataType, Some(conf.sessionLocalTimeZone))),
    e, Literal(null))
}
// ...
case First(expr, _) =>
  First(ifExpr(expr), Literal(true))
case Last(expr, _) =>
  Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
  a.withNewChildren(a.children.map(ifExpr))
// ...
Alias(filteredAggregate, outputName(value, aggregate))()

As a result of this operation, a new aggregate safely evaluating the pivoted columns is created:

0 = {Alias@17155} "first(if ((name#3 <=> cast(team1 as string))) points#5 else null, true) AS team1#22"
1 = {Alias@17156} "first(if ((name#3 <=> cast(team2 as string))) points#5 else null, true) AS team2#24"
2 = {Alias@17157} "first(if ((name#3 <=> cast(team3 as string))) points#5 else null, true) AS team3#26"
3 = {Alias@17158} "first(if ((name#3 <=> cast(team4 as string))) points#5 else null, true) AS team4#28"
4 = {Alias@17159} "first(if ((name#3 <=> cast(team5 as string))) points#5 else null, true) AS team5#30"
5 = {Alias@17160} "first(if ((name#3 <=> cast(team6 as string))) points#5 else null, true) AS team6#32"
6 = {Alias@17161} "first(if ((name#3 <=> cast(team7 as string))) points#5 else null, true) AS team7#34"

It can be considered as "check if the pivoted column name is equal to the casted name, if yes, pass it to the first aggregate, otherwise pass null". At the end, for the code snippet below:

val teams = Seq(
  Team("team1", "France", 3), Team("team2", "Poland", 4), Team("team3", "Germany", 8),
  Team("team4", "France", 3), Team("team5", "Poland", 5), Team("team6", "Germany", 9),
  Team("team7", "France", 3), Team("team5", "Poland", 6), Team("team3", "Germany", 1),
  Team("team1", "France", 3), Team("team1", "Poland", 7), Team("team6", "Germany", 2)
)
.map(team => team.toStringPoints)
.toDF()

val pivoted = teams.groupBy("country").pivot("name").agg("points" -> "first")
pivoted.show(false)
pivoted.explain(true)

case class Team(name: String, country: String, points: Int) {
  def toStringPoints = TeamStringPoints(name, country, points.toString)
}
case class TeamStringPoints(name: String, country: String, points: String)

the optimized logical plan will look like:

Aggregate [country#4], [country#4, first(if ((name#3 <=> team1)) points#5 else null, true) AS team1#22, first(if ((name#3 <=> team2)) points#5 else null, true) AS team2#24, first(if ((name#3 <=> team3)) points#5 else null, true) AS team3#26, first(if ((name#3 <=> team4)) points#5 else null, true) AS team4#28, first(if ((name#3 <=> team5)) points#5 else null, true) AS team5#30, first(if ((name#3 <=> team6)) points#5 else null, true) AS team6#32, first(if ((name#3 <=> team7)) points#5 else null, true) AS team7#34]
+- LocalRelation [name#3, country#4, points#5]

Physical execution

The physical execution will follow the plan :P It will then execute as a standard aggregation, with the partial evaluation on the partitions and the final step of partial results merge:

HashAggregate(keys=[country#4], functions=[pivotfirst(name#3, sum(`points`)#21L, team1, team2, team3, team4, team5, team6, team7, 0, 0)], output=[country#4, team1#38L, team2#39L, team3#40L, team4#41L, team5#42L, team6#43L, team7#44L])
+- Exchange hashpartitioning(country#4, 200), true, [id=#117]
   +- HashAggregate(keys=[country#4], functions=[partial_pivotfirst(name#3, sum(`points`)#21L, team1, team2, team3, team4, team5, team6, team7, 0, 0)], output=[country#4, team1#29L, team2#30L, team3#31L, team4#32L, team5#33L, team6#34L, team7#35L])
      +- *(2) HashAggregate(keys=[country#4, name#3], functions=[sum(cast(points#5 as bigint))], output=[country#4, name#3, sum(`points`)#21L])
         +- Exchange hashpartitioning(country#4, name#3, 200), true, [id=#112]
            +- *(1) HashAggregate(keys=[country#4, name#3], functions=[partial_sum(cast(points#5 as bigint))], output=[country#4, name#3, sum#92L])
               +- *(1) LocalTableScan [name#3, country#4, points#5]

An interesting thing to notice is that it starts by executing the final aggregation (sum) for the combination of the aggregation and pivot columns. Only later the pivotfirst aggregation is applied. What does it do exactly? To understand it, let's start by discovering the constructor's signature:

case class PivotFirst(
  pivotColumn: Expression,
  valueColumn: Expression,
  pivotColumnValues: Seq[Any],
  mutableAggBufferOffset: Int = 0,
  inputAggBufferOffset: Int = 0)

It takes as parameters the pivot column, the value column which corresponds to the aggregation executed on top of the pivot column (sum), pivot values (team1, team2, ... team7) and two technical offsets. These parameters are important but the key logic is inside the classs:

val pivotIndex: Map[Any, Int] = if (pivotColumn.dataType.isInstanceOf[AtomicType]) {
  HashMap(pivotColumnValues.zipWithIndex: _*)
} else {
  TreeMap(pivotColumnValues.zipWithIndex: _*)(
    TypeUtils.getInterpretedOrdering(pivotColumn.dataType))
}

val indexSize = pivotIndex.size

private val updateRow: (InternalRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType)

The pivotIndex contains a map between the pivoted values and their indexes, the indexSize knows how many elements are in the map whereas the updateRow is the update function. Every time a row is partially evaluated, the following happens:

val index = pivotIndex.getOrElse(pivotColValue, -1)
if (index >= 0) {
  val value = valueColumn.eval(inputRow)
  if (value != null) {
    updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value)
  }
}

The logic is quite straightforward since it's a simple lookup in the pivotIndex map, generation of the value (sum) and the update accordingly of the value stored in the map. The update is a simple set operation based on the offset passed as a second parameter of the updateRow function. For example, it will execute as follows for the integer type:

case IntegerType =>
  (row, offset, value) => row.setInt(offset, value.asInstanceOf[Int])

Later, during the merge stage of both intermediary buffers, the values are set similar way to the mutableAggBuffer:

for (i <- 0 until indexSize) {
  if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) {
    val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType)
    updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value)
  }
}

I don't know if you agree, but the pivot's operation is very fascinating and surprising. On the one hand, it executes as an aggregation. But on another, it needs to know the data before running. And to discover the data, an intermediary DataFrame is created, and that was something that surprised me. Since I haven't seen many operations executed that way, with an Apache Spark processing in the middle, I will be very glad to discover other ones! If you know them, let's catch 'em all in the comments below this article 👊


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!