Data+AI Summit follow-up: arbitrary stateful processing and state management

Versions: Apache Spark 3.0.1

After previous posts about native stateful operations, it's time to focus on the one where you can define your custom stateful logic.

The blog post is organized as follows. The first section will show the logical plan for arbitrary stateful processing. The next part will present how this streaming operation interacts with the state store.

Planning

The arbitrary stateful processing is related to mapGroupsWithState and flatMapGroupsWithState functions. When you call them on your input Dataset, you transparently call that:

// def mapGroupsWithState[S: Encoder, U: Encoder]
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
    Dataset[U](
      sparkSession,
      FlatMapGroupsWithState[K, V, S, U](
        flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
        groupingAttributes,
        dataAttributes,
        OutputMode.Update,
        isMapGroupsWithState = true,
        timeoutConf,
        child = logicalPlan))
//   def flatMapGroupsWithState[S: Encoder, U: Encoder]
    if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
      throw new IllegalArgumentException("The output mode of function should be append or update")
    }

    Dataset[U](
      sparkSession,
      FlatMapGroupsWithState[K, V, S, U](
        func.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
        groupingAttributes,
        dataAttributes,
        outputMode,
        isMapGroupsWithState = false,
        timeoutConf,
        child = logicalPlan))

As you can see, the mapGroupsWithState is just an alias for the flatMapGroupsWithState that wraps the mapping function and transforms its output to an Iterator. Nothing complex then but the real fun starts in the FlatMapGroupsWithStateExec class. The rule responsible for transforming the logical map-with-state node into FlatMapGroupsWithStateExec physical node is org.apache.spark.sql.execution.SparkStrategies.FlatMapGroupsWithStateStrategy. The function takes the state store version as the parameter:

        val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
        val execPlan = FlatMapGroupsWithStateExec(
          func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, stateVersion,
          outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child))
        execPlan :: Nil

The version is later used in the physical operator to initialize one of 2 available state managers:

// FlatMapGroupsWithStateExecHelper#createStateManager
  def createStateManager(
      stateEncoder: ExpressionEncoder[Any],
      shouldStoreTimestamp: Boolean,
      stateFormatVersion: Int): StateManager = {
    stateFormatVersion match {
      case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp)
      case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp)
      case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
    }
  }

The difference between versions consists of the user-defined state storage. The V1 stores it as a flattened columns (UnsafeRow[ col1 | col2 | col3 | timestamp ]) whereas the V2 as a nested struct (UnsafeRow[ nested-struct | timestamp | UnsafeRow[ col1 | col2 | col3 ] ]). The reason for that change? It was impossible to set the timeout without explicitly defining the value for the state. It was considered as a confusing rule and changing the underlying format was the solution for enabling the state storage without the value but with a timeout.

FlatMapGroupsWithStateExec and state store

The operator interacts with the state store for 2 different operations. The first one is the processing of the new input rows via InputProcessor's processNewData(dataIter: Iterator[InternalRow]) method. Under-the-hood it creates an instance of GroupedIterator - since the state applies on the groups of rows - and returns a flatten map like:

      groupedIter.flatMap { case (keyRow, valueRowIter) =>
        val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
        callFunctionAndUpdateState(
          stateManager.getState(store, keyUnsafeRow),
          valueRowIter,
          hasTimedOut = false)
      }

As you can see, it's the first interaction with the state store. The iterator gets the state object for the key. The 2 other actions happen in the callFunctionAndUpdateState where after processing all rows for the input group, the state is either removed or updated in the state store. And it's in the update process where you can encounter the "empty" value phenomena addressed by the V2 of the state store:

        if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
          stateManager.removeState(store, stateData.keyRow)
          numUpdatedStateRows += 1
        } else {
          val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
          val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
          val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged

          if (shouldWriteState) {
            val updatedStateObj = if (groupState.exists) groupState.get else null
            stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
            numUpdatedStateRows += 1
          }
        }

As you can see from the if-else case, if the user explicitly removed the state, it will be empty. As of this writing, the user cannot define the state to null because of this check in the state's update method but it can initialize a group with only the timeout defined (no check performed on the value at timeout setting):

  override def update(newValue: S): Unit = {
    if (newValue == null) {
      throw new IllegalArgumentException("'null' is not a valid state value")
    }
    value = newValue
    defined = true
    updated = true
    removed = false
  }
