Why UnsafeRow.copy() for state persistence in the state store?

Versions: Apache Spark 2.4.2

In my last Spark+AI Summit 2019 follow-up posts I'm implementing a custom state store. The extension is inspired by the default state store. At the moment of code analysis, one of the places that intrigued me was the put(key: UnsafeRow, value: UnsafeRow) method. Keep reading if you're curious why.

New ebook 🔥

Learn 84 ways to solve common data engineering problems with cloud services.

👉 I want my Early Access edition

The put method is responsible for adding new value to the state store and its implementation is quite straightforward:

    override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
      verify(state == UPDATING, "Cannot put after already committed or aborted")
      val keyCopy = key.copy()
      val valueCopy = value.copy()
      mapToUpdate.put(keyCopy, valueCopy)
      writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy)
    }

Do you see something intriguing? Yes, the key and value stored in the state store cache map are both the copies of the original data! Why so? The answer is given in the StateStore interface comment:

  /**
   * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in
   * the params can be reused, and must make copies of the data as needed for persistence.
   */
  def put(key: UnsafeRow, value: UnsafeRow): Unit

You said "reused"? Yes, if you look at the UnsafeRow implementations, you will see that it's never created per-row. Its constructor takes only 1 parameter with the number of fields in the schema. The value for each UnsafeRow instance is assigned later, when pointTo(Object baseObject, long baseOffset, int sizeInBytes) or pointTo(byte[] buf, int sizeInBytes) is called. And very often Spark creates a single UnsafeRow instance and uses it as a wrapper for the real data objects. You can see that use in:

Before terminating, let's check what happened if we kept a not copied instance of an UnsafeRow in a list:

    import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
    val stringEncoder = ExpressionEncoder[String]
    val rowA = stringEncoder.toRow("A").asInstanceOf[UnsafeRow]

    assert(rowA.getString(0) == "A")

    val rowB = stringEncoder.toRow("B").asInstanceOf[UnsafeRow]
    assert(rowB.getString(0) == "B")

    rowA.pointTo(rowB.getBytes(), rowB.getSizeInBytes)
    assert(rowA.getString(0) == "B")
    assert(rowB.getString(0) == "B")

As you can see, the UnsafeRow created at the beginning references the second row. In Apache Spark source code you will find a lot of places where the UnsafeRow is shared but also the places where a single method returns a new UnsafeRow instance. But despite the fact of being created locally, it doesn't mean that it won't be mutated elsewhere.