Shuffle reading in Apache Spark SQL

Versions: Apache Spark 3.1.1

So far I've covered the writing part of the shuffle files. You've learned about 3 different shuffle writers, but what happens with their generated files? Who and how reads them? Is the reading an in-memory operation? I will try to answer this and some other questions in this blog post.

ShuffleRowRDD

I did mention "Apache Spark SQL" in the title of this article on purpose. Apache Spark has 2 abstractions responsible for dealing with shuffle files, the ShuffledRDD and ShuffleRowRDD. The former one interacts with the RDD API whereas the latter one with the Dataset API. Since the Dataset API is a recommended way to go in most of the cases, I've decided to focus only on it.

The ShuffleRowRDD is created by ShuffleExchangeExec with the following parameters:

class ShuffledRowRDD(
    var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
    metrics: Map[String, SQLMetric],
    partitionSpecs: Array[ShufflePartitionSpec])

The first important attribute is the ShuffleDependency instance. It brings the shuffle id parameter with an internal ShuffleHandle field. The second important attribute is the partitionSpecs array. It defines the type of the shuffle partitions which can be:

The partition specifications are important because they will define the shuffle reader's behavior. Technically speaking, they all share the same reader type which is an instance of ShuffleReader[K, C] but depending on the shuffle type, the reader will get different parameters:

Reader initialization

By calling SparkEnv.get.shuffleManager.getReader, the ShuffleRowRDD calls the SortShuffleManager. It does 2 things. First, it identifies the location of the shuffle blocks to fetch. It does it with the map output tracker component that I will detail in the next section. After that, the manager creates an instance of BlockStoreShuffleReader that will be responsible for passing shuffle files from the mappers to the reducer task.

The ShuffleRowRDD uses the read() method to iterate over the shuffle data and return it to the client's code:

class ShuffledRowRDD(
    var dependency: ShuffleDependency[Int, InternalRow, InternalRow],
    metrics: Map[String, SQLMetric],
    partitionSpecs: Array[ShufflePartitionSpec])
  extends RDD[InternalRow](dependency.rdd.context, Nil) {

  override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
    val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
    // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
    // as well as the `tempMetrics` for basic shuffle metrics.
    val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
    val reader = split.asInstanceOf[ShuffledRowRDDPartition].spec match {
// ...
    reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2)
  }

Map output tracker

Apache Spark has 2 map output trackers. The first of them is MapOutputTrackerMaster. It resides in the driver and keeps track of the map outputs for each stage. It mainly communicates with DAGScheduler:

The second map output tracker is MapOutputTrackerWorker, located on the executors. It's responsible for fetching the shuffle metadata information from the master tracker. It happens when the worker tracker hasn't cached the shuffle information and has to send a GetMapOutputStatuses message to get it. The shuffle information contains a list of MapStatus. Each item defines the physical location and map task id. Thanks to this list, the tracker can communicate the fetchable elements to the shuffle reader:

class BlockManagerId private (
    private var executorId_ : String,
    private var host_ : String,
    private var port_ : Int,
    private var topologyInfo_ : Option[String])

private[spark] sealed trait MapStatus { 
  def location: BlockManagerId
// ...
  def mapId: Long
}

private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
  private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = {
/// ...
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
     fetchingLock.withLock(shuffleId) {
        var fetchedStatuses = mapStatuses.get(shuffleId).orNull
        if (fetchedStatuses == null) {
          logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
          logInfo("Got the output locations")
          mapStatuses.put(shuffleId, fetchedStatuses)
        } 
        fetchedStatuses
// ...
}
}

The fetched information is passed to the BlockStoreShuffleReader as a list of shuffle blocks with their physical locations:

