PySpark and the JVM - introduction, part 2

Versions: Apache Spark 3.2.1

Last time I introduced Py4j which is the bridge between Apache Spark JVM codebase and Python client applications. Today it's a great moment to take a deeper look at their interaction in the context of data processing defined with the RDD and DataFrame APIs.

DataFrame and RDD

When you define a DataFrame-based transformation, such as a filter method or aggregate, PySpark delegates its computation to the JVM DataFrame. How? Let's check the snippet from dataframe.py:

class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
    def __init__(self, jdf, sql_ctx):
        self._jdf = jdf
        self.sql_ctx = sql_ctx
    # ...
def createOrReplaceTempView(self, name):
        self._jdf.createOrReplaceTempView(name)

def exceptAll(self, other):
return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)

def limit(self, num):
jdf = self._jdf.limit(num)
return DataFrame(jdf, self.sql_ctx)

def coalesce(self, numPartitions):
return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)

def filter(self, condition):
if isinstance(condition, str):
    jdf = self._jdf.filter(condition)
elif isinstance(condition, Column):
    jdf = self._jdf.filter(condition._jc)
else:
    raise TypeError("condition should be string or Column")
return DataFrame(jdf, self.sql_ctx)

As you can see, PySpark's DataFrame has a private _jdf property which is an instance of Py4j's JavaObject referencing DataFrame instance on the JVM side. A JavaObject gives the possibility to call methods and access fields of the JVM counterpart. It means that whenever you invoke a Python's DataFrame method, you automatically invoke Scala's DataFrame. Therefore, calling a function doesn't involve exchanging the data whatsoever. The physical execution of the function happens on the JVM side.

RDD does the opposite. This API involves 2-way synchronization between the Python and JVM processes where the data gets stored on the JVM side but the computation happens in the Python interpreter. Below you can see a simplified example for the map(...) transformation where PythonRDD physically store the data got from the applied map function and 2 serialize/deserialize operations.

RDD in details

Let's take a deeper dive at the code used to draw the diagram in the previous section:

Why is map(...) not available in DataFrame?

PySpark DataFrame and RDD share several methods, like filter(...). It's one of the most common data transformations and from that we could naturally assume that another popular operation, the map(...), should also be present in both APIs. But this function is only available in the RDD.

One of the reasons might be the processing character. The filter(...) supports fully structured operation either on a Column or a string condition. It's not the case of the map(...) that can have any arbitrary processing logic. Hence, it must run on the Python side and get the data deserialized, which is the domain of the RDD.

Of course, the lack of the map function also comes from the dynamic vs static typing language and strongly typed character of Datasets. The filter method available in PySpark doesn't accept a custom function neither.

Serialization in the DataFrame

Although the DataFrame API is not supposed to serialize/deserialize the data from the JVM, it's not a rule. Any DataFrame method that brings the data to the driver needs to ask its Java counterpart to serialize it beforehand. Let's check some of these methods:

class DataFrame(PandasMapOpsMixin, PandasConversionMixin):

def collect(self):
with SCCallSiteSync(self._sc) as css:
    sock_info = self._jdf.collectToPython()
return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))

def toLocalIterator(self, prefetchPartitions=False):
with SCCallSiteSync(self._sc) as css:
    sock_info = self._jdf.toPythonIterator(prefetchPartitions)
return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))

def tail(self, num):
with SCCallSiteSync(self._sc):
    sock_info = self._jdf.tailToPython(num)
return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))

All of these 3 methods have a similar workflow consisting of loading the data from the socket and deserializing it to the Python objects. They all call a Scala Dataset API that registers the available serializers and converts them while returning the rows to the Python interpereter:

class Dataset[T] private[sql](
    @DeveloperApi @Unstable @transient val queryExecution: QueryExecution,
    @DeveloperApi @Unstable @transient val encoder: Encoder[T])
  extends Serializable {

private[sql] def collectToPython(): Array[Any] = {
  EvaluatePython.registerPicklers()
  withAction("collectToPython", queryExecution) { plan =>
    val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
    val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
      plan.executeCollect().iterator.map(toJava))
    PythonRDD.serveIterator(iter, "serve-DataFrame")
  }
}

The 2-way serialization mentioned many times in the article is often the main reason why it's recommended to prefer the DataFrame API over the RDD's. I hope that like me, now you understand better what happens and why it has such a big impact on the performance.