Writing Apache Spark SQL custom logical optimization - improved code and summary

Versions: Apache Spark 2.4.0

In the previous post about Apache Spark SQL custom optimizations I presented a rule transforming UNION operator into JOIN. At this time I only did a simple version working only with 2 datasets. In this post, I will share its improved version.

Looking for a better data engineering position and skills?

You have been working as a data engineer but feel stuck? You don't have any new challenges and are still writing the same jobs all over again? You have now different options. You can try to look for a new job, now or later, or learn from the others! "Become a Better Data Engineer" initiative is one of these places where you can find online learning resources where the theory meets the practice. They will help you prepare maybe for the next job, or at least, improve your current skillset without looking for something else.

👉 I'm interested in improving my data engineering skillset

See you there, Bartosz

The post starts by showing the improved version which applies on more than 2 UNIONs. The second part gives a summary of all classes and abstractions I had to use to write the code.

Improved UNION rewritter

Initially I thought that adapting the code from the previous week's post to the case of multiple UNIONs will be complicated. First, I wanted to catch the highest UNION operator and recursively transform all its children into JOINs. However, that was not a good idea because the transform() method behaves exactly like a graph traversal. Therefore, I already have the possibility to reach and transform every node. My second thought was "maybe I will catch the UNIONs and transform only its 2 closest children". The only thing I had to verify was the state of the transformed plan. I wanted to ensure that the transformed logical plan is broadcasted up or down in the transform method. This test case proves that it's the case:

  "transformation method" should "move modified plan to the next level" in {
    val letterReference = AttributeReference("letter", StringType, false)()
    val nameReference = AttributeReference("name", StringType, false)()
    val ageReference = AttributeReference("age", IntegerType, false)()
    val selectExpressions = Seq[NamedExpression](
      letterReference, nameReference
    val dataset = LocalRelation(Seq(letterReference, nameReference, ageReference))
    val project1 = Project(selectExpressions, dataset)
    val project2 = Project(selectExpressions, dataset)
    val join = Join(project1, project2, JoinType("fullouter"),
      Option(And(selectExpressions(0), selectExpressions(1))))

    val projectionsFromJoin = new scala.collection.mutable.ListBuffer[Project]()
    // transform is called only once;
    // the goal is to show you that the modified plan is passed to next levels
    join.transformUp {
      case project: Project => {
        project.copy(projectList = project.projectList ++ project.projectList)
      case join: Join => {
        join.children.filter(childPlan => childPlan.isInstanceOf[Project])
          .foreach(childPlan => {

    val expectedFields = Seq("letter", "name", "letter", "name")
    projectionsFromJoin.foreach(projection => {
      projection.projectList.map(namedExpression => namedExpression.name) should contain allElementsOf expectedFields

As you can observe in the test, the transformed Project is directly passed to upper levels of the logical plan. Thanks to this confirmation I could apply my idea of transforming tuples:

object UnionAdvancedOptimizer extends Rule[LogicalPlan] {

  override def apply(logicalPlan: LogicalPlan): LogicalPlan = logicalPlan transformUp {
    case distinct: Distinct if distinct.child.isInstanceOf[Union] => {
      val union = distinct.child.asInstanceOf[Union]
      val joinColumns = union.children.map {
        case localRelation: LocalRelation => localRelation.output
        case project: Project => project.output
      val joinPairs = joinColumns.transpose
      val equalToExpressions = joinPairs.map(attributes => {
        EqualTo(attributes(0), attributes(1))

      def concatExpressions(expression: Expression, remainingExpressions: Seq[Expression]): Expression = {
        if (remainingExpressions.isEmpty) {
        } else {
          concatExpressions(And(expression, remainingExpressions.head), remainingExpressions.tail)
      val concatenatedExpressions = concatExpressions(equalToExpressions.head, equalToExpressions.tail)

      val projection1 = Project(joinColumns(0), union.children(0))
      val projection2 = Project(joinColumns(1), union.children(1))
      val join1 = Join(projection1, projection2, JoinType("fullouter"), Option(concatenatedExpressions))
      val projection = combineProjection(joinPairs, join1)
  private def combineProjection(joinAttributes: Seq[Seq[Attribute]], childPlan: Join): Project = {
    val fields = joinAttributes.map(attributes => {
      Alias(Coalesce(attributes), attributes(0).name)()
    Project(fields, childPlan)


If you remember my previous post, you can see that the code didn't evolve a lot. The only difference is the used transformation method. Since I wanted to avoid to recursively convert UNIONs into JOINs, I preferred this time to start from the bottom and be able to find the pairs of projections at every level (remember: our rewriter returns a projection on top of the JOIN). You can see the optimizer in action in the following test:

  private val sparkSession: SparkSession = SparkSession.builder().appName("Advanced UNION hint test").master("local[*]")
    .withExtensions(extensions => {
      // I should use here an optimizer rule but I should then apply it on an HashAggregate + Union which is
      // less understandable than Union + Distinct
      extensions.injectResolutionRule(_ => UnionAdvancedOptimizer)

  "multiple UNION" should "be rewritten to JOINs" in {
    import sparkSession.implicits._
    val dataset1 = Seq(("A", 1, 1), ("B", 2, 1), ("C", 3, 1), ("D", 4, 1), ("E", 5, 1)).toDF("letter", "nr", "a_flag")
    val dataset2 = Seq(("F", 10, 1), ("G", 11, 1), ("H", 12, 1)).toDF("letter", "nr", "a_flag")
    val dataset3 = Seq(("A", 20, 1), ("B", 21, 1), ("C", 22, 1)).toDF("letter", "nr", "a_flag")
    val dataset4 = Seq(("A", 20, 1), ("B", 21, 1), ("C", 22, 1)).toDF("letter", "nr", "a_flag")

    val rewrittenQuery = sparkSession.sql("SELECT letter, nr, a_flag FROM dataset_1 " +
      "UNION SELECT letter, nr, a_flag FROM dataset_2 " +
      "UNION SELECT letter, nr, a_flag FROM dataset_3 UNION SELECT letter, nr, a_flag FROM dataset_4")

    val unionDataset =
      rewrittenQuery.map(row => s"${row.getAs[String]("letter")}-${row.getAs[Int]("nr")}-${row.getAs[Int]("a_flag")}").collect()

    unionDataset should have size 11
    unionDataset should contain allOf("A-1-1", "A-20-1", "B-2-1", "B-21-1", "C-3-1", "C-22-1", "D-4-1",
      "E-5-1", "F-10-1", "G-11-1", "H-12-1")
    rewrittenQuery.queryExecution.executedPlan.toString should include("SortMergeJoin")

The code has still some points to improve like filtering but I won't going into this right now and instead, I will rather move to the physical plans.

Used classes and abstractions

Last time I shared with you some takeaways about writing a custom optimization. Now I would like to complete this list with the abstractions that you may need to implement it:

That's all for writing a custom resolution rule. But that's not finished for Apache Spark SQL customization. In the next post, I will explore other customization rules and make the exercise of implementing them.

If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!