private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { 
  override def getReader[K, C](
      handle: ShuffleHandle,
      startMapIndex: Int, endMapIndex: Int, startPartition: Int, endPartition: Int,
      context: TaskContext,
      metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
    val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
      handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
    new BlockStoreShuffleReader(
      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
      shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
  }

BlockStoreShuffleReader

At this point Apache Spark only knows the shuffle metadata. But thanks to it, it can now pass to the physical reading in the BlockStoreShuffleReader class. The reader initializes an instance of ShuffleBlockFetcherIterator that will be responsible for getting shuffle blocks from executors (external shuffle service is not covered in this blog post):

private[spark] class BlockStoreShuffleReader[K, C] {
// ...
  override def read(): Iterator[Product2[K, C]] = {
    val wrappedStreams = new ShuffleBlockFetcherIterator( /* ... */ )
    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
      
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }
// ...

The BlockStoreShuffleReader has a few more lines with other iterators wrapping the raw shuffle data. However, to not make this article too long, I will cover this part in a follow-up blog post. Here, I'll only focus on the physical data retrieval, so on the ShuffleBlockFetcherIterator.

ShuffleBlockFetcherIterator

The first important observation about this iterator is the configuration. The constructor takes various configuration parameters to control the number of fetch calls. These controls are size-based (max size of pending reads in bytes) and count-based (max number of the pending fetch requests, max number of pending requests for every remote host), and memory-based (the max number of shuffle data to keep in memory).

When the shuffle reader creates the instance of the ShuffleBlockFetcherIterator, the iterator calls an initialization step defined in its initialize() method. It starts by defining a list of fetch requests from the list of blocksByAddress got from the map output tracker. Depending on the block storage - local or remote host - the block request can be skipped because the iterator will read the shuffle data directly from disk. The requests executed on a remote node are later randomly shuffled (BTW, you know why the shuffle retrieval is not deterministic!) and saved into val fetchRequests = new Queue[FetchRequest]:

  private[this] def initialize(): Unit = {
    // ...
    // Partition blocks by the different fetch modes: local, host-local and remote blocks.
    val remoteRequests = partitionBlocksByFetchMode()
    // Add the remote requests into our queue in a random order
    fetchRequests ++= Utils.randomize(remoteRequests)

After building the list of fetch requests, the iterator will start the first reading loop that will stop after reading the max number of in-flight bytes to read or pending requests.

Fetching

The physical fetching happens in the fetchUpToMaxBytes() and it consists in:

The client runs the fetching process in fetchBlocks() method with the help of another class called OneForOneBlockFetcher. Since there is a lot of going and comings between the classes, I will detail the algorithm in the list below:

Fetched shuffle data is later put to the private[this] val results = new LinkedBlockingQueue[FetchResult] and returned to the caller from the next() method. If all available block data was already consumed, the iterator will execute the fetchUpToMaxBytes() presented before:

  override def next(): (BlockId, InputStream) = {
    if (!hasNext) {
      throw new NoSuchElementException()
    }

    numBlocksProcessed += 1

    var result: FetchResult = null
    var input: InputStream = null
    var streamCompressedOrEncrypted: Boolean = false

    while (result == null) {
      val startFetchWait = System.nanoTime()
      result = results.take()
      val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait)
      shuffleMetrics.incFetchWaitTime(fetchWaitTime)

      result match {
        case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
// ...
          }

        case FailureFetchResult(blockId, mapIndex, address, e) =>
          throwFetchFailedException(blockId, mapIndex, address, e)
      }

      // Send fetch requests up to maxBytesInFlight
      fetchUpToMaxBytes()
    }

    currentResult = result.asInstanceOf[SuccessFetchResult]
    (currentResult.blockId,
      new BufferReleasingInputStream(
        input,
        this,
        currentResult.blockId,
        currentResult.mapIndex,
        currentResult.address,
        detectCorrupt && streamCompressedOrEncrypted))
  }

And I will stop here for today. As you can see, shuffle reading is not a straightforward action. It involves a lot of components controlling the executed logic and RPC calls. In the next blog post I'll focus on the part using the fetched shuffle blocks.

If you liked it, you should read:

The comments are moderated. I publish them when I answer, so don't worry if you don't see yours immediately :)

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!