The golden rule, when you deal with a lot of data, is to avoid bringing all these data on a single node. It can easily and pretty quickly lead to OOM errors. Spark isn't an exception for this rule. But Spark provides one solution that can reduce the amount of objects brought the driver, when this move is mandatory - toLocalIterator method.
Data Engineering Design Patterns
Looking for a book that defines and solves most common data engineering problems? I'm currently writing
one on that topic and the first chapters are already available in π
Early Release on the O'Reilly platform
I also help solve your data engineering problems π contact@waitingforcode.com π©
This short post describes this method. The first part explains what happens when it's called while the second part, through learning test checking logs output, proves that toLocalIterator can reduce the number of objects sent to the driver.
toLocalIterator explained
The most popular Spark's method used to bring data to the driver is collect(). It executes given job in all partitions (executors side) and collects all results (driver side) with Array.concat(results: _*) method. The toLocalIterator does the contrary. Instead of launching the job simultaneously on all partitions it executes the job on 1 partition at once. So, the driver must have enough memory to store the biggest partition.
The implementation details look like:
def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) }
It's important to note, however, that the toLocalIterator doesn't prevent against OOM problems on the driver side. As already mentioned, the driver must be ready to handle the biggest partition. In the case of a lot of created objects, small number of partitions and bad partitioning (e.g. one partition storing 90% of data), the OOM problems are still real.
toLocalIterator example
The following code contains 2 test cases showing the differences between toLocalIterator and collect:
val ExpectedLogPatterns = Seq("Finished task 0.0 in stage 0.0 (TID 0) in ", "Finished task 1.0 in stage 0.0 (TID 1) in ", "Finished task 2.0 in stage 0.0 (TID 2) in ", "Finished task 3.0 in stage 0.0 (TID 3) in ", "Finished task 4.0 in stage 0.0 (TID 4) in ") "only one task" should "be executed when the first 5 items must be retrieved" in { val inMemoryLogAppender = InMemoryLogAppender.createLogAppender(ExpectedLogPatterns) val numbersRdd = sparkContext.parallelize(1 to 100, 5) val numbersRddLocalIterator = numbersRdd.map(number => number * 2) .toLocalIterator // This filter could be implemented in .filter() method // But used as here helps to show the difference between // toLocalIterator and collect var canRun = true while (numbersRddLocalIterator.hasNext && canRun) { val partitionNumber = numbersRddLocalIterator.next() if (partitionNumber == 10) { canRun = false } } inMemoryLogAppender.messages.size shouldEqual (1) val logMessages = inMemoryLogAppender.getMessagesText() val taskExecution = logMessages.filter(msg => msg.startsWith(s"Finished task 0.0 ")).size taskExecution shouldEqual(1) } "collect invocation" should "move all data to the driver" in { val inMemoryLogAppender = InMemoryLogAppender.createLogAppender (ExpectedLogPatterns) val numbersRdd = sparkContext.parallelize (1 to 100, 5) val collectedNumbers = numbersRdd.map (number => number * 2).collect inMemoryLogAppender.messages.size shouldEqual(5) val logMessages = inMemoryLogAppender.getMessagesText () for (i <- 0 to 4) { val taskExecution = logMessages.filter (msg => msg.startsWith (s"Finished task ${i}.0")).size taskExecution shouldEqual(1) } }
The goal of this post was to show an alternative to collect() method, being less memory-intensive. The first part explained the implementation details. We could learn that driver memory must be ready to support only the biggest partition. The second section shown the differences between collect() and toLocalIterator() through 2 test cases analyzing tasks execution from the logs.