Introduction to custom optimization in Apache Spark SQL

Versions: Apache Spark 2.4.0

In November 2018 bithw1 pointed out to me a feature that I haven't used yet in Apache Spark - custom optimization. After some months consacred to learning Apache Spark GraphX, I finally found a moment to explore it. This post begins a new series about Apache Spark customization and it covers the basics, i.e. the 2 available methods to add the custom optimizations.

New ebook 🔥

Learn 84 ways to solve common data engineering problems with cloud services.

👉 I want my Early Access edition

Before all things, I will recall the basics of the optimization rules in Apache Spark. It will be a short part since I already published some posts about this topic in Spark SQL optimization internals category. In the next section, I will focus on the first method that we can use to define the additional optimizations. This part will also contain the example of a simple optimization rule. The next section will show another method to define the customized rules.

Optimization in Apache Spark

Apache Spark SQL executes the data processing logic in multiple steps. One of them, occurring directly after the query analysis stage, is the logical optimization. The main goals of this stage are to reduce the number of operations or apply some of them at the data source level. In one word, it helps to optimize the data processing logic.

Among the applied optimizations you will find the predicate pushdowns where the filtering is delegated to the data source, but also a lot of expressions simplifications like boolean simplification, filter clauses concatenation or RegEx substitution with String's startsWith or contains methods.

Extra optimizations

The Apache Spark 2.0.0 release brought a feature called extra optimizations. Its use is straightforward. After building the SparkSession you need to call experimental method and set the extraOptimizations field of the returned ExperimentalMethods class. The engine will later use these additional optimizations during the construction of the applied rules:

class SparkOptimizer(
    catalog: SessionCatalog,
    experimentalMethods: ExperimentalMethods)
  extends Optimizer(catalog) {

  override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
    Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
    Batch("Extract Python UDFs", Once,
      Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
    Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
    Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++
    postHocOptimizationBatches :+
    Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)

The following test proves how Apache Spark uses the extra optimization rule. In my case, it's not a pure optimization though. But I took this code to show how and whether the rule is really used during the execution:

"extra optimization rule" should "be added through extraOptimizations field" in {
  val testSparkSession: SparkSession = SparkSession.builder().appName("Extra optimization rules")
  import testSparkSession.implicits._
  testSparkSession.experimental.extraOptimizations = Seq(Replace0With3Optimizer)
  Seq(-1, -2, -3).toDF("nr").write.mode("overwrite").json("./test_nrs")

  val optimizedResult ="./test_nrs").selectExpr("nr + 0")

  optimizedResult.queryExecution.optimizedPlan.toString() should include("Project [(nr#12L + 3) AS (nr + 0)#14L]")
  optimizedResult.collect().map(row => row.getAs[Long]("(nr + 0)")) should contain allOf(0, 1, 2)

object Replace0With3Optimizer extends Rule[LogicalPlan] {

  def apply(logicalPlan: LogicalPlan): LogicalPlan = {
    logicalPlan.transformAllExpressions {
      case Add(left, right) => {
        if (isStaticAdd(left)) {
        } else if (isStaticAdd(right)) {
          Add(left, Literal(3L))
        } else {
          Add(left, right)

  private def isStaticAdd(expression: Expression): Boolean = {
    expression.isInstanceOf[Literal] && expression.asInstanceOf[Literal].toString == "0"



The extra optimization rules, it was some time ago. A more recent possibility to customize our Apache Spark workflows came with the 2.2.0 release. Previous version had only the support for the optimization rules and planning strategies. The new one completed them with:

Applying them is simple. You must only pass the customized rules during the SparkSession construction through the withExtensions method:

def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized {

The builder passes the extensions to the SparkSession as one of the constructor parameters. After that, the framework appends them into Analyzer and Optimizer of the used BaseSessionStateBuilder. Let's see how to implement our previous optimization with this new method:

"extra optimization rule" should "be added through extensions" in {
  val testSparkSession: SparkSession = SparkSession.builder().appName("Extra optimization rules")
    .withExtensions(extensions => {
      extensions.injectOptimizerRule(session => Replace0With3Optimizer)
  import testSparkSession.implicits._
  testSparkSession.experimental.extraOptimizations = Seq()
  Seq(-1, -2, -3).toDF("nr").write.mode("overwrite").json("./test_nrs")
  val optimizedResult ="./test_nrs").selectExpr("nr + 0")

  optimizedResult.queryExecution.optimizedPlan.toString() should include("Project [(nr#12L + 3) AS (nr + 0)#14L]")
  optimizedResult.collect().map(row => row.getAs[Long]("(nr + 0)")) should contain allOf(0, 1, 2)

Adding a custom optimization to the Apache Spark session is straightforward. But writing a good one is much harder. Initially, in this post, I wanted to simplify any idempotent operation like for instance abs(abs(abs("col1"))) but it was too complicated for the first try. That's why the post focused only on the custom optimization injection methods. Only one of the next articles will cover the API used to write the custom optimizations.