Data+AI Summit: custom state store integration feedback

Versions: Apache Spark 3.0.1

After the introductory part, it's time to share what I learned from the custom state store implementation.

The post is divided into 4 short parts. Each of them presents one feedback point related to a custom state store implementation.

The sense of state store name

I didn't notice that at the beginning but the StateStoreId class is composed of 3 mandatory fields (checkpointRootLocation, operatorId, partitionId) and 1 optional (storeName). Knowing that there is one state store per operation and partition, and that the default implementation stores states in checkpoint location, the 3 mandatory ones are quite easy to understand. But what with the state store name?

In fact, some of the stateful operations can use 2 different state stores, and the name is used to distinguish between them. It's the case of streaming joins 2 stores are created. The first one stores the number of matched values for every key and the second corresponding rows at key+index-basis.

But if for whatever reason you want to store something related to the state store locally, it's better to take the name into account. In my sample MapDB implementation, the name is used in the MapDBStateStoreNamingFactory to generate unique names per operation, partition and name for underlying files:

case class MapDBStateStoreNamingFactory(checkpointStorePath: String, localStorePath: String,
                                       operatorId: Long, partitionNumber: Int, stateStoreName: String) {
// ...
  val allEntriesFile = s"${localStorePath}/all-entries-${stateStoreName}-${operatorId}-${partitionNumber}.db"
  private def updateFile(dir: String, version: Long) = {
    new File(s"${dir}/${version}").mkdirs()
    s"${dir}/${version}/updates-${stateStoreName}-${operatorId}-${partitionNumber}.db"
  }
  private def deleteFile(dir: String, version: Long) = {
    new File(s"${dir}/${version}").mkdirs()
    s"${dir}/${version}/deletes-${stateStoreName}-${operatorId}-${partitionNumber}.db"
  }
  private def snapshotFile(dir: String, version: Long) = s"${dir}/${version}/snapshot-${stateStoreName}-${operatorId}-${partitionNumber}.db"
}

UnsafeRow - to copy or not to copy?

That's the question. UnsafeRow is a mutable class and that's the reason why if you execute the code below, you will see that the unsafeRows contains always the same - the last - value:

  val unsafeRows = new mutable.ListBuffer[UnsafeRow]()
  val bytesFromUnsafeRow = new mutable.ListBuffer[String]()
  val bytes = "123".getBytes("utf-8")
  val unsafeRowFromBytes = new UnsafeRow(1)
  unsafeRowFromBytes.pointTo(bytes, bytes.length)
  bytesFromUnsafeRow.append(new String(unsafeRowFromBytes.getBytes))
  unsafeRows.append(unsafeRowFromBytes)

  val nextBytes = "345".getBytes("utf-8")
  unsafeRowFromBytes.pointTo(nextBytes, nextBytes.length)
  bytesFromUnsafeRow.append(new String(unsafeRowFromBytes.getBytes))
  unsafeRows.append(unsafeRowFromBytes)

  // `bytesFromUnsafeRow` keeps 2 different values, normal since it
  // extracts the bytes from every `UnsafeRow` at the given moment
  println(bytesFromUnsafeRow)
  // `unsafeRows` keeps only 1 value, since the `UnsafeRow` is mutable
  println(unsafeRows)
  unsafeRows.foreach(row => println(new String(row.getBytes)))

But from the snippet you can also notice that the bytes array composing an UnsafeRow is only the copy of the current row. Hence, it can be stored without the explicit copy() call on UnsafeRow. That's the reason why the put method of MapDBStateStore simply takes the array of bytes for the key and value:

  override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
    updatesFromVersion.put(key.getBytes, value.getBytes)
    mapWithAllEntries.remove(key.getBytes)
  }

But to be perfectly honest, I also used the array of bytes for the sake of simplicity and to not implement a custom MapDB serializer/deserializer pair for UnsafeRow :P

UnsafeRowPair trap

But UnsafeRow was not a single mutable object causing some problems. The second one is UnsafeRowPair used in the iterator() method:

  override def iterator(): Iterator[UnsafeRowPair] = {
    val unsafeRowPair = new UnsafeRowPair()
    def setKeyAndValueToUnsafeRowPair(entry: java.util.Map.Entry[Array[Byte], Array[Byte]]): UnsafeRowPair = {
      val key = new UnsafeRow(keySchema.fields.length)
      key.pointTo(entry.getKey, entry.getKey.length)
      val value = convertValueToUnsafeRow(entry.getValue)
      unsafeRowPair.withRows(key, value)
    }
    updatesFromVersion.getEntries.asScala.toIterator.map(entry => {
      setKeyAndValueToUnsafeRowPair(entry)
    }) ++ mapWithAllEntries.getEntries.asScala.toIterator.map(entry => {
      setKeyAndValueToUnsafeRowPair(entry)
    })
  }

