Writing custom optimization in Apache Spark SQL - Union rewriter MVP version

Versions: Apache Spark 2.4.0 https://github.com/bartosz25/spark-...ions/PhysicalOptimizationTest.scala

Last time I presented you the basics of code generation in physical plans of Apache Spark SQL. This time I will try to write a physical plan executing UNION operation as a JOIN without code generation.

The post starts by introducing the classes that I will use to implement this custom physical plan. In the second part, I will share some things I learned after playing with the API.

Custom node

My first try used an unchanged logical plan. As you can imagine, it didn't work well since everything were already optimized and the initial DISTINCT UNION operators became a hash aggregate. To disable this I simply added a custom logic optimization rule that replaced DISTINCT UNION nodes by a custom node called UnionToJoinRewritten:

case class UnionToJoinRewritten(left: LogicalPlan, right: LogicalPlan) extends LogicalPlan {
  override def output: Seq[Attribute] = left.output ++ right.output
  override def children: Seq[LogicalPlan] = Seq(left, right)
}

Even though the code is quite short, it already gives interesting feedback. As you can see, the output generated by this node is that it returns all columns for the children nodes. Why? After all, if it's a union, we should only return the columns of one plan and it should be enough. Indeed, but it exposes you to other optimizations which may invalidate your initial physical execution idea. It was my case when I tried this node returning only the attributes of the first child. The projection of the second child, missing in the output part, was simply invalidated by ConvertToLocalRelation rule:

=== Applying Rule org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation ===
 UnionToJoinRewritten                                UnionToJoinRewritten
 :- LocalRelation [letter#7, nr#8, a_flag#9]         :- LocalRelation [letter#7, nr#8, a_flag#9]
!+- Project                                          +- LocalRelation
!   +- LocalRelation [letter#20, nr#21, a_flag#22]   
                 (org.apache.spark.sql.internal.BaseSessionStateBuilder$$anon$2:62)

What was the consequence of this optimization ? The content of the right part of the query was empty. It had the expected number of rows (5) but every row hasn't any value. So doing a join was impossible.

Plan later

The second problem I faced was about SparkPlan of children nodes of the plan. The idea of the custom physical plan was to use RDD's fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) function, so to create an RDD for both sides of the query in SparkPlan's doExecute() implementation.

To create an RDD at this level I had to transform 2 logical plans into 2 physical plans and I had no idea how to do so. I checked then at the source code level and found that SparkStrategy's method called planLater(LogicalPlan) was used very often. The function creates a PlanLater physical execution node that will provide a real execution plan for this node later via one of other existing SparkStrategies.

All the magic happens in plan(LogicalPlan) method of org.apache.spark.sql.catalyst.planning.QueryPlanner implementation which first collects all PlanLater plans and next retrieves the physical plan candidates:

  def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = {
    // Obviously a lot to do here still...

    // Collect physical plan candidates.
    val candidates = strategies.iterator.flatMap(_(plan))

    // The candidates may contain placeholders marked as [[planLater]],
    // so try to replace them by their child plans.
    val plans = candidates.flatMap { candidate =>
      val placeholders = collectPlaceholders(candidate)

      if (placeholders.isEmpty) {
        // Take the candidate as is because it does not contain placeholders.
        Iterator(candidate)
      } else {
        // Plan the logical plan marked as [[planLater]] and replace the placeholders.
// …
}

All this learning led me to the following SparkStrategy:

object UnionToJoinRewrittenStrategy extends SparkStrategy {
  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case union: UnionToJoinRewritten => {
      new UnionJoinExecutorExec(planLater(union.children(0)), planLater(union.children(1))) :: Nil
    }
    case _ => Nil
  }
}

Case class SparkPlan

The implementation of the last part learned me the first things pretty early. I started by implementing it like:

class UnionJoinExecutorExec(left: SparkPlan, right: SparkPlan) extends SparkPlan

The problem with that signature was that I ended up with the physical plan where PlanLater nodes weren't planed:

== Physical Plan ==
UnionJoinExecutor
:- PlanLater LocalRelation [letter#7, nr#8, a_flag#9]
+- PlanLater LocalRelation [letter#20, nr#21, a_flag#22]

And quite logically, it led to the following exception:

java.lang.UnsupportedOperationException was thrown.
java.lang.UnsupportedOperationException
    at org.apache.spark.sql.execution.PlanLater.doExecute(SparkStrategies.scala:58)

When I started to investigate, 3 strange methods caught my attention, productElement, productArity and canEqual. They don't come from Apache Spark contract but from the Scala's one. After that, I looked at built-in strategies and found they were all case classes. Moreover, I found the confirmation in one of the links quoted in "Read also" section. Since I didn't want to deal with such low-level details, I preferred to transform my class into a case class:

case class UnionJoinExecutorExec(left: SparkPlan, right: SparkPlan) extends SparkPlan {
  override protected def doExecute(): RDD[InternalRow] = {
    // covered in the next part
  } 
  override def output: Seq[Attribute] =  left.output  ++ right.output
  override def children: Seq[SparkPlan] = Seq(left, right)
}

Final implementation

As I mentioned at the beginning, the idea behind this custom physical plan was to use RDD's fullOuterJoin method and return matched rows. The implementation is then quite straightforward:

  override protected def doExecute(): RDD[InternalRow] = {
    val leftRdd = left.execute().groupBy(row => {
      row.hashCode()
    })
    val rightRdd = right.execute().groupBy(row => {
      row.hashCode()
    })

    leftRdd.fullOuterJoin(rightRdd).mapPartitions( matchedRows => {
      val rowWriter = new UnsafeRowWriter(6)
      matchedRows.map {
        case (_, (leftRows, rightRows)) => {
          rowWriter.reset()
          rowWriter.zeroOutNullBytes()
          val matchedRow = leftRows.getOrElse(rightRows.get).toSeq.head
          val (letter, nr, flag) = (matchedRow.getUTF8String(0), matchedRow.getInt(1), matchedRow.getInt(2))
          rowWriter.write(0, letter)
          rowWriter.write(1, nr)
          rowWriter.write(2, flag)
          rowWriter.write(3, letter)
          rowWriter.write(4, nr)
          rowWriter.write(5, flag)
          rowWriter.getRow
        }
      }
    })
  }

To create the output row (UnsafeRow) I used UnsafeRowWriter. One important point here - always reset the buffer. Otherwise, you risk to encounter NegativeArraySizeException. By the way, it's officialy advised in the comment of BufferHolder class: "Note that for each incoming record, we should call `reset` of BufferHolder instance before write the record and reuse the data buffer.". Also, before writing a new row to row buffer, you should call zeroOutNullBytes() to clear out null bits.

To confirm that it works, you can find the following test case:

  private val sparkSession: SparkSession = SparkSession.builder().appName("Union as join - physical optimization")
    .master("local[*]")
    .withExtensions(extensions => {
      extensions.injectResolutionRule(_ => UnionToJoinLogicalPlanRule)
      extensions.injectPlannerStrategy(_ => UnionToJoinRewrittenStrategy)
    })
    .getOrCreate()


  "UNION" should "be executed as RDD's fullOuterJoin method" in {
    import sparkSession.implicits._
    val dataset1 = Seq(("A", 1, 1), ("B", 2, 1), ("C", 3, 1), ("D", 4, 1), ("E", 5, 1)).toDF("letter", "nr", "a_flag")
    val dataset2 = Seq(("A", 1, 1), ("E", 5, 1), ("F", 10, 1), ("G", 11, 1), ("H", 12, 1)).toDF("letter", "nr", "a_flag")

    dataset1.createOrReplaceTempView("dataset_1")
    dataset2.createOrReplaceTempView("dataset_2")
    val rewrittenQuery = sparkSession.sql(
      """SELECT letter, nr, a_flag FROM dataset_1
        |UNION SELECT letter, nr, a_flag FROM dataset_2""".stripMargin)

    val unionRows = rewrittenQuery.collect().map(r => s"${r.getAs[String]("letter")}-${r.getAs[Int]("nr")}")
    unionRows should have size 8
    unionRows should contain allOf("A-1", "B-2", "C-3", "D-4", "E-5", "F-10", "G-11", "H-12")
    rewrittenQuery.queryExecution.executedPlan.toString should include("UnionJoinExecutor")
  }

That's all for this post about custom physical plan optimization. I agree with you that the code is far away from being optimal. It works for 2 datasets and returns all rows instead of their aggregated projection. However, this small POC helped me to discover some tricks like the one with PlanLater.