Shuffle writers: SortShuffleWriter

Versions: Apache Spark 3.1.1

In the beginning I thought that the mappers sent shuffle files to the reducers. After understanding that it was the opposite, I was thinking that a part of the shuffle data is kept in memory for the performance purposes... Once I corrected all these misbeliefs about shuffle, I noted a few points to explore. One of these points are shuffle writers that I will present in the next 3 blog posts.

appens when it uses it for shuffling. You should also understand - if it's not the case yet - what the spilling is.

I ended up with a quite huge article. I don't like this format but the topic was more complex than I had thought. So, you can grab a cup of coffee or tea before reading :)

SortShuffleManager

SortShuffleManager is the default and unique shuffle manager present in vanilla Apache Spark. You can replace it with a class set in the spark.shuffle.manager property, though. The implementation must implement the ShuffleManager interface contract. But that's not our concern here!

The ShuffleManager interface exposes the methods to write, read and manage shuffle files. Well, technically speaking, the methods return the classes responsible for writing, reading and management. And because in this series I will only focus on the writers, below you will find only this part covered.

The first important method for the writers is registerShuffle(shuffleId: Int, dependency: ShuffleDependency[K, V, C]). And surprisingly, it doesn't register the shuffle! Instead, it returns one of 3 implementations of the ShuffleHandle interface: BypassMergeSortShuffleHandle, SerializedShuffleHandle, BaseShuffleHandle. The returned object depends on multiple factors that I will detail in every blog post of the series. But I think you see it already. Yes, this ShuffleHandle type determines the shuffle writer type! If you don't take my word, let's check the getWriter method of the default SortShuffleManager:

  override def getWriter[K, V](
      handle: ShuffleHandle,
      mapId: Long,
      context: TaskContext,
      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
// ...
    handle match {
      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
        new UnsafeShuffleWriter(...)
      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
        new BypassMergeSortShuffleWriter(...)
      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
        new SortShuffleWriter(...)
    }

Besides the writer type generation, the getWriter method is also the place where Apache Spark registers the mapping between the shuffle id and the task from the mapping stage:

  override def getWriter[K, V](
      handle: ShuffleHandle,
      mapId: Long,
      context: TaskContext,
      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
    val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent(
      handle.shuffleId, _ => new OpenHashSet[Long](16))
    mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) }
// ...
}

The role of SortShuffleManager stops. It's then a good moment to see under what conditions it generates the BaseShuffleHandle, associated with SortShuffleWriter type.

SortShuffleWriter - when?

SortShuffleWriter is the last choice from the shuffle writers. In other words, shuffle manager uses it when neither BypassMergeSortShuffleWriter nor UnsafeShuffleWriter can be used, meaning that none of these conditions is met:

SortShuffleWriter - sorter

The key element of the SortShuffleWriter is the sorted field representing an instance of the ExternalSorter class. The writer initializes it before starting to physically write all records of the map task. The initialization logic depends on the map-side aggregation (local aggregation). If it's present, the sorter is created with an aggregator and ordering key:

private var sorter: ExternalSorter[K, V, _] = null
  override def write(records: Iterator[Product2[K, V]]): Unit = {
    sorter = if (dep.mapSideCombine) {
      new ExternalSorter[K, V, C](
        context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      new ExternalSorter[K, V, V](
        context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
    }

After this initialization step, all map records are passed to the ExternalSorter's insertAll(records: Iterator[Product2[K, V]]) method where the logic depends on the map-side aggregation:

  override def write(records: Iterator[Product2[K, V]]): Unit = {
// ...
    sorter.insertAll(records)
// ...
}

  def insertAll(records: Iterator[Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined

    if (shouldCombine) {
// Logic#1; map-side aggregation
    } else {
// Logic#2; no map-side aggregation
}
}

SortShuffleWriter - without partial aggregation

The logic behind the scenario not involving map-side aggregation is quite straightforward. Input records are key-value pairs and the sorter iterates over all of them. During the iteration, it computes the shuffle partition number for every key and inserts the record to the PartitionedPairBuffer:

      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)

The insert method doesn't involve any sorting. It simply inserts the key and value at the next available positions:

  def insert(partition: Int, key: K, value: V): Unit = {
    if (curSize == capacity) {
      growArray()
    }
    data(2 * curSize) = (partition, key.asInstanceOf[AnyRef])
    data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
    curSize += 1
    afterUpdate() 
  }

SortShuffleWriter - with partial aggregation

The scenario involving the aggregation starts by defining the merge and combine function for the map-side aggregation:

      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }

The mergeValue will be responsible for merging so far accumulated values with the new input entry whereas the createCombiner will create the initial value for the partial aggregation. Later, the writer will iterate over all input records and apply the update function each time. The result of this operation will be put to the SizeTrackingAppendOnlyMap instance storing the data for the reducer:

      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        map.changeValue((getPartition(kv._1), kv._1), update)
        maybeSpillCollection(usingMap = true)
      }

SortShuffleWriter - spilling

If you check again the 2 iteration snippets, you will notice the call to maybeSpillCollection(usingMap: Boolean) method. It's responsible for spilling memory data to disk. The method starts by estimating the size of the underlying in-memory shuffle data structure after adding new entry. It passes this number to the maybeSpill method which returns true if the structure holds too much data in-memory. When it happens, a new instance of the used in-memory structure (map or buffer) is created:

  private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    if (usingMap) {
      estimatedSize = map.estimateSize()
      if (maybeSpill(map, estimatedSize)) {
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
      estimatedSize = buffer.estimateSize()
      if (maybeSpill(buffer, estimatedSize)) {
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

How does the spilling work? Every 32 read elements, the sorter checks whether it can acquire more memory from the execution pool. If not, it means that the so far accumulated items in memory should be written to disk:

  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { 
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      myMemoryThreshold += granted 
      shouldSpill = currentMemory >= myMemoryThreshold
    }
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold

Please notice, the spilling can also happen if the writer read more than spark.shuffle.spill.numElementsForceSpillThreshold records (default: 2147483647), represented by numElementsForceSpillThreshold in the snippet.

The spilling starts by logging this action with a message like Thread %d spilling in-memory map of %s to disk (%d time%s so far). After that, the ExternalSorter initializes an instance of WritablePartitionedIterator (inMemoryIterator below), that will write the records of every shuffle partition to the file. The writing happens at partition basis; i.e. all records for one partition are written at once:

  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { 
// ...
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
// ...
}

class ExternalSorter {
  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
    val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
    val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
    spills += spillFile
  }

  private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
      : SpilledFile = {

    val (blockId, file) = diskBlockManager.createTempShuffleBlock()
   val writer: DiskBlockObjectWriter =
      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
// ...
      while (inMemoryIterator.hasNext) {
        val partitionId = inMemoryIterator.nextPartition()
        require(partitionId >= 0 && partitionId < numPartitions,
          s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
        inMemoryIterator.writeNext(writer)
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1

        if (objectsWritten == serializerBatchSize) {
          flush()
        }
}
    // Flush the disk writer's contents to disk, and update relevant variables.
    // The writer is committed at the end of this process.
    def flush(): Unit = {
      val segment = writer.commitAndGet()
      batchSizes += segment.length
      _diskBytesSpilled += segment.length
      objectsWritten = 0
    }

SortShuffleWriter - sorting for no aggregation scenario

So far, we didn't see any trace of sorting. That's normal because it happens before the presented snippet! When the writer calls destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]]), it internally creates a partitioned destructive sorted iterator wrapped by a WritablePartitionedIterator:

private[spark] trait WritablePartitionedPairCollection[K, V] {
  def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
    : WritablePartitionedIterator = {
    val it = partitionedDestructiveSortedIterator(keyComparator)
    new WritablePartitionedIterator {
      private[this] var cur = if (it.hasNext) it.next() else null

      def writeNext(writer: PairsWriter): Unit = {
        writer.write(cur._1._2, cur._2)
        cur = if (it.hasNext) it.next() else null
      }

      def hasNext(): Boolean = cur != null

      def nextPartition(): Int = cur._1._1
    }
  }

WritablePartitionedPairCollection is an interface and has 2 implementations, the PartitionedAppendOnlyMap and PartitionedPairBuffer. Which one is involved in the spilling? It depends. For the non-aggregation scenario, it will be the PartitionedPairBuffer. It will first sort the records with a sort-merge-like TimSort algorithm in Sorter class. Later, the iterator's next() method will return the sorted key-value pairs one-by-one:

private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
  extends WritablePartitionedPairCollection[K, V] with SizeTracker 
// ...
  def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] =
    (a: (Int, K), b: (Int, K)) => {
      val partitionDiff = a._1 - b._1
      if (partitionDiff != 0) {
        partitionDiff
      } else {
        keyComparator.compare(a._2, b._2)
      }
    }

  override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
    : Iterator[((Int, K), V)] = {
    val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
    new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
    iterator
  }

  private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
    var pos = 0
// ...

    override def next(): ((Int, K), V) = {
      if (!hasNext) {
        throw new NoSuchElementException
      }
      val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V])
      pos += 1
      pair
    }
  }

