Spark's Singleton to be or not to be dilemma

Versions: Spark 2.1.0

Some time ago I was wondering why an object created once in the driver is recreated every time with new stage on executors - even if this object is sent through a broadcast variable. After some code digging, the response related to Java serialization appeared.

This post tries to explains the question arised in my head some time ago. It's composed of 2 parts. The first part contains the code illustrating the problem. The second part gives the response.

New instance created every time

The following tests show the situation analyzed in this post:

"new instance" should "be created every time in stage" in {
  for (i <- 1 to 10) {
    dataQueue += streamingContext.sparkContext.makeRDD(Seq(i), 1)
  }
  val instanceClass = new InstanceClass()
  val seenInstancesAccumulator = streamingContext.sparkContext.collectionAccumulator[Int]("seen instances")
  val queueStream = streamingContext.queueStream(dataQueue, true)
  queueStream.foreachRDD((rdd, time) => {
    rdd.foreachPartition(numbers => {
      seenInstancesAccumulator.add(instanceClass.hashCode())
      println(s"Instance of ${instanceClass} with hash code ${instanceClass.hashCode()}")
    })
  })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(10000)

  val seenClassesHashes = seenInstancesAccumulator.value.stream().collect(Collectors.toSet())
  seenClassesHashes.size() should be > 1
}

"a single instance" should "be kept in every stage" in {
  for (i <- 1 to 10) {
    dataQueue += streamingContext.sparkContext.makeRDD(Seq(i), 1)
  }
  val instanceClass = SingletonObject
  val seenInstancesAccumulator = streamingContext.sparkContext.collectionAccumulator[Int]("seen instances")
  val queueStream = streamingContext.queueStream(dataQueue, true)
  queueStream.foreachRDD((rdd, time) => {
    rdd.foreachPartition(numbers => {
      seenInstancesAccumulator.add(instanceClass.hashCode())
      println(s"Instance of ${instanceClass} with hash code ${instanceClass.hashCode()}")
    })
  })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(10000)

  val seenClassesHashes = seenInstancesAccumulator.value.stream().collect(Collectors.toSet())
  seenClassesHashes.size() shouldEqual(1)
}

"a single instance coming from pool of singletons" should "be kept in every stage" in {
  for (i <- 1 to 10) {
    dataQueue += streamingContext.sparkContext.makeRDD(Seq(i), 1)
  }
  val seenInstancesAccumulator = streamingContext.sparkContext.collectionAccumulator[Int]("seen instances")
  val queueStream = streamingContext.queueStream(dataQueue, true)
  queueStream.foreachRDD((rdd, time) => {
    rdd.foreachPartition(numbers => {
      val lazyLoadedInstanceClass = LazyLoadedInstanceClass.getInstance(1)
      seenInstancesAccumulator.add(lazyLoadedInstanceClass.hashCode())
      println(s"Instance of ${lazyLoadedInstanceClass} with hash code ${lazyLoadedInstanceClass.hashCode()}")
    })
  })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(10000)

  val seenClassesHashes = seenInstancesAccumulator.value.stream().collect(Collectors.toSet())
  seenClassesHashes.size() shouldEqual(1)

}

}

class InstanceClass extends Serializable {}

object SingletonObject extends Serializable  {}


class LazyLoadedInstanceClass(val id:Int) extends Serializable {}

object LazyLoadedInstanceClass extends Serializable {

private val InstancesMap = mutable.Map[Int, LazyLoadedInstanceClass]()

  def getInstance(id: Int): LazyLoadedInstanceClass = {
    InstancesMap.getOrElseUpdate(id, new LazyLoadedInstanceClass(id))
  }

}

As you can see, singleton and lazy loaded instance, as expected, represent always the same object. We could expect to have the same think for the instance of InstanceClass, initialized once on the driver, just before processing. But surprisingly it's not the case.

Why new instance is created every time ?

To discover the reason of that, some code digging must be done. Every time when foreachRDD is done, the closure defined inside foreachPartition is deserialized by the executors.

Under-the-hood the Java serialization is used to construct serialized objects used in the processing. The deserialization is made by org.apache.spark.serializer.JavaDeserializationStream and its below method:

def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]

asInstanceOf[Class] called on a class (InstanceClass case) will create its new instance every time. The same method called on object, that by default is a singleton, will create the same instance every time. The following test proves that:

  "JVM singleton" should "be serialized and deserialized as the same object instance" in {
  val seenHashCodes = mutable.Set[Int]()
  val serializedObjectBytes = serializeObject(SingletonClass)

  for (i <- 1 to 10) {
    val objectInputStream = new ObjectInputStream(new ByteArrayInputStream(serializedObjectBytes))
    val deserializedSingleton = objectInputStream.readObject().asInstanceOf[SingletonClass.type]
    seenHashCodes += deserializedSingleton.hashCode()
    objectInputStream.close()
  }

  seenHashCodes.size shouldEqual(1)
}

