Writing Apache Spark SQL custom logical optimization - the first version

Versions: Apache Spark 2.4.0 https://github.com/bartosz25/spark-...mizations/UnionSimpleHintTest.scala

Last time I wrote about different hints present in RDBMS and Hive. Today it's the moment to implement one of them.

The post has 2 sections. In the first one, I will present you the code of "UNION hint" that will translate UNION operation into FULL OUTER JOIN. In the second section, I will give you some feedback about the things I learned during the process. Spoiler alert: it's not the last post from the series about Spark SQL customization. In the next ones, I will try to approach planner strategy customization and extend UNION hint.

Union hint

During next posts I will implement a kind of UNION hint that will transform UNION execution plan into one or more JOIN operations. I found the inspiration in SQL Server query hints execution plan - Part3 post where the author shows how to use 2 different UNION strategies, one using hashing and another merge union to combine both datasets. Please notice that the article was written in 2014 and maybe the things changed since.

Even though I explored some of extension API internals in writing Apache Spark SQL custom logical optimization - API, I decided to start small and implement a very simple version of UNION transformer. This version works only for an operation applied on 2 datasets:

object UnionSimpleTransformer extends Rule[LogicalPlan] {

  override def apply(logicalPlan: LogicalPlan): LogicalPlan = logicalPlan transform {
    case distinct: Distinct if distinct.child.isInstanceOf[Union] => {
      val union = distinct.child.asInstanceOf[Union]
      val joinColumns = union.children.map {
        case localRelation: LocalRelation => localRelation.output
        case project: Project => project.output
      }
      val joinPairs = joinColumns.transpose
      val equalToExpressions = joinPairs.map(attributes => {
        EqualTo(attributes(0), attributes(1))
      })

      def concatExpressions(expression: Expression, remainingExpressions: Seq[Expression]): Expression = {
        if (remainingExpressions.isEmpty) {
          expression
        } else {
          concatExpressions(And(expression, remainingExpressions.head), remainingExpressions.tail)
        }
      }
      val concatenatedExpressions = concatExpressions(equalToExpressions.head, equalToExpressions.tail)

      val projection1 = Project(joinColumns(0), union.children(0))
      val projection2 = Project(joinColumns(1), union.children(1))
      val join = Join(projection1, projection2, JoinType("fullouter"), Option(concatenatedExpressions))
      combineProjection(joinPairs, join)
    }
  }

  private def combineProjection(joinAttributes: Seq[Seq[Attribute]], childPlan: Join): Project = {
    val fields = joinAttributes.map(attributes => {
      Alias(Coalesce(attributes), attributes(0).name)()
    })
    Project(fields, childPlan)
  }


}

Below you can find how UnionSimpleTransformer works :

  private val sparkSession: SparkSession = SparkSession.builder().appName("Union Hint - simple version test")
    .master("local[*]")
    .withExtensions(extensions => {
      extensions.injectResolutionRule(_ => UnionSimpleTransformer)
    })
    .getOrCreate()

  "UNION rewritter" should "transform a UNION of 2 Datasets into JOIN" in {
    import sparkSession.implicits._
    val dataset1 = Seq(("A", 1, 1), ("B", 2, 1), ("C", 3, 1), ("D", 4, 1), ("E", 5, 1)).toDF("letter", "nr", "a_flag")
    val dataset2 = Seq(("A", 1, 1), ("E", 5, 1), ("F", 10, 1), ("G", 11, 1), ("H", 12, 1)).toDF("letter", "nr", "a_flag")

    dataset1.createOrReplaceTempView("dataset_1")
    dataset2.createOrReplaceTempView("dataset_2")
    val rewrittenQuery = sparkSession.sql(
      """SELECT letter, nr, a_flag FROM dataset_1
        |UNION SELECT letter, nr, a_flag FROM dataset_2""".stripMargin)
    rewrittenQuery.explain(true)

    val unionData = rewrittenQuery.map(row => s"${row.getAs[String]("letter")}-${row.getAs[Int]("nr")}-${row.getAs[Int]("a_flag")}")
      .collect()
    unionData should have size 8
    unionData should contain allOf("A-1-1", "B-2-1",  "C-3-1", "D-4-1", "E-5-1", "F-10-1", "G-11-1", "H-12-1")
    rewrittenQuery.queryExecution.executedPlan.toString should include("SortMergeJoin")
  }

As you can see, everything works as expected. The only thing to notice is that the optimizer applies only on a simple UNION, i.e. the one which doesn't return duplicates.

Takeaways

Extending Catalyst Optimizer is similar to writing iterative graph traversal. The optimizer can visit your rule many times and maybe it's one of the biggest difficulties. But that's not the only one. If you want to discover them more, I prepared a short list of takeaways of my learning experience:

  1. Read the logs - always check whether your rule was applied by looking for "Applying Rule…" message in the logs. It will also help you to understand what happened if your rule is overridden or ignored by any other native one.
  2. Catch everything in rule's partial function. When I first started with the optimizations, I wanted to see what happens when I change transformDown to transformUp or use any of resolveOperators* instead. I was always printing the parts I wanted but you can do some debugging with breakpoints as well.
  3. If you don't have an idea about the specific operation nodes (Project, Join, Filter…) to use in the plan you want to build, you can always start by writing the query you want to generate and retrieve the nodes from there. For my case of UNION rewriting, I started with exploring the plan for FULL OUTER JOIN executed with sort-merge strategy.
  4. Start small but it's a general software engineering truth. Do not try to build your final transformation at once. Instead, go step by step and at every successful iteration add a new feature. Concretely, in my example, I started by transforming a UNION of 2 datasets into a FULL OUTER JOIN. In the next iteration, I added the projection returning updated fields and only at the end I made it more dynamic and adapted to more than 2 datasets (next post).
  5. Work in isolation. If you are not sure about your code, always try to execute it against a smaller scope. In the exercise, I wasn't sure about my custom Project operation (the one with aliases). First I thought that I was using incorrect operations or attributes. Luckily, after some tests in isolation, i.e. where I only added a new projection on top of another one, I validated the projection code and could identify the issue elsewhere (see next point for that).
  6. Do not be too granular. In the first version of my UNION to JOIN transformation, I added 2 optimization steps, one to translate UNION into JOIN, and one to add a custom projection on top of already existent one. After validating in the previous point that my custom projection was correct, I tried to move it inside the translation part. And it was the solution. After all, it was quite logical because the custom projection was an intrinsic part of the rewritten UNION. Hence, try always to think in the categories of operations and not isolation - even though isolation is helpful in many other places like testing and readability.
  7. Test at scale. In this article I simply tried, as an exercise, to change Apache Spark's execution plan. But Catalyst Optimizer already does a lot of useful stuff (Spark SQL operator optimizations - part 1, Spark SQL operator optimizations - part 2) and you should be careful with adding your own rule. And if you do and decide to go on production, test the changes at scale to avoid bad surprises at runtime.
  8. It's hard to override already transformed plan by native optimization rules. First, I wanted to write a rule to preserve the order of JOINs. However, after several different tests, the plan was already optimized by Apache Spark native optimization rules and canceling that was harder than I expected. That's why, when you don't want to use some of built-in rules, think rather about disabling them with spark.sql.optimizer.excludedRules property and always check the list of non excludable rules exposed in org.apache.spark.sql.catalyst.optimizer.Optimizer#nonExcludableRules() method.

In this first post about implementing custom optimizations in Apache Spark SQL I presented a way to transform a distinct UNION into a JOIN query. That's true, it's a simplified version working only with 2 datasets. In the next post, I will try to extend it and make it work with more than 2 sources.