SortShuffleWriter - sorting for aggregation scenario

For the map-aggregation scenario, the writer uses AppendOnlyMap. As for the PartitionedPairBuffer, It also sorts the values with TimSort before exposing them from the iterator:

private[spark] class PartitionedAppendOnlyMap[K, V]
  extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] {

  def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
    : Iterator[((Int, K), V)] = {
    val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
    destructiveSortedIterator(comparator)
  }
// ...
    new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)

    new Iterator[(K, V)] {
      var i = 0
      var nullValueReady = haveNullValue
      def hasNext: Boolean = (i < newIndex || nullValueReady)
      def next(): (K, V) = {
        if (nullValueReady) {
          nullValueReady = false
          (null.asInstanceOf[K], nullValue)
        } else {
          val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
          i += 1
          item
        }
      }
    }

For both scenarios the sorting is made per shuffle partition or per shuffle partition and a record key if the user code orders the records. You can notice this user sorting logic in the ExternalSorter class:

  def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
 // ...
      if (ordering.isEmpty) {
        // The user hasn't requested sorted keys, so only sort by partition ID, not key
        groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
      } else {
        // We do need to sort by both partition ID and key
        groupByPartition(destructiveIterator(
          collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
      }

Final output generation

So far we have potentially 0 or more spilled files on disk, ordered by the partition id, and one in-memory structure holding all data not written on disk. When does the final output generation happen? Once all records inserted, the SortShuffleWriter initializes an instance of ShuffleMapOutputWriter that is passed to the ExternalSorter's writePartitionedMapOutput method:

private[spark] class SortShuffleWriter[K, V, C] {

  override def write(records: Iterator[Product2[K, V]]): Unit = {
// ...
    val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
      dep.shuffleId, mapId, dep.partitioner.numPartitions)
    sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
    val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)

Two things can happen in writePartitionedMapOutput. The first case handles the case without prior spilling. The logic of this operation is very similar to the one you saw previously for the spilling, though! The writer creates an iterator from destructiveSortedWritablePartitionedIterator method and iterates over all shuffle partitions to write them sequentially to corresponding shuffle block file (shuffle_${shuffleId}_${mapId}_${reduceId} ):

      while (it.hasNext()) {
        val partitionId = it.nextPartition()
        var partitionWriter: ShufflePartitionWriter = null
        var partitionPairsWriter: ShufflePartitionPairsWriter = null
        TryUtils.tryWithSafeFinally {
          partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
          val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
          partitionPairsWriter = new ShufflePartitionPairsWriter(
            partitionWriter,
            serializerManager,
            serInstance,
            blockId,
            context.taskMetrics().shuffleWriteMetrics)
          while (it.hasNext && it.nextPartition() == partitionId) {
            it.writeNext(partitionPairsWriter)
          }
        } {
          if (partitionPairsWriter != null) {
            partitionPairsWriter.close()
          }
        }
        nextPartitionId = partitionId + 1
      }

1 file per shuffle partition or 1 file per mapper task?

As you can see from the snippet above, ShuffleBlockId is created for every shuffle partition. It's represented by 3 different ids, though. The id represents the shuffle partition number. The mapId is the id of the map tasks. And finally, the shuffleId is the shuffleId attribute of the ShuffleDependency class, incremented at every shuffle operation:

private[spark] class SortShuffleWriter[K, V, C](...) {
  private val dep = handle.dependency

  override def write(records: Iterator[Product2[K, V]]): Unit = {
// ...
    val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
      dep.shuffleId, mapId, dep.partitioner.numPartitions)
    sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
// ...
}
}

private[spark] class BaseShuffleHandle[K, V, C](
    shuffleId: Int,
    val dependency: ShuffleDependency[K, V, C])
  extends ShuffleHandle(shuffleId)

class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](...) {

val shuffleId: Int = _rdd.context.newShuffleId()

}
class SparkContext(config: SparkConf) extends Logging {

