Writing Apache Spark SQL custom logical optimization - API

Versions: Apache Spark 2.4.0

In one of my previous posts I presented how to add a custom optimization to Apache Spark SQL. It was not a good moment to deep delve into the topic because of its complexity. That's why I will try to do a better job here by showing the API of native optimizations.

This post is an exercise for reverse engineering where I will try to analyze already existent optimization rules before building my own in one of the next posts. The post is composed of 5 sections. Each one explains one specific point of Apache Spark SQL logical optimization. The first of them defines a general template that you can meet in the source code. Two next sections focus on 2 methods of LogicalPlan that are used in the template: transform and resolve. The 4th section shows the role of the pattern matching guards. The last part terminates the post with a short recall of the elements that you can use in your custom optimization code.

The general template

The optimization rule must inherit the Rule[TreeType <: TreeNode[_]] class where the type is one of the nodes of the query AST (Abstract Syntax Tree). Among the examples of TreeNodes you will find Literal for the text values, GreaterThan for ">" expression, Filter for the predicates from WHERE clause and so forth. I written a post about the basics of Catalyst Optimizer in Spark SQL where you can find more details about the TreeNodes. All of them have one important thing in common, they are the children of the LogicalPlan abstract class.

After discovering the fact about the Rule base class, I started to analyze all the implementations available in the current version (2.4.0). From that analysis I figured out a global pattern used to write the logical plan optimization rules. The pattern can be summarized as:

// {{TRANSFORMATION}} = one of LogicalPlan's transform* methods
def apply(plan: LogicalPlan): LogicalPlan = plan.{{TRANSFORMATION}} {
  case agg: Aggregation => …
  case projection: Project => ...
}

As you can see, the rule starts with the call to one of LogicalPlan's transform-like methods. I will detail them in the next section. These methods take a partial function as a parameter. To recall, a partial function is a function that applies only on a subset of values. You can find more details in the dedicated post about Partial functions in Scala. The partial application explains why you retrieve the pattern matching inside the transform method. If the function is not supposed to handle some specific operator, its execution is simply skipped.

LogicalPlan and transform methods

Let's focus now on the transform* methods declared in LogicalPlan. In fact, these methods are defined partially in parent classes which are QueryPlan and UnaryNode.

The role of transform-like methods is to apply the optimization rule on the AST nodes. In the API you will find 2 kinds of methods, pre-order and post-order. The former ones are suffixed with Down keyword and apply the rule to the current node and all its children. The latter one is suffixed with Up and applies the rule first to all the children before the current node. Let's see some test cases to get a better idea about it:

private val selectStatement = {
  val letterReference = AttributeReference("letter", StringType, false)()
  val nameReference = AttributeReference("name", StringType, false)()
  val ageReference = AttributeReference("age", IntegerType, false)()
  val selectExpressions = Seq[NamedExpression](
    letterReference, nameReference
  )

  val dataset = LocalRelation(Seq(letterReference, nameReference, ageReference))
  Project(selectExpressions, dataset)
}


"transformDown" should "only apply to the current node and children" in {
  val transformedPlans = new scala.collection.mutable.ListBuffer[String]()
  selectStatement.transformDown {
    case lp => {
      transformedPlans.append(lp.nodeName)
      lp
    }
  }

  transformedPlans should have size 2
  transformedPlans should contain inOrder("Project", "LocalRelation")
}

"transformUp" should "apply to the children and current node" in {
  val transformedPlans = new scala.collection.mutable.ListBuffer[String]()
  selectStatement.transformUp {
    case lp => {
      transformedPlans.append(lp.nodeName)
      lp
    }
  }

  transformedPlans should have size 2
  transformedPlans should contain inOrder("LocalRelation", "Project")
}

LogicalPlan and resolve operators

But the transform-like methods are not the only ones used in the optimization rules. Another category, less frequently used though, groups resolveOperators* methods. The resolve-like functions are very similar to the transformations. The single difference is that they skip already analyzed nodes. The following test illustrates that difference:

private val selectStatement = {
  val letterReference = AttributeReference("letter", StringType, false)()
  val nameReference = AttributeReference("name", StringType, false)()
  val ageReference = AttributeReference("age", IntegerType, false)()
  val selectExpressions = Seq[NamedExpression](
    letterReference, nameReference
  )

  val dataset = LocalRelation(Seq(letterReference, nameReference, ageReference))
  Project(selectExpressions, dataset)
}

