Writing custom optimization in Apache Spark SQL - generated code

Versions: Apache Spark 2.4.0 https://github.com/bartosz25/spark-...ions/CustomProjectCodeGenTest.scala

In my previous post, I explained how to implement a custom physical plan execution. However, this first version didn't use generated code which is also an interesting option to customize Apache Spark. And it's also the feature that I will cover in this post.

The post will be divided into several short parts. Each of them will focus on a specific feature of code generation. By the end of this post, you should be able to assembly these parts and understand the proposed custom physical execution.

CodegenContext

I introduced this class in The who, when, how and what of Apache Spark SQL code generation post but it needs some more detailed explanation now. Just to recall, this class acts as a container and a factory, i.e. it stores all variables and functions used in the generated code, but it also provides a lot of factory methods to create these variables and functions. For instance, you will use addMutableState(javaType: String, variableName: String, initFunc: String => String = _ => "", forceInline: Boolean = false, useFreshName: Boolean = true) to add a field holding a mutable state in the generated class. And among other possible methods you will find counters, input collections, or boolean flags.

Another factory method you can use in the generated class is freshName(name: String). It returns a unique name within the given instance of the generation context. It guarantees that you will never have conflicted variables so you can safely reference them in your code to generate.

CodegenContext also gives you a possibility to declare functions through addNewFunction(funcName: String, funcCode: String, inlineToOuterClass: Boolean = false). This method has an interesting inlining optimization (last argument). If it's true and the function becomes too big, it can be automatically inlined, i.e. moved into an external class.

Of course, there are many other factory methods. The pattern to find them is to look for functions beginning with an "add*" prefix. I'm stopping with the 3 ones because there are other things to share.

Current vars vs INPUT_ROW

The second discussed point concerns CodegenContext's fields called currentVars and INPUT_ROW. Accordingly to the documentation:

  /**
   * Holding the variable name of the input row of the current operator, will be used by
   * `BoundReference` to generate code.
   *
   * Note that if `currentVars` is not null, `BoundReference` prefers `currentVars` over `INPUT_ROW`
   * to generate code. If you want to make sure the generated code use `INPUT_ROW`, you need to set
   * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling
   * `Expression.genCode`.
   */
  var INPUT_ROW = "i"

  /**
   * Holding a list of generated columns as input of current operator, will be used by
   * BoundReference to generate code.
   */
  var currentVars: Seq[ExprCode] = null

So what it's all about? For a better understanding, we can simply check where BoundReference uses it to generate the code:

// org.apache.spark.sql.catalyst.expressions.BoundReference#doGenCode
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
      val oev = ctx.currentVars(ordinal)
      ev.isNull = oev.isNull
      ev.value = oev.value
      ev.copy(code = oev.code)
    } else {
      assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be null.")
      val javaType = JavaCode.javaType(dataType)
      val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
      if (nullable) {
        ev.copy(code =
          code"""
             |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
             |$javaType ${ev.value} = ${ev.isNull} ?
             |  ${CodeGenerator.defaultValue(dataType)} : ($value);
           """.stripMargin)
      } else {
        ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
      }

If you've ever read the code generated by Apache Spark before, you will quickly make a link between these isNull cases and the code. The getValue function referenced by CodeGenerator is a simple pattern matching against a set of the available types and their getters:

  def getValue(input: String, dataType: DataType, ordinal: String): String = {
    val jt = javaType(dataType)
    dataType match {
      case _ if isPrimitiveType(jt) => s"$input.get${primitiveTypeName(jt)}($ordinal)"
      case t: DecimalType => s"$input.getDecimal($ordinal, ${t.precision}, ${t.scale})"

To sum-up, we can say that these 2 variables are used to safely manipulate input rows. Currently, I have no idea why both exist though. If you have any clue, your comment is more than welcome!

Output vars

Unlike input part, output vars are something that is passed as a parameter to a method called consume that I will cover in the last section. Output vars represent a parameter which will be used to generate the code for UnsafeRow. But it's true only when the row name is not null. If it's not the case, the code reading the input data will use getters and write UnsafeRow like that:

/* 032 */         boolean unionjoinexecutor_isNull_0 = unionjoinexecutor_processedRow_0.isNullAt(0);
/* 033 */         UTF8String unionjoinexecutor_value_0 = unionjoinexecutor_isNull_0 ?
/* 034 */         null : (unionjoinexecutor_processedRow_0.getUTF8String(0));
/* 035 */         int unionjoinexecutor_value_1 = unionjoinexecutor_processedRow_0.getInt(1);
/* 036 */         int unionjoinexecutor_value_2 = unionjoinexecutor_processedRow_0.getInt(2);
/* 037 */         unionjoinexecutor_mutableStateArray_1[0].reset();
/* 038 */
/* 039 */         unionjoinexecutor_mutableStateArray_1[0].zeroOutNullBytes();
/* 040 */
/* 041 */         if (unionjoinexecutor_isNull_0) {
/* 042 */           unionjoinexecutor_mutableStateArray_1[0].setNullAt(0);
/* 043 */         } else {
/* 044 */           unionjoinexecutor_mutableStateArray_1[0].write(0, unionjoinexecutor_value_0);
/* 045 */         }
/* 046 */
/* 047 */         unionjoinexecutor_mutableStateArray_1[0].write(1, unionjoinexecutor_value_1);
/* 048 */
/* 049 */         unionjoinexecutor_mutableStateArray_1[0].write(2, unionjoinexecutor_value_2);
/* 050 */         append((unionjoinexecutor_mutableStateArray_1[0].getRow()));

On the other hand, when you use an explicit row name, the code will be much less verbose:

/* 025 */       InternalRow unionjoinexecutor_processedRow_0 = (InternalRow) unionjoinexecutor_mutableStateArray_0[0].next();
/* 026 */       if (unionjoinexecutor_count_0 % 2 == 0) {
/* 027 */         append(unionjoinexecutor_processedRow_0);

BufferedRowIterator

Another part of the mystery I faced is about BufferedRowIterator implementation. Why the code generated by Apache Spark extends this iterator by default? The answer hides in WholeStageCodegenExec:

  /**
   * Generates code for this subtree.
   *
   * @return the tuple of the codegen context and the actual generated source.
   */
  def doCodeGen(): (CodegenContext, CodeAndComment) = {
    // …
          final class $className extends ${classOf[BufferedRowIterator].getName} {

        private Object[] references;
        private scala.collection.Iterator[] inputs;
        ${ctx.declareMutableStates()}

        public $className(Object[] references) {
          this.references = references;
        }

        public void init(int index, scala.collection.Iterator[] inputs) {
          partitionIndex = index;
          this.inputs = inputs;
          ${ctx.initMutableStates()}
          ${ctx.initPartition()}
        }

        ${ctx.emitExtraCode()}

        ${ctx.declareAddedFunctions()}
      }
      """.trim

As you can see, if whole-stage code generation is enabled, Apache Spark will go to this method and inject there the custom code you wrote in your physical plan. To see this method ignored, you can simply set sparkSession.conf.set("spark.sql.codegen.wholeStage", false) and check apply method of CollapseCodegenStages class.

The above snippet also explains why, whatever you write for initialized mutable variables, every time you have to deal with this.inputs field and cannot override the name of input RDD. The proof in DataSourceScanExec's doProduce method:

// DataSourceScanExec

    // PhysicalRDD always just has one input
    val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")

// 

consume

I would like to terminate with WholeStageCodegenExec's consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null) method. This function will take columns or rows from current physical plan and call doConsume defined in the parent plan. First it builds the row that will be processed by the parent class directly from the expected output of the physical plan. Later this row is passed to the parent's doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode) method to...be consumed.

The consuming may take different forms. It can serialize or deserialize an object, expand it or simply add it to BufferedRowIterator which is the case for the example discussed here. The append's call looks like this, if the row name is empty:

/* 045 */         unionjoinexecutor_mutableStateArray_1[0].write(1, unionjoinexecutor_value_1)
;
/* 046 */
/* 047 */         unionjoinexecutor_mutableStateArray_1[0].write(2, unionjoinexecutor_value_2)
;
/* 048 */         append((unionjoinexecutor_mutableStateArray_1[0].getRow()));

Or like this, if it's defined:

/* 025 */       InternalRow unionjoinexecutor_processedRow_0 = (InternalRow) unionjoinexecutor
_mutableStateArray_0[0].next();
/* 026 */       if (unionjoinexecutor_count_0 % 2 == 0) {
/* 027 */         append(unionjoinexecutor_processedRow_0);

consume method also guarantees safety in conversion. I did a test and reduced the number of output variables (outputVars) to 1 and when consume was used, I got this error:

java.lang.AssertionError: assertion failed
        at scala.Predef$.assert(Predef.scala:156)
        at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:147)
        at com.waitingforcode.sql.CustomProjectExec.consume(CustomProjectCodeGenTest.scala:66)
        at com.waitingforcode.sql.CustomProjectExec.doProduce(CustomProjectCodeGenTest.scala:120)

If you do this and use one of parent's method, the code probably won't fail but you will get a strange output like this one (for this.append method called instead of consume, and outputVars.take(1)):

     +------+------------+------+
    |letter|          nr|a_flag|
    +------+------------+------+
    |     A|            |      |
    |     C|            |      |
    |     E|            |      |
    |     A|            |      |
    |     F|           |      |
    |     H|           |      |
    +------+------------+------+

Wrap-up

After discovering all of this, it's a good moment to write a custom optimization with a generated code. The example is here just to illustrate the use of previously covered points and it has no big extra value since it only tries to implement a weird filtering inside a code generated for projection:

case class CustomProjectExec(outputAttrs: Seq[Attribute], child: SparkPlan) extends SparkPlan with CodegenSupport {

  override protected def doExecute(): RDD[InternalRow] = {
    children.head.execute()
  }

  override def output: Seq[Attribute] =  outputAttrs

  override def children: Seq[SparkPlan] = Seq(child)

  override def inputRDDs(): Seq[RDD[InternalRow]] = {
    children.map(c => c.execute())
  }

  override protected def doProduce(ctx: CodegenContext): String = {
    val input = ctx.addMutableState("scala.collection.Iterator", "inputsVariable", v => s"$v = inputs[0];")

    val row1 = ctx.freshName("processedRow")
    ctx.INPUT_ROW = row1
    ctx.currentVars = null

    val rowsCounter = ctx.addMutableState(CodeGenerator.JAVA_INT, "count")
    // Generates output variables
    // I take output attributes with indexes and generate the references for them
    // If you want to see what brings the use of consume, you can overwrite this part with
    // ...ref }.take(2)
    // It will throw an exception since the `outputVars` are different from the `def output: Seq[Attribute]`: assertion failed
    //java.lang.AssertionError: assertion failed
    //    at scala.Predef$.assert(Predef.scala:156)
    //    at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:147)
    // On the other side, it won't generate an error if you use this.append(ctx.INPUT_ROW) instead but the data you'll
    // collect will be invalid:
    // +------+------------+------+
    //|letter|          nr|a_flag|
    //+------+------------+------+
    //|     A|            |      |
    //|     C|            |      |
    //|     E|            |      |
    //|     A|            |      |
    //|     F|           |      |
    //|     H|           |      |
    //+------+------------+------+
    val outputVars = output.zipWithIndex.map { case (a, i) =>
      val ref = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
      ref
    }

    // using fresh name helps to avoid naming conflicts
    val debugFunctionName = ctx.freshName("printProcessedRow")
    val debugFunction = ctx.addNewFunction(debugFunctionName, s"""
      protected void $debugFunctionName(InternalRow row) {
        System.out.println("Processing " + row);
      }
    """, inlineToOuterClass = true)

    s"""
       |while ($input.hasNext()) {
       |  InternalRow ${ctx.INPUT_ROW} = (InternalRow) $input.next();
       |  ${debugFunction}(${ctx.INPUT_ROW});
       |  if ($rowsCounter % 2 == 0) {
       |    ${consume(ctx, outputVars, ctx.INPUT_ROW)}
       |  } else {
       |    System.out.println("Skipping row because of counter " + $rowsCounter);
       |  }
       |  $rowsCounter = $rowsCounter + 1;
       |  if (shouldStop()) return;
       |}
    """.stripMargin
  }
}

The following test shows how does it work:

object CustomProjectTransformer extends Rule[LogicalPlan] {

  override def apply(logicalPlan: LogicalPlan): LogicalPlan = logicalPlan transformDown   {
    case project: Project => {
      CustomProject(project, project.child)
    }
  }
}

case class CustomProject(project: Project, child: LogicalPlan) extends UnaryNode {
  override def output: Seq[Attribute] = {
    val aliasesToMap = project.projectList.map(ne => ne.name)
    child.output.zipWithIndex.map {
      case (expression, index) => Alias(expression, aliasesToMap(index))().toAttribute
    }
  }
}

object CustomProjectStrategy extends SparkStrategy {
  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case project: CustomProject => {
      new CustomProjectExec(project.output, planLater(project.child)) :: Nil
    }
    case _ => Nil
  }
}

  "projection" should "be overwriten with a custom physical optimization" in {
    val sparkSession: SparkSession = SparkSession.builder().appName("Custom project rewritten")
      .master("local[1]") // limit parallelism to see counter from generated code in action
      .withExtensions(extensions => {
        extensions.injectResolutionRule(_ => CustomProjectTransformer)
        extensions.injectPlannerStrategy(_ => CustomProjectStrategy)
      })
      .getOrCreate()
    import sparkSession.implicits._
    val dataset = Seq(("A", 1, 1), ("B", 2, 1), ("C", 3, 1), ("D", 4, 1), ("E", 5, 1)).toDF("letter", "nr", "a_flag")

    val allRows = dataset.select("letter", "nr", "a_flag").collect()

    allRows should have size 3
    allRows.map(row => row.mkString(", ")) should contain allOf("A, 1, 1", "C, 3, 1", "E, 5, 1")
  }

This article presents how to write a custom physical strategy using generated code. It uses a bottom-up approach where the concepts implemented in the code example were consecutively explained in each section. As you could see, the key here is to use CodegenContext as much as possible since it provides a safe way to generate naming and evaluate the variables. It's also important to keep everything consistent and, for instance, to not generate the rows which are different from the exposed output.


If you liked it, you should read:

đź“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!