  private val nextShuffleId = new AtomicInteger(0)

  private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement()

However, it doesn't mean the shuffle writer will create 1 file data and index file for every shuffle partition! In fact, the data file generation happens inside LocalDiskShuffleMapOutputWriter:

public LocalDiskShuffleMapOutputWriter(int shuffleId, long mapId, int numPartitions, IndexShuffleBlockResolver blockResolver, SparkConf sparkConf) {
// ...
    this.outputFile = blockResolver.getDataFile(shuffleId, mapId);
    this.outputTempFile = null;
  }

As you can notice, the getDataFile function is called only with the shuffle and map ids! When it happens, Apache Spark considers the data file as no-op reduce id operation:

private[spark] class IndexShuffleBlockResolver(
    conf: SparkConf,
    _blockManager: BlockManager = null)
  extends ShuffleBlockResolver
  with Logging {

  def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None)
 
   def getDataFile(shuffleId: Int, mapId: Long, dirs: Option[Array[String]]): File = {
    val blockId = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
    dirs
      .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name))
      .getOrElse(blockManager.diskBlockManager.getFile(blockId))
  }
}
private[spark] object IndexShuffleBlockResolver {
  // No-op reduce ID used in interactions with disk store.
  // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
  // shuffle outputs for several reduces are glommed into a single file.
  val NOOP_REDUCE_ID = 0
}

Whenever a new shuffle partition is written, the local disk writer uses a FileOutputStream wrapping this no-op file that for the time of this operation is considered as temporary (hence outputTempFile):

public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {

  @Override
  public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException {
    if (reducePartitionId <= lastPartitionId) {
      throw new IllegalArgumentException("Partitions should be requested in increasing order.");
    }
    lastPartitionId = reducePartitionId;
    if (outputTempFile == null) {
      outputTempFile = Utils.tempFileWith(outputFile);
    }
    if (outputFileChannel != null) {
      currChannelPosition = outputFileChannel.position();
    } else {
      currChannelPosition = 0L;
    }
    return new LocalDiskShufflePartitionWriter(reducePartitionId);
  }

  private void initStream() throws IOException {
    if (outputFileStream == null) {
      outputFileStream = new FileOutputStream(outputTempFile, true);
    }
    if (outputBufferedFileStream == null) {
      outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize);
    }
  }

