Shuffle reading in Apache Spark SQL - wrapping iterators and beyond

Versions: Apache Spark 3.1.1

It's time for the 2nd blog post about the shuffle readers. Recently, we discovered how Apache Spark fetches the shuffle blocks from local and remote hosts. Today, I would like to share with you the wrapping iterators. Sounds mysterious? It won't be if we start by looking at the iterators participating in the processing of shuffle block files.

Data Engineering Design Patterns

Looking for a book that defines and solves most common data engineering problems? I'm currently writing one on that topic and the first chapters are already available in πŸ‘‰ Early Release on the O'Reilly platform

I also help solve your data engineering problems πŸ‘‰ contact@waitingforcode.com πŸ“©

A high-level view of them looks like that:

private[spark] class BlockStoreShuffleReader[K, C](
// ...
) {
  override def read(): Iterator[Product2[K, C]] = {
    val wrappedStreams = new ShuffleBlockFetcherIterator(
// ...
).toCompletionIterator
   val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => 
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
// …
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// ...
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
)
resultIter match {
      case _: InterruptibleIterator[Product2[K, C]] => resultIter
      case _ =>
        new InterruptibleIterator[Product2[K, C]](context, resultIter)
    }

It makes a lot of different iterators! Let me introduce them in the next section.

Shuffle iterators

The iterators you've seen in the snippets are:

What instead of ShuffleDependency?

As I've mentioned, Apache Spark SQL doesn't use the metadata associated to the ShuffleDependency class to execute a sorting or aggregating post-shuffle operation.But how does it do that? The answer is hidden in the execution plan. Below you can see the plan generated for a groupByKey(...).mapGroups(...) operation. For you, which node is an unusual one?

== Physical Plan ==
*(2) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#15]
+- MapGroups org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1292/1990828041@342beaf6, value#12.toString, createexternalrow(id#7.toString, login#8.toString, StructField(id,StringType,true), StructField(login,StringType,true)), [value#12], [id#7, login#8], obj#14: java.lang.String
   +- *(1) Sort [value#12 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(value#12, 10), ENSURE_REQUIREMENTS, [id=#15]
         +- AppendColumns com.waitingforcode.AggregatorExample$$$Lambda$1287/131837504@5866731, createexternalrow(id#7.toString, login#8.toString, StructField(id,StringType,true), StructField(login,StringType,true)), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#12]
            +- LocalTableScan [id#7, login#8]

It's the highlighted Sort. Apache Spark SQL relies on the execution plan to provide a correct input to the post-shuffle operations. And the correct input for mapGroups requires the data to be sorted by the grouping key. It's explained in the comment of GroupedIterator class used in the physical execution:

/**
 * Iterates over a presorted set of rows, chunking it up by the grouping expression.  Each call to
 * next will return a pair containing the current group and an iterator that will return all the
 * elements of that group.  Iterators for each group are lazily constructed by extracting rows
 * from the input iterator.  As such, full groups are never materialized by this class.
* // …
*/

Shuffle reading picture

Sometimes the shuffle data must be materialized in an intermediate stage, as for the mapGroups example presented above. Sometimes it doesn't need to be, and it's the case of accumulative operations like count or sum. A plan for the former one is missing the intermediary node:

== Physical Plan ==
*(2) HashAggregate(keys=[value#19], functions=[count(1)], output=[key#24, count(1)#23L])
+- Exchange hashpartitioning(value#19, 10), ENSURE_REQUIREMENTS, [id=#50]
   +- *(1) HashAggregate(keys=[value#19], functions=[partial_count(1)], output=[value#19, count#30L])
      +- *(1) Project [value#19]
         +- AppendColumns com.waitingforcode.AggregatorExample$$$Lambda$2599/1548162287@21f91efa, createexternalrow(id#7.toString, login#8.toString, StructField(id,StringType,true), StructField(login,StringType,true)), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#19]
            +- LocalTableScan [id#7, login#8]

The fact of how shuffle blocks will be consumed depends then on the operation. The mapGroups will need to materialize them because of the sorting logic using spillable UnsafeExternalSorter:

/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean sort_needToSort_0;
/* 010 */   private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter_0;

/* 019 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 020 */     partitionIndex = index;
/* 021 */     this.inputs = inputs;
/* 022 */     sort_needToSort_0 = true;
/* 023 */     sort_sorter_0 = ((org.apache.spark.sql.execution.SortExec) references[0] /* plan */).createSorter();
// ..
/* 028 */   }
/* 030 */   private void sort_addToSorter_0() throws java.io.IOException {
/* 031 */     while ( inputadapter_input_0.hasNext()) {
/* 032 */       InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 034 */       sort_sorter_0.insertRow((UnsafeRow)inputadapter_row_0);
/* 035 */       // shouldStop check is eliminated
/* 036 */     }
/* 038 */   }

The count aggregation acts more like a fire-and-forget consumer. It processes the pre-shuffle partially aggregated rows and uses them to update the local aggregation buffers in ObjectAggregationIterator. It doesn't need to materialize the shuffle blocks beforehand.

Reading sum-up

The shuffle readers series was quite long and believe me, I haven't covered all topics yet! However, we can already try to summarize the key points of the shuffle readers:

I hope that thanks to the recent blog posts, the Apache Spark shuffle part has less mysteries, but I also hope you have some questions to share that will maybe help to write the next blog posts of the series!


If you liked it, you should read:

πŸ“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!