Shuffle in PySpark

Versions: Apache Spark 3.3.0

Shuffle is for me a never-ending story. Last year I spent long weeks analyzing the readers and writers and was hoping for some rest in 2022. However, it didn't happen. My recent PySpark investigation led me to the file and my first reaction was "Oh, so PySpark has its own shuffle mechanism?". Let's check this out!

New ebook 🔥

Learn 84 ways to solve common data engineering problems with cloud services.

👉 I want my copy


PySpark shuffle is not a new concept. It has been there since Apache Spark 1.1.0 (!) and got introduced during 2014 by Davies Liu as a part of SPARK-2538: External aggregation in Python.

The problem PySpark users faced at that time were job failures caused by OOM errors when the reduce tasks had data bigger than the available memory. The solution was to use the same spilling technique as for the JVM part, so writing data from memory on disk.

If you have read the JIRA title carefully, you should have noticed the "external" word in the title. It's not a coincidence. These aggregations work on Python functions, so externally to the Apache Spark JVM engine. If you look for them in the code, you'll very quickly identify some of them because they all start with an External prefix.

To understand the purpose of these external aggregations, let's analyze 3 shuffle-based functions: groupByKey, combineByKey and sortByKey.


The groupByKey operation is only a high-level interface. If you check the implementation out, you will notice that several other transformations occur. Each of them invokes an external aggregation, as shown in the following schema:

The Aggregator is simply a class holding the grouping logic (init, merge values, merge combiners) and there is no shuffle happening in there. The ExternalMerger, on the other hand, is the first place when the shuffle, and eventually spilling, occur. To be more precise, the groupByKey operation uses these 2 methods of the ExternalMerger:

After creating these groups, groupByKey uses yet another External* aggregator, the ExternalGroupBy. The first aggregator (ExternalMerger) is there to group rows sharing the same key while the ExternalGroupBy is there to apply the user function on top of these groups. The groups creation depends on the dataset:


The combineByKey uses only the ExternalMerger. The single difference with the groupByKey is the lack of ExternalGroupBy, replaced in the last step by the mergeCombiners method from the ExternalMerger.

The first part where the ExternalMerger is involved behaves the same as in the groupByKey; it's a mergeValues() call followed by the items():

    def combineByKey(
        self: "RDD[Tuple[K, V]]", createCombiner: Callable[[V], U],
        mergeValue: Callable[[U, V], U], mergeCombiners: Callable[[U, U], U],
        numPartitions: Optional[int] = None, partitionFunc: Callable[[K], int] = portable_hash,
    ) -> "RDD[Tuple[K, U]]":

# ...
        def combineLocally(iterator: Iterable[Tuple[K, V]]) -> Iterable[Tuple[K, U]]:
            merger = ExternalMerger(agg, memory * 0.9, serializer)
            return merger.items()

        locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
        shuffled = locally_combined.partitionBy(numPartitions, partitionFunc)
# ...

The things are different in the final mapPartitions step where the combineByKey calls ExternalMerger's mergeCombiners followed by the items():

    def combineByKey(
        self: "RDD[Tuple[K, V]]", createCombiner: Callable[[V], U],
        mergeValue: Callable[[U, V], U], mergeCombiners: Callable[[U, U], U],
        numPartitions: Optional[int] = None, partitionFunc: Callable[[K], int] = portable_hash,
    ) -> "RDD[Tuple[K, U]]":

# ...

        def _mergeCombiners(iterator: Iterable[Tuple[K, U]]) -> Iterable[Tuple[K, U]]:
            merger = ExternalMerger(agg, memory, serializer)
            return merger.items()

        return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)

The mergeCombiners is the method applying the merge functions on top of the previously created pre-shuffle partial aggregates. It can also lead to create spill files which would result in calling the aforementioned _recursive_merged_items for partitions too big to fit into the memory.


The logic for the last key-based operation doesn't involve an explicit spill. I mean, the code doesn't mention a spill but definitely does one. After all, the sorting logic also relies on an external-like class called ExternalSorter.

Under-the-hood, PySpark first partitions the data and only after bringing records with similar keys together, it applies the sorting logic from the ExternalSorter. The logic does the following:

  1. Divide the iterator to sort into multiple groups and accumulate them in memory as long as the memory usage doesn't reach the allowed threshold.
  2. After reaching the threshold, sort each group in memory with classical Python's .sort() method and serialize the result into a temporary sort file.
        def sorted(self, iterator, key=None, reverse=False):
    # ...
                if used_memory > limit:
                    # sort them inplace will save memory
                    current_chunk.sort(key=key, reverse=reverse)
                    path = self._get_path(len(chunks))
                    with open(path, "wb") as f:
                        self.serializer.dump_stream(current_chunk, f)
                    def load(f):
                        for v in self.serializer.load_stream(f):
                            yield v
                        # close the file explicit once we consume all the items
                        # to avoid ResourceWarning in Python3
                    chunks.append(load(open(path, "rb")))
                    current_chunk = []
    This part also adds the sorted data into chunks array as an open stream. It returns each entry lazily, with a generator (I've already blogged about the generators in PySpark).
  3. Use the heap queue (priority queue) to merge the spilled files, hence, to return the chunks in order. This step calls the heapq.merge(*chunks, key=key, reverse=reverse) generator function and as before with the yield, doesn't materialize the whole data in-memory. Instead, it reads one record at a time from each chunk stream and returns the smallest/biggest one to PySpark. It works because the chunks are already sorted so there is no need to compare more than one chunk item each time.

This article brings another topic to the table. PySpark shuffle relies on the memory used by the PySpark worker and I haven't covered that division yet!