Writing custom optimization in Apache Spark SQL - custom parser

Versions: Apache Spark 2.4.0

Last time I presented ANTLR and how Apache Spark SQL uses it to convert textual SQL expressions into internal classes. In this post I will write a custom parser.

The post is composed of 2 parts. Each one describes one specific scenario requiring a custom parser. The first section shows how to build a completely new DSL for queries. The second one covers the use case of extending already existent parser with new operator.

Alternative querying DSL

For this first example I will write a piece of completely new data querying language. The syntax used for making the queries is:

[ (columns) ] <- dataset

ANTLR grammar file looks like this:

grammar SqlExtended; 

dslQuery : '[' columnsList ']' '<-' dataset ;

columnsList: (IDENTIFIER)+ ;

dataset: IDENTIFIER ;


IDENTIFIER : [a-zA-Z0-9]+ ;
WS: [ \t\n\r]+ -> skip ;

To generate lexer, parser, and visitor I used ANTLR Maven plugin.

To parse alternatively written queries, I had to define 2 classes. The first one is the visitor which will be responsible for transforming ANTLR's AST into a LogicalPlan, the same way as for Apache Spark LogicalPlan creation:

class CustomSqlVisitor extends SqlExtendedBaseVisitor[LogicalPlan] {
  override def visitDslQuery(ctx: SqlExtendedParser.DslQueryContext): LogicalPlan = {
    val columnsListContext = ctx.children.asScala.find(tree => tree.isInstanceOf[ColumnsListContext])
      .map(tree => tree.asInstanceOf[ColumnsListContext]).get // fail-fast, should be always found
    val datasetContext = ctx.children.asScala.find(tree => tree.isInstanceOf[DatasetContext])
      .map(tree => tree.asInstanceOf[DatasetContext]).get // fail-fast, should be always found

    Project(getColumnsToSelect(columnsListContext),
      UnresolvedRelation(TableIdentifier(datasetContext.children.asScala.last.getText)))
  }

  private def getColumnsToSelect(columnsListContext: ColumnsListContext): Seq[NamedExpression] = {
    columnsListContext.children.asScala.map(column => UnresolvedAttribute(column.getText))
  }
}

Since our DSL is simple, there is no a lot of magic here. More interesting things happen in the custom parser:

case class CustomDslParser(defaultParser: ParserInterface) extends ParserInterface {
  override def parsePlan(sqlText: String): LogicalPlan = {
    // All this is included in a parse method shared by all parse* methods
    // See org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse
    val sqlExtendedLexer = new SqlExtendedLexer(CharStreams.fromString(sqlText))
    val tokenStream = new CommonTokenStream(sqlExtendedLexer)

    val parser = new SqlExtendedParser(tokenStream)
    val visitor = new CustomSqlVisitor()
    val result = visitor.visitDslQuery(parser.dslQuery())
    result
  }

  // Delegate to default parser for remaining parsing expressions
  override def parseExpression(sqlText: String): Expression = defaultParser.parseExpression(sqlText)

  override def parseTableIdentifier(sqlText: String): TableIdentifier = defaultParser.parseTableIdentifier(sqlText)

  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = defaultParser.parseFunctionIdentifier(sqlText)

  override def parseTableSchema(sqlText: String): StructType = defaultParser.parseTableSchema(sqlText)

  override def parseDataType(sqlText: String): DataType = defaultParser.parseDataType(sqlText)
}

First and foremost, you can see that it delegates a part of the job to the default parser associated with current SparkSession. Accordingly to the comment from SparkSessionExtensions#injectParser(), the injection method can retrieve already existent parser in order to "create a partial parser and to delegate to the underlying parser for completeness". The second point is the implementation of parsePlan. As you can see, it uses a similar code to the one from parser method of AbstractSqlParser class, i.e. it creates a lexer, parser and at the end visits parsed tree with our custom visitor in order to build the LogicalPlan for the query.

Let's assert now that it works as expected:

  "custom DSL" should "be converted to a SELECT statement" in {
    val sparkSession: SparkSession = SparkSession.builder().appName("Custom DSL parser")
      .master("local[*]")
      .withExtensions(extensions => {
        extensions.injectParser((session, defaultParser) => {
          CustomDslParser(defaultParser)
        })
      })
      .getOrCreate()
    import sparkSession.implicits._
    val dataset1 = Seq(("A", 1), ("B", 2), ("C", 3), ("D", 4), ("E", 5)).toDF("letter", "nr")
    dataset1.createOrReplaceTempView("dataset1")

    val customQuery = sparkSession.sql("[letter nr] <- dataset1")

    val result = customQuery.collect().map(row => s"${row.getAs[String]("letter")}-${row.getAs[Int]("nr")}")
    result should have size 5
    result should contain allOf("A-1", "B-2", "C-3", "D-4", "E-5")
  }

Extending existent parser

The previous parser was relatively easy since it didn't impact Apache Spark's default one. Now, let's do something more complicated and see how we could extend the current parser. During my research I found a method proposed by Ruby Tahboub in Writing a customized Spark SQL Parser. It works but seems to have some maintenance burden in case of version changes. That's why I tried to find an alternative version, loosely coupled with Apache Spark files.

I first tried to extend existing classes but unfortunately, it's quite hard since they're generated by ANTLR and extending them would say to do the job of ANTLR by hand. From that, I tried to create an extension for the grammar file. Before making it, I had to download the grammar version corresponding to the runtime version of Apache Spark. The extended grammar file is not complex since it uses similar principles to the object-oriented programming languages:

grammar SqlExtended;
import SqlBase;


queryOrganization
    : (ORDER BY order+=sortItem (',' order+=sortItem)*)?
      (RETURNED BY returned+=sortItem (',' returned+=sortItem)*)?
      (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)?
      (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)?
      (SORT BY sort+=sortItem (',' sort+=sortItem)*)?
      windows?
      (LIMIT (ALL | limit=expression))?
    ;

RETURNED: 'RETURNED';

I bet you already see what I'm trying to do. As an exercise, I want to add a new syntax for query ordering called "RETURNED BY". Two things to note. First, I didn't find a way to extend only a part of queryOrganization - that's why the whole expression is copied (if you have any, I'll be happy to learn). Another one is the need to map the RETURNED to its textual representation. In my first try, I omitted this and the lexer didn't recognize this value properly ("warning(125): SqlExtended.g4:7:7: implicit definition of token RETURNED in parser" at compilation time) and considered "RETURNED" as an identifier instead of an operator.

After redefining queryOrganization, I wanted to implement only the part constructing ordering clause, or rather to add a simple case to consider "RETURNED BY" as a synonymous of ORDER BY (I'm aware that I could simply do a string's replace("RETURNED BY", "ORDER BY" and use the default parser but it's not the goal here). The method I should override was:

  /**
   * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
   * clauses determine the shape (ordering/partitioning/rows) of the query result.
   */
  private def withQueryResultClauses(
      ctx: QueryOrganizationContext,
      query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
// ...

As you can see, the method is private so extending it will be impossible. Delegating the execution will be impossible as well since the method is called here:

  override def visitSingleInsertQuery(
      ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
    plan(ctx.queryTerm).
      // Add organization statements.
      optionalMap(ctx.queryOrganization)(withQueryResultClauses).
      // Add insert.
      optionalMap(ctx.insertInto())(withInsertInto)
  }

As you can notice here, the solution from the beginning of this section, that I wanted to simplify, seems to be the easiest one and so despite the overhead of copying classes and files.

The goal of this post was to show that it's not an easy task to extend existent parser nor to delegate a part of specific work to it. If you think that I missed something, please post a comment. I will be happy to see that writing a custom parser is much simpler than I experienced. Otherwise, you will need either to write a custom visitor or copy the codebase. While the former solution is fine if we're developing a completely new querying language, the latter one is much more difficult to accept, especially if we want to simply add one new operator. Fortunately, we can still do this with logical optimization rule and appropriated custom physical plan.