// Below, some timeout configuration methods
// You can see, there is no control on the state value
  private def checkTimeoutTimestampAllowed(): Unit = {
    if (timeoutConf != EventTimeTimeout) {
      throw new UnsupportedOperationException(
        "Cannot set timeout timestamp without enabling event time timeout in " +
          "[map|flatMapGroupsWithState")
    }
  }
  override def setTimeoutDuration(durationMs: Long): Unit = {
    if (timeoutConf != ProcessingTimeTimeout) {
      throw new UnsupportedOperationException(
        "Cannot set timeout duration without enabling processing time timeout in " +
          "[map|flatMap]GroupsWithState")
    }
    if (durationMs <= 0) {
      throw new IllegalArgumentException("Timeout duration must be positive")
    }
    timeoutTimestamp = durationMs + batchProcessingTimeMs
  }

  override def setTimeoutDuration(duration: String): Unit = {
    setTimeoutDuration(parseDuration(duration))
  }

  override def setTimeoutTimestamp(timestampMs: Long): Unit = {
    checkTimeoutTimestampAllowed()
    if (timestampMs <= 0) {
      throw new IllegalArgumentException("Timeout timestamp must be positive")
    }
    if (eventTimeWatermarkMs != NO_TIMESTAMP && timestampMs < eventTimeWatermarkMs) {
      throw new IllegalArgumentException(
        s"Timeout timestamp ($timestampMs) cannot be earlier than the " +
          s"current watermark ($eventTimeWatermarkMs)")
    }
    timeoutTimestamp = timestampMs
  }

  override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = {
    checkTimeoutTimestampAllowed()
    setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs)
  }

  override def setTimeoutTimestamp(timestamp: Date): Unit = {
    checkTimeoutTimestampAllowed()
    setTimeoutTimestamp(timestamp.getTime)
  }

  override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = {
    checkTimeoutTimestampAllowed()
    setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration))
  }

The second place interacting with the state store is the clean up of expired states. This operation takes place in the processTimedOutState() method of the InputProcessor. It starts by getting all expired states by calling StateStore's getRange method:

        val timingOutPairs = stateManager.getAllState(store).filter { state =>
          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
        }

Just after, the sale callFunctionAndUpdateState method is called but this time with an empty list for the new rows to process:

        timingOutPairs.flatMap { stateData =>
          callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
        }

The hasTimedOut flag is very often handled in the user-defined stateful function. When it's set to true, state's remove() method should be called to mark the state for deletion and either remove it in the callFunctionAndUpdateState if no timeout configuration was set (groupState.getTimeoutTimestamp == NO_TIMESTAMP from the above snippet), or update if the state expired because of the defined timeout.

You could think that, at least for the default state store backend, the removed state groups won't disappear from the state store's hashmap cache. However, it will disappear because when the state is restored from the state store, its timeoutTimestamp configuration is set to NO_TIMESTAMP and this update snippet is called:

    private def callFunctionAndUpdateState(
        stateData: StateData,
        valueRowIter: Iterator[InternalRow],
        hasTimedOut: Boolean): Iterator[InternalRow] = {

      val keyObj = getKeyObj(stateData.keyRow)  // convert key to objects
      val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
  
// Here the state is created with NO_TIMESTAMP configuration as the
// timeoutTimestamp property. It doesn't mean that the batchTimestampMs or
// eventTimeWatermark are empty because the timeoutTimestamp is left 
// to default value in createForStreaming()
    val groupState = GroupStateImpl.createForStreaming(
        Option(stateData.stateObj),
        batchTimestampMs.getOrElse(NO_TIMESTAMP),
        eventTimeWatermark.getOrElse(NO_TIMESTAMP),
        timeoutConf,
        hasTimedOut,
        watermarkPresent)

      def onIteratorCompletion: Unit = {
        if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
          stateManager.removeState(store, stateData.keyRow)
          numUpdatedStateRows += 1
        } else {

If it's not clear, this short video using rate stream data source should shed some light on it:

Arbitrary stateful processing is the last stateful operation explained in the Data+AI follow-up posts series. As you saw, it has some subtle specificities like the empty state value marker for the removed states. But overall, it's logic is quite similar to the previous operations since it can be summarized as this 3-steps workflow: process new data, process expired states, and commit all the changes. Next time I'll present you with the embedded database alternatives for the state store backend storage.