You can see here the final version that works on the iterators rather than the materialized collections. If you used a materialized collection here, as for the case of an UnsafeRow, only the last value would be kept. So if you run the following code snippet:

  println("--- Checking the materialized collection version ---")
  val unsafeRowPair = new UnsafeRowPair()
  val mappedUnsafeRowPairs = Seq(Seq("1", "2", "3"), Seq("4", "5", "6")).flatMap(numbers => {
    numbers.map(number => {
      val numberBytes = number.getBytes("utf-8")
      val unsafeRowNumber = new UnsafeRow(1)
      unsafeRowNumber.pointTo(numberBytes, numberBytes.length)
      unsafeRowPair.withRows(unsafeRowNumber, unsafeRowNumber)
    })
  })

  mappedUnsafeRowPairs.foreach(pair => {
    println(s"${new String(pair.key.getBytes)} ==> ${new String(pair.value.getBytes)}")
  })

  println("--- Checking the iterator version ---")
  val unsafeRowPairInternal = new UnsafeRowPair()
  def setKeyAndValueToUnsafeRowPair(number: String): UnsafeRowPair = {
    val numberBytes = number.getBytes("utf-8")
    val unsafeRowNumber = new UnsafeRow(1)
    unsafeRowNumber.pointTo(numberBytes, numberBytes.length)
    unsafeRowPairInternal.withRows(unsafeRowNumber, unsafeRowNumber)
  }
  val mappedUnsafeRowPairsInternal = Seq("1", "2", "3").toIterator.map(number => setKeyAndValueToUnsafeRowPair(number)) ++
    Seq("4", "5", "6").toIterator.map(number => setKeyAndValueToUnsafeRowPair(number))
  // This one retrieves the rows correctly
  // UnsafeRowPair is a shared buffer and since the iterator materializes
  // one item at a time, it always return the next row
  mappedUnsafeRowPairsInternal.foreach(pair => {
    println(s"${new String(pair.key.getBytes)} ==> ${new String(pair.value.getBytes)}")
  })

You should observe that output for the first print:

--- Checking the materialized collection version ---
6 ==> 6
6 ==> 6
6 ==> 6
6 ==> 6
6 ==> 6
6 ==> 6

And this one for the second group:

--- Checking the iterator version ---
1 ==> 1
2 ==> 2
3 ==> 3
4 ==> 4
5 ==> 5
6 ==> 6

Managing 2 state store readings for the same version

The next challenge was about the state store retrieved twice inside the same micro-batch. And by the "state store retrieval", I mean the getStore(version: Long) method of StateStoreProvider.

When does it happen? The double retrieval happens with aggregations. The first time, Apache Spark loads the state store in the StateStoreRestoreExec operation to get the previously computed aggregates. In my MapDB-backed implementation, this call creates delta files to handle updated and deleted states:

 private val updatesFromVersionDb = DBMaker
    .fileDB(updatesFileFullPath)
    .fileMmapEnableIfSupported()
    .make()
  private val updatesFromVersion = updatesFromVersionDb
    .hashMap(MapDBStateStore.EntriesName, Serializer.BYTE_ARRAY, Serializer.BYTE_ARRAY)
    .createOrOpen()

  private val deletesFileFullPath = namingFactory.localDeltaForDelete(version)
  private val deletesFromVersionDb = DBMaker
    .fileDB(deletesFileFullPath)
    .fileMmapEnableIfSupported()
    .make()
  private val deletesFromVersion = deletesFromVersionDb
    .hashSet(MapDBStateStore.EntriesName, Serializer.BYTE_ARRAY)
    .createOrOpen()

The problem with this snippet is that the second operation, executed inside StateStoreSaveExec class, also calls getStore. Before identifying this behavior, the state store was always created from scratch in this method, which led to the checksum problems for MapDB files.

To solve the issue, it's possible to disable the checksum verification with checksumHeaderBypass(). But to avoid the data corruption issues, I opted for an alternative approach to keep the current state store instance referenced in the provider:


class MapDBStateStoreProvider extends StateStoreProvider with Logging {
  private var lastCommittedVersion = NoCommittedVersionFlag
  private var previousStateStoreInstance: MapDBStateStore = null

  override def getStore(version: Long): StateStore = {
    if (previousStateStoreInstance == null || lastCommittedVersion != version ) {
// ... creates the store only if it's the first execution 
//     or if the version changed

These are the extra points completing the points covered in the summary slide of my Data+AI talk. Thanks for reading 📖

If you liked it, you should read:

The comments are moderated. I publish them when I answer, so don't worry if you don't see yours immediately :)

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!