Generators and PySpark

Versions: Apache Spark 3.2.1

I remember the first PySpark codes I saw. They were pretty similar to the Scala ones I used to work with except one small detail, the yield keyword. Since then, I've understood their purpose but have been actively looking for an occasion to blog about them. Growing the PySpark section is a great opportunity for this!

New ebook 🔥

Learn 84 ways to solve common data engineering problems with cloud services.

👉 I want my Early Access edition

Generators

To put it simply, generators generate the data. They're not collections because they don't store anything. Instead, they just provide the items to the calling process one element at a time.

If this overly simplified definition is not enough, below you can find more details:

I think you got the idea! But how does it work with PySpark?

Generators in mapPartitions

One place where you will encounter the generators is the mapPartitions transformation which applies the mapping function to all elements of the partition. As you might already deduce, the lazy character of the generators avoids materializing the mapped result in memory on the Python side. To understand it better, let's take the following code as an example:

from pyspark.sql import SparkSession
import resource

def using(point=""):
    usage=resource.getrusage(resource.RUSAGE_SELF)
    return '''%s: usertime=%s systime=%s mem=%s mb
           '''%(point,usage[0],usage[1],
                usage[2]/1024.0 )

spark = SparkSession.builder.master("local[*]")\
    .appName("Yield in mapPartitions")\
    .getOrCreate()

input_data = spark.sparkContext.parallelize(list(range(0, 10000000)), 1)

def map_numbers_with_list(numbers):
    output_letters = []
    for number in numbers:
        output_letters.append(number*200)
        if number == 0 or number >= 10000000 - 2:
            print(using(f"list{number}"))
    return output_letters


mapped_result = input_data.mapPartitions(map_numbers_with_list).count()

Agree, the code is not very representative because it only counts the elements. However, this simplistic example is useful to understand the difference between the list-based and the generator-based approach just below:

def map_numbers_with_generator(numbers):
    for number in numbers:
        yield number*200
        if number == 0 or number >= 10000000 - 2:
            print(using(f"list{number}"))

For the list-based approach, the memory usage printed from the using(...) function increases from 21MB in the first print to 406MB in the last one. It doesn't happen for the generator-based solution, where the used memory remains relatively stable around 21MB. The Python interpreter memory pressure is then lower for the generator.

Can you be more precise?

Yes, but I had to add some debugging code first. And I'm quite happy because figuring this out wasn't something easy. First, I tried the same approach as for the Scala API, hence adding some breakpoints in PyCharm but it didn't work. Next, I thought about logging but nothing was printed. Hopefully, I took a break and found the solution. The code executed by Python interpreter doesn't come from .py files of the package but from py.files of the pyspark.zip located here:

So I opened the archive, found the serializers.py and added some dummy print to the serialization and deserialization methods:

class AutoBatchedSerializer(BatchedSerializer):
    """
    Choose the size of batch automatically based on the size of object
    """

    def __init__(self, serializer, bestSize=1 << 16):
        BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE)
        self.bestSize = bestSize

    def dump_stream(self, iterator, stream):
        print(f'AutoBatchedSerializer types for iterator={type(iterator)} and stream={type(stream)}')
        print(f'AutoBatchedSerializer data for iterator={iterator} and stream={stream}')
        batch, best = 1, self.bestSize
        print(f'AutoBatchedSerializer  batch, best = {batch}, {best}')
        iterator = iter(iterator)
        while True:
            vs = list(itertools.islice(iterator, batch))
            print(f'AutoBatchedSerializer stream  VS for {vs}')
            if not vs:
                break

            bytes = self.serializer.dumps(vs)
            write_int(len(bytes), stream)
            print(f'AutoBatchedSerializer Writing {len(bytes)} to {type(stream)}')
            stream.write(bytes)

            size = len(bytes)
            if size < best:
                batch *= 2
            elif size > best * 10 and batch > 1:
                batch //= 2

After, I reduced the range size to 10 and rerun the generator- and list-based apps, this time with the collect() action. Additionally, I introduced a print("yield") and print("append") inside the for loop from the mapPartitions function. Here is what I got:

# Generator-based app
AutoBatchedSerializer types for iterator=<class 'generator'> and stream=<class '_io.BufferedWriter'>
AutoBatchedSerializer data for iterator=<generator object map_letters_generator at 0x7efc26a6af90> and stream=<_io.BufferedWriter name=5>
AutoBatchedSerializer  batch, best = 1, 65536

yield
AutoBatchedSerializer stream  VS for [0] 
AutoBatchedSerializer Writing 17 to <class '_io.BufferedWriter'>
yield
yield
AutoBatchedSerializer stream  VS for [200, 400] 
AutoBatchedSerializer Writing 21 to <class '_io.BufferedWriter'>
yield
yield
yield
yield
AutoBatchedSerializer stream  VS for [600, 800, 1000, 1200] 
AutoBatchedSerializer Writing 28 to <class '_io.BufferedWriter'>
yield
yield
yield
0 milliseconds
AutoBatchedSerializer stream  VS for [1400, 1600, 1800] 
AutoBatchedSerializer Writing 25 to <class '_io.BufferedWriter'>

# List-based app 
append
append
append
append
append
append
append
append
append
append
AutoBatchedSerializer types for iterator=<class 'list'> and stream=<class '_io.BufferedWriter'>
AutoBatchedSerializer data for iterator=[0, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800] and stream=<_io.BufferedWriter name=5>
AutoBatchedSerializer  batch, best = 1, 65536
AutoBatchedSerializer stream  VS for [0]
AutoBatchedSerializer Writing 17 to <class '_io.BufferedWriter'>
AutoBatchedSerializer stream  VS for [200, 400]
AutoBatchedSerializer Writing 21 to <class '_io.BufferedWriter'>
AutoBatchedSerializer stream  VS for [600, 800, 1000, 1200]
AutoBatchedSerializer Writing 28 to <class '_io.BufferedWriter'>
AutoBatchedSerializer stream  VS for [1400, 1600, 1800]
AutoBatchedSerializer Writing 25 to <class '_io.BufferedWriter'>
AutoBatchedSerializer stream  VS for []

The logs are similar except the input iterator which corresponds to the data type returned by the mapping function. Of course, it implies a different data generation method that you can see by comparing the "append" with "yield" prints (all at once vs one at a time).

I must admit that I'm not fully satisfied with the article because I'm not proficient with PySpark debugging. However, I hope that it gives enough details to understand the difference between the generator- and list-based mapping function and its impact on the job.