"resolveOperatorsUp" should "should only apply to the not analyzed nodes" in {
  val resolvedPlans = new scala.collection.mutable.ListBuffer[String]()
  val resolvedPlan = selectStatement.resolveOperatorsUp {
    case project @ Project(selectList, child) => {
      resolvedPlans.append(project.nodeName)
      project
    }
  }
  // Mark the [[Project]] as already analyzed
  SimpleAnalyzer.checkAnalysis(resolvedPlan)

  // Check once again whether the [[Project]] or its children will be resolved once again
  resolvedPlan.resolveOperatorsUp {
      case project @ Project(selectList, child) => {
        resolvedPlans.append(project.nodeName)
        project
      }
      case lp => {
        resolvedPlans.append(lp.nodeName)
        lp
      }
    }

  resolvedPlans should have size 1
  resolvedPlans(0) shouldEqual ("Project")
}

"transformDown" should "should apply to analyzed and not analyzed nodes" in {
  val resolvedPlans = new scala.collection.mutable.ListBuffer[String]()
  val resolvedPlan = selectStatement.transformDown {
    case project @ Project(selectList, child) => {
      resolvedPlans.append(project.nodeName)
      project
    }
  }
  SimpleAnalyzer.checkAnalysis(resolvedPlan)

  resolvedPlan.transformDown {
    case project @ Project(selectList, child) => {
      resolvedPlans.append(project.nodeName)
      project
    }
  }

  resolvedPlans should have size 2
  resolvedPlans(0) shouldEqual ("Project")
  resolvedPlans(1) shouldEqual ("Project")
}

As you can see, in the test I called checkAnalysis method. It's a method provided by CheckAnalysis trait and it consists on checking the query written by the user against some syntax rules. To see all possible analysis errors, you can simply check the uses of the failAnalysis(msg: String). Below you can find some of them:

failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
        failAnalysis(
          "The number of aliases supplied in the AS clause does not match the number of columns " +
          s"output by the UDTF expected ${elementAttrs.size} aliases but got " +
          s"${names.mkString(",")} ")
            failAnalysis(s"invalid cast from ${c.child.dataType.catalogString} to " +
              c.dataType.catalogString)
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in a Filter: $plan")

Put another way, the analysis checks whether the query can be physically executed. The flag indicating the analyzed state of the node exists only to avoid to work on already analyzed nodes.

Resolution guards

Pattern matching comes with an interesting feature of guards which are an extra condition on the matched expression. In the optimization rules, you will very often meet the guards checking whether the given node or its children are resolved. In the next snippet, you can see some of these use cases in different places of Apache Spark source code:

// FixNullability
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
  case p if !p.resolved => p // Skip unresolved nodes.
  case p: LogicalPlan if p.resolved =>
    val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap {
      case (exprId, attributes) =>
        // ...

// TypeCoercionRule
def apply(plan: LogicalPlan): LogicalPlan = {
  val newPlan = coerceTypes(plan)
  if (plan.fastEquals(newPlan)) {
    plan
  } else {
    propagateTypes(newPlan)
  }
}
private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
  // No propagation required for leaf nodes.
  case q: LogicalPlan if q.children.isEmpty => q

  // Don't propagate types from unresolved children.
  case q: LogicalPlan if !q.childrenResolved => q

If you take a look at the lazy val resolved: Boolean field in LogicalPlan class, you will see pretty self-explanatory comment:

  /**
   * Returns true if this expression and all its children have been resolved to a specific schema
   * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan
   * can override this (e.g.
   * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
   * should return `false`).
   */
  lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved

As you can see, it's the flag saying whether the expression and all its children were matched against some specific schema. And if you search the use cases of that flag, you will find that it's used in a lot of places rewriting the operators or using the schema:

Other types of guards exist but they're less popular than the ones from the above list. Among them, you will find the guards checking whether the query is a streaming one or not, or whether the children plan has a specific size:

// ReplaceDeduplicateWithAggregate
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
  case Deduplicate(keys, child) if !child.isStreaming =>
    val keyExprIds = keys.map(_.exprId)
    val aggCols = child.output.map { attr =>
    if (keyExprIds.contains(attr.exprId)) {
      attr
    } else {
      Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
    }
  }

// TimeWindowing
  def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
    case p: LogicalPlan if p.children.size == 1 =>
      val child = p.children.head
      val windowExpressions =
        p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet

LogicalPlan implementations

I've already mentioned the major LogicalPlan implementations but that was not detailed. In the list below you can find more context for some popular implementations:

That's all for this post. As you can see, we can create a universal template since almost all of them are designed around LogicalPlan's transform or resolve-like methods and the partial function matching one of the operators existing in the query. In one of the next posts, I will try to go deeper and write a custom optimization with a little bit more awareness than previously.


If you liked it, you should read:

đź“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!