"Java object" should "be serialized and deserialized as new instance" in {
  val seenHashCodes = mutable.Set[Int]()
  val instanceClass = new InstanceClass
  val serializedObjectBytes = serializeObject(instanceClass)

  for (i <- 1 to 10) {
    val objectInputStream = new ObjectInputStream(new ByteArrayInputStream(serializedObjectBytes))
    val deserializedInstanceClass = objectInputStream.readObject().asInstanceOf[InstanceClass]
    seenHashCodes.add(deserializedInstanceClass.hashCode())
    objectInputStream.close()
  }

  seenHashCodes.size should be > 1
}

private def serializeObject(toSerialize: Any): Array[Byte] = {
  val outputStream = new ByteArrayOutputStream()
  val objectOutputStream = new ObjectOutputStream(outputStream)
  objectOutputStream.writeObject(toSerialize)
  objectOutputStream.close()
  outputStream.toByteArray
}


class InstanceClass extends Serializable {}

object SingletonClass extends Serializable {}

Does it mean that we must always use objects to guarantee the uniqueness ? Not at all. Object's uniqueness can also be provided by the implementation of equality methods (equals/hashCode). The last example shows how it works:

"Java object with implemented equality" should "be serialized and deserialized as the same intance" in {
  val seenHashCodes = mutable.Set[Int]()
  val instanceClass = new InstanceClassWithEquality(1)
  val serializedObjectBytes = serializeObject(instanceClass)

  for (i <- 1 to 10) {
    val objectInputStream = new ObjectInputStream(new ByteArrayInputStream(serializedObjectBytes))
    val deserializedInstanceClass = objectInputStream.readObject().asInstanceOf[InstanceClassWithEquality]
    seenHashCodes.add(deserializedInstanceClass.hashCode())
    objectInputStream.close()
  }

  seenHashCodes.size shouldEqual(1)
}

class InstanceClassWithEquality(val id: Int) extends Serializable {

  override def equals(comparedObject: scala.Any): Boolean = {
    if (comparedObject.isInstanceOf[InstanceClassWithEquality]) {
      val comparedInstance = comparedObject.asInstanceOf[InstanceClassWithEquality]
      id == comparedInstance.id
    } else {
      false
    }
  }

  override def hashCode(): Int = {
    id
  }

}

And the class with implemented equality logic (explicitly and through case class), used in Spark programs, gives the following:

"a single instance coming from class with equality logic implemented" should "be kept in every stage" in {
  for (i <- 1 to 10) {
    dataQueue += streamingContext.sparkContext.makeRDD(Seq(i), 1)
  }
  val seenInstancesAccumulator = streamingContext.sparkContext.collectionAccumulator[Int]("seen instances")
  val queueStream = streamingContext.queueStream(dataQueue, true)
  val instanceClass = new InstanceClassWithEquality(1)
  queueStream.foreachRDD((rdd, time) => {
    rdd.foreachPartition(numbers => {
      seenInstancesAccumulator.add(instanceClass.hashCode())
      println(s"Instance of ${instanceClass} with hash code ${instanceClass.hashCode()}")
    })
  })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(10000)

  val seenClassesHashes = seenInstancesAccumulator.value.stream().collect(Collectors.toSet())
  seenClassesHashes.size() shouldEqual(1)
}

// The same thing as in above test can be achieved easier
// thanks to Scala's case classes
"a single instance coming from case class" should "be kept in every stage" in {
  for (i <- 1 to 10) {
    dataQueue += streamingContext.sparkContext.makeRDD(Seq(i), 1)
  }
  val seenInstancesAccumulator = streamingContext.sparkContext.collectionAccumulator[Int]("seen instances")
  val queueStream = streamingContext.queueStream(dataQueue, true)
  val instanceClass = InstanceClassAsCaseClass(1)
  queueStream.foreachRDD((rdd, time) => {
    rdd.foreachPartition(numbers => {
      seenInstancesAccumulator.add(instanceClass.hashCode())
      println(s"Instance of ${instanceClass} with hash code ${instanceClass.hashCode()}")
    })
  })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(10000)

  val seenClassesHashes = seenInstancesAccumulator.value.stream().collect(Collectors.toSet())
  seenClassesHashes.size() shouldEqual(1)
}

class InstanceClassWithEquality(val id:Int) extends Serializable {
  override def equals(comparedObject: scala.Any): Boolean = {
    if (comparedObject.isInstanceOf[InstanceClassWithEquality]) {
      val comparedInstance = comparedObject.asInstanceOf[InstanceClassWithEquality]
      id == comparedInstance.id
    } else {
      false
    }
  }

  override def hashCode(): Int = {
    id
  }
}

case class InstanceClassAsCaseClass(id: Int)

Through this post we can learn that for every stage Spark creates new instance of serialized objects because of Java serialization. The tests made in the second part of the post proven that when a class instance is serialized, on deserialization a new object was created every time. The same test made on singleton (Scala's object) shown the contrary - even if it's read 10 times, always the same object is created. Through the last test case we learned that we might receive the same instance for class if equality would be implemented.


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!