// This class is exposed to the shuffle writer but as you can notice
// it uses the output streams defined at the level of
//  LocalDiskShuffleMapOutputWriter
  private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter {
//     @Override
    public OutputStream openStream() throws IOException {
      if (partStream == null) {
        if (outputFileChannel != null) {
          throw new IllegalStateException("Requested an output channel for a previous write but" +
              " now an output stream has been requested. Should not be using both channels" +
              " and streams to write.");
        }
        initStream();
        partStream = new PartitionWriterStream(partitionId);
      }
      return partStream;
    }

    @Override
    public Optional openChannelWrapper() throws IOException {
      if (partChannel == null) {
        if (partStream != null) {
          throw new IllegalStateException("Requested an output stream for a previous write but" +
              " now an output channel has been requested. Should not be using both channels" +
              " and streams to write.");
        }
        initChannel();
        partChannel = new PartitionWriterChannel(partitionId);
      }
      return Optional.of(partChannel);
    }
  private class PartitionWriterStream extends OutputStream {
    @Override
    public void write(int b) throws IOException {
      verifyNotClosed();
      outputBufferedFileStream.write(b);
      count++;
    }
 
// ...

Final output generation after spilling

What happens for the spilling scenario? The ExternalSorter works on a partitionedIterator returning the elements for each shuffle partition:

      for ((id, elements) <- this.partitionedIterator) {
        val blockId = ShuffleBlockId(shuffleId, mapId, id)

If some data was spilled before, the partitionedIterator method takes care of merging it with the in-memory shuffle data. Once again, it uses the partitioned destructive sorted iterator:

  def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
    val usingMap = aggregator.isDefined
    val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
    if (spills.isEmpty) {
// ...
    } else {
      // Merge spilled and in-memory data
      merge(spills.toSeq, destructiveIterator(
        collection.partitionedDestructiveSortedIterator(comparator)))
    }
  }

  private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    val readers = spills.map(new SpillReader(_))
    val inMemBuffered = inMemory.buffered

The merged iterator is built by looping over all shuffle partitions and merging spilled and in-memory data for each of them. Remember, the spills sorts the data per partition whereas the in-memory structure is based on the shuffle partition key. So this operation is relatively easy:

    (0 until numPartitions).iterator.map { p =>
      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)

At the end, depending on the shuffled operation - with/without map-side aggregation, with/without ordering - the iterator values are either aggregated, sorted or flattened:

      if (aggregator.isDefined) {
        // Perform partial aggregation across partitions
        (p, mergeWithAggregation(
          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
      } else if (ordering.isDefined) {
        // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
        // sort the elements without trying to merge them
        (p, mergeSort(iterators, ordering.get))
      } else {
        (p, iterators.iterator.flatten)
      }

These iterators are later returned to the writePartitionedMapOutput and written to shuffle file with the same class as for the non-spilled scenario:

private[spark] class ExternalSorter[K, V, C](...) {

  def writePartitionedMapOutput(
      shuffleId: Int,
      mapId: Long,
      mapOutputWriter: ShuffleMapOutputWriter): Unit = {
    var nextPartitionId = 0
    if (spills.isEmpty) {
// ...
    } else {
      for ((id, elements) <- this.partitionedIterator) {
        val blockId = ShuffleBlockId(shuffleId, mapId, id)
        var partitionWriter: ShufflePartitionWriter = null
        var partitionPairsWriter: ShufflePartitionPairsWriter = null
        TryUtils.tryWithSafeFinally {
          partitionWriter = mapOutputWriter.getPartitionWriter(id)
          partitionPairsWriter = new ShufflePartitionPairsWriter(
            partitionWriter,
            serializerManager,
            serInstance,
            blockId,
            context.taskMetrics().shuffleWriteMetrics)
          if (elements.hasNext) {
            for (elem <- elements) {
              partitionPairsWriter.write(elem._1, elem._2)
            }
          }
        } {
          if (partitionPairsWriter != null) {
            partitionPairsWriter.close()
          }
        }
        nextPartitionId = id + 1
      }

Final output generation - index file

After creating the shuffle file with partition-ordered records, the SortShuffleWriter calls commitAllPartitions() that creates the index file indicating where each reducer task can find his shuffle data:

public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
// ...
  @Override
  public long[] commitAllPartitions() throws IOException {
// ...
    cleanUp();
    File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
    log.debug("Writing shuffle index file for mapId {} with length {}", mapId,
        partitionLengths.length);
    blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
    return partitionLengths;
  }

How does this method work? The LocalDiskShuffleMapOutputWriter, which is the default shuffle output writer, tracks the number of bytes written by each shuffle partition in its private long[] partitionLengths field. This attribute is updated when the shuffle partition writer put all shuffle data to the final file:

  private class PartitionWriterStream extends OutputStream {
// ...
    @Override
    public void write(byte[] buf, int pos, int length) throws IOException {
      verifyNotClosed();
      outputBufferedFileStream.write(buf, pos, length);
      count += length;
    }

    @Override
    public void close() {
      isClosed = true;
      partitionLengths[partitionId] = count;
      bytesWrittenToMergedFile += count;
    }

To thank you for staying with me, I created this schema summarizing all the classes and interactions presented in the blog post:

Class Responsibilities
SortShuffleManager returns: ShuffleHandle, ShuffleReader, ShuffleWriter instances
BaseShuffleHandle ShuffleHandle used by SortShuffleWriter
SortShuffleWriter the least preferred ShuffleWriter used for BaseShuffleHandle
PartitionedAppendOnlyMap in-memory buffer for the partial aggregation cenario
PartitionedPairBuffer in-memory buffer for no partial aggregation scenario
ExternalSorter processes map task data to shuffle, triggers spilling if the in-memory buffer is full

You can then notice that all the complexity is delegated to the ExternalSorter and that the SortShuffleWriter acts more as a façade for the shuffle files materialization. In the next articles I will present 2 other shuffle writers but if you want to see the SortShuffleWriter in action, please check the demo below. Thanks for reading!


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!