Collecting a part of data to the driver with RDD toLocalIterator

Versions: Spark 2.1.0

The golden rule, when you deal with a lot of data, is to avoid bringing all these data on a single node. It can easily and pretty quickly lead to OOM errors. Spark isn't an exception for this rule. But Spark provides one solution that can reduce the amount of objects brought the driver, when this move is mandatory - toLocalIterator method.

This short post describes this method. The first part explains what happens when it's called while the second part, through learning test checking logs output, proves that toLocalIterator can reduce the number of objects sent to the driver.

toLocalIterator explained

The most popular Spark's method used to bring data to the driver is collect(). It executes given job in all partitions (executors side) and collects all results (driver side) with Array.concat(results: _*) method. The toLocalIterator does the contrary. Instead of launching the job simultaneously on all partitions it executes the job on 1 partition at once. So, the driver must have enough memory to store the biggest partition.

The implementation details look like:

def toLocalIterator: Iterator[T] = withScope {
  def collectPartition(p: Int): Array[T] = {
    sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head
  }
  (0 until partitions.length).iterator.flatMap(i => collectPartition(i))
}

It's important to note, however, that the toLocalIterator doesn't prevent against OOM problems on the driver side. As already mentioned, the driver must be ready to handle the biggest partition. In the case of a lot of created objects, small number of partitions and bad partitioning (e.g. one partition storing 90% of data), the OOM problems are still real.

toLocalIterator example

The following code contains 2 test cases showing the differences between toLocalIterator and collect:

val ExpectedLogPatterns = Seq("Finished task 0.0 in stage 0.0 (TID 0) in ",
  "Finished task 1.0 in stage 0.0 (TID 1) in ",
  "Finished task 2.0 in stage 0.0 (TID 2) in ", 
  "Finished task 3.0 in stage 0.0 (TID 3) in ",
  "Finished task 4.0 in stage 0.0 (TID 4) in ") 

"only one task" should "be executed when the first 5 items must be retrieved" in {
  val inMemoryLogAppender = InMemoryLogAppender.createLogAppender(ExpectedLogPatterns)
  val numbersRdd = sparkContext.parallelize(1 to 100, 5)
  val numbersRddLocalIterator = numbersRdd.map(number => number * 2)
    .toLocalIterator

  // This filter could be implemented in .filter() method
  // But used as here helps to show the difference between
  // toLocalIterator and collect
  var canRun = true
  while (numbersRddLocalIterator.hasNext && canRun) {
    val partitionNumber = numbersRddLocalIterator.next()
    if (partitionNumber == 10) {
      canRun = false
    }
  }

  inMemoryLogAppender.messages.size shouldEqual (1)
  val logMessages = inMemoryLogAppender.getMessagesText()
  val taskExecution = logMessages.filter(msg => msg.startsWith(s"Finished task 0.0 ")).size
  taskExecution shouldEqual(1)
}

"collect invocation" should "move all data to the driver" in {
  val inMemoryLogAppender = InMemoryLogAppender.createLogAppender (ExpectedLogPatterns)
  val numbersRdd = sparkContext.parallelize (1 to 100, 5)

  val collectedNumbers = numbersRdd.map (number => number * 2).collect

  inMemoryLogAppender.messages.size shouldEqual(5)
  val logMessages = inMemoryLogAppender.getMessagesText ()
  for (i <- 0 to 4) {
    val taskExecution = logMessages.filter (msg => msg.startsWith (s"Finished task ${i}.0")).size
    taskExecution shouldEqual(1)
  }
}

The goal of this post was to show an alternative to collect() method, being less memory-intensive. The first part explained the implementation details. We could learn that driver memory must be ready to support only the biggest partition. The second section shown the differences between collect() and toLocalIterator() through 2 test cases analyzing tasks execution from the logs.