sortWithinPartitions in Apache Spark SQL

Versions: Apache Spark 2.4.3

Few weeks ago when I was preparing a talk for one local meetup, I wanted to list the most common operations we can do with Spark for the newcomers. And I found one I haven't used before, namely sortWithinPartitions.

A virtual conference at the intersection of Data and AI. This is not a conference for the hype. Its real users talking about real experiences.
- 40+ speakers with the likes of Hannes from Duck DB, Sol Rashidi, Joe Reis, Sadie St. Lawrence, Ryan Wolf from nvidia, Rebecca from lidl
- 12th September 2024
- Three simultaneous tracks
- Panels, Lighting Talks, Keynotes, Booth crawls, Roundtables and Entertainment.
- Topics include (ingestion, finops for data, data for inference (feature platforms), data for ML observability
- 100% virtual and 100% free

👉 Register here

The post is divided in 2 sections. The first one shows what happens for a global ordering, ie. the one that applies on the whole dataset. The second part focuses on the local ordering and shows how the "local" ordering is different.

Global ordering

To understand what happens for global ordering, let's take a very easy use case:

  "global ordering" should "range partition the data" in {
    val dataset = Seq((30), (40), (10), (20), (11), (25), (99), (109), (9), (2)).toDF("nr")

    val executionPlan = dataset.sort($"nr".desc).queryExecution.executedPlan

    executionPlan.toString() should include("Exchange rangepartitioning(nr#3 DESC NULLS LAST, 2)")
  }

The execution plan for it looks like:

*(1) Sort [nr#3 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(nr#3 DESC NULLS LAST, 2)
   +- LocalTableScan [nr#3]

As you can see, an interesting thing happens here because Spark will apply the range partitioning algorithm to keep consecutive records close on the same partition. From that, and I'm spoiling a little, having the same sorting object used during the physical execution makes sense. But there is a subtle difference with local ordering.

Local ordering

For the local ordering I'll use the same code...well, almost the same. Instead of calling sort, I'll call sortWithinPartitions(sortExprs: Column*):

  "local ordering" should "avoir shuffle exchange" in {
    val dataset = Seq((30), (40), (10), (20), (11), (25), (99), (109), (9), (2)).toDF("nr")

    val executionPlan = dataset.sortWithinPartitions($"nr".desc).queryExecution.executedPlan

    println(s"${executionPlan}")
    val executionPlanLevels = executionPlan.toString.split("\n")
    executionPlanLevels should have size 2
    executionPlanLevels(0) shouldEqual "*(1) Sort [nr#3 DESC NULLS LAST], false, 0"
    executionPlanLevels(1) shouldEqual "+- LocalTableScan [nr#3]"
  }

As you can see, the step for repartitioning data is missing. To understand why, let's take a look at the code source. The first thing to notice is the setting of global attribute to false in the logical node representing sort operations, org.apache.spark.sql.catalyst.plans.logical.Sort. This attribute is later passed to the physical execution node, org.apache.spark.sql.execution.SortExec which uses it control data distribution through this simple method:

  override def requiredChildDistribution: Seq[Distribution] =
    if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil

And that's the key to understand why shuffle doesn't happen for partition-based ordering. requiredChildDistribution is later used by ensureDistributionAndOrdering(operator: SparkPlan) of EnsureRequirements class to figure out whether adding an extra shuffle step is required or not:

  private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {

// for SortExec, returns 
// "if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil"
// as requiredChildDistribution
    val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution

    children = children.zip(requiredChildDistributions).map {
      case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
        child
      case (child, BroadcastDistribution(mode)) =>
        BroadcastExchangeExec(mode, child)
      case (child, distribution) =>
        val numPartitions = distribution.requiredNumPartitions
          .getOrElse(defaultNumPreShufflePartitions)
        ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
    }

The sort node will return for local sort execution a Distribution being an instance of UnspecifiedDistribution class. Later, this method checks for every child in the plan, whether its current distribution is the one that is expected. For the case from my learning test, the outputPartitioning attribute of the child is an instance of UnknownPartitioning and it automatically satisfies the ... satisfies predicate, so there is not need to add an extra ShuffleExchangeExec node:

trait Partitioning {
  final def satisfies(required: Distribution): Boolean = {
    required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required)
  }
  protected def satisfies0(required: Distribution): Boolean = required match {
    case UnspecifiedDistribution => true
    case AllTuples => numPartitions == 1
    case _ => false
  }
}

// Overriding-free implementation
case class UnknownPartitioning(numPartitions: Int) extends Partitioning

At the end of the ensureDistributionAndOrdering(operator: SparkPlan) method, after adding shuffles for all children needing them (requiredChildDistributions), the initial operator (Sort in our example), gets a new child. For partition-based sorting, this new child is the node from the if case, since the child node for Sort, LocalTableScan, already satisfies the sorting. Why? Simply because it doesn't have any requiredSorting and in orderingSatisfies you can see that for that case, the method always returns true:

  private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
    val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
    
    // ...

    // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
    children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
      // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
      if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
        child
      } else {
        SortExec(requiredOrdering, global = false, child = child)
      }
    }

    operator.withNewChildren(children)
}

abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
  def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
}

object SortOrder {
  def orderingSatisfies(ordering1: Seq[SortOrder], ordering2: Seq[SortOrder]): Boolean = {
    if (ordering2.isEmpty) {
      true
    } else if (ordering2.length > ordering1.length) {
      false
    } else {
      ordering2.zip(ordering1).forall {
        case (o2, o1) => o1.satisfies(o2)
      }
    }
  }
}

And that's all, the short story explaining why global ordering requires an extra shuffle and the local one not. During the code exploration I discovered a few interesting places about sorting implementation and Partition trait. Soon, and if Spark 3.0 is not released before, I will write a blog post about this aspect too!


If you liked it, you should read:

đź“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!