State lifecycle management in Structured Streaming

Versions: Apache Spark 2.4.2

In this post about state store in Structured Streaming I will focus on the state lifecycle management. The goal is to see what happens when the state expires, why removing it from the state store is so important and some other interesting questions!

How state can expire?

The expiration of a state can be controlled either with a fixed timeout value or with a specific time. You will find 2 methods responsible for that in GroupState instance that you can access in your mapGroupsWithState or flatMapGroupsWithState methods.

Which one of them you can use is controlled in another moment when you're declaring the function for *WithState transformation. At this specific moment you'll say whether a processing time-based or an event time-based expiration policy should apply:

// processing-time policy
    val mappedValues = inputStream.toDS().toDF("id", "name")
      .groupByKey(row => row.getAs[Long]("id"))
      .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingExpirationFunc)

// event-time policy
    val mappedValues =inputStream.toDS().toDF("created", "id", "name")
      .withWatermark("created", "3 seconds")
      .groupByKey(row => row.getAs[Long]("id"))
      .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(eventTimeExpirationFunc)

If you decide to use an event-time expiration policy, you'll also need to define the watermark for your query.

Why we need to enable watermark for event-time expiration?

Using an event-time state expiration automatically requires the use of watermark. Otherwise, you will get such beautiful exception:

Exception in thread "main" org.apache.spark.sql.AnalysisException: Watermark must be specified in the query using '[Dataset/DataFrame].withWatermark()' for using event-time timeout in a [map|flatMap]GroupsWithState. Event-time timeout not supported without watermark.;;

Why that constraint? After all, the watermark should apply only on the input data to decide how late can be the data integrated into the processing. In the case of stateful processing, this definition should be extended. In fact, the watermark must be used alongside the event-time expiration time to provide a finest grained expiration control.

As stated in the documentation, "timeout will never occur before watermark has exceeded the set timeout". If we compare event-time timeout with processing-time one, we can say that for any event-time state, the watermark acts as the real computation time in processing-time state. So it's used as a clock to detect whether the state should expire.

What happens if state expiration and watermark are defined?

In that case, the state advances with the watermark, so the state for a particular key can be closed as soon as more recent events arrive and that state doesn't receive new events. Once again, since the watermark is synonymous of the real-time clock of processing time-based operations, this behavior is perfectly fine.

In the test below you can see that among 2 states, one expires because it doesn't receive new events from incoming data and that the second one is still alive because it gets some. In that situation, we can also consider the input messages as the heartbeats marking the state as "still alive":

  "the state for key number 2" should "expire because of the watermark" in {
    val sparkSession: SparkSession = SparkSession.builder()
      .appName("Spark Structured Streaming watermark state expiration")
      .config("spark.sql.shuffle.partitions", "1")
      .master("local[2]").getOrCreate()
    import sparkSession.implicits._

    val eventTimeExpirationFunc: (Long, Iterator[Row], GroupState[String]) => String = (key, values, state) => {
      if (values.isEmpty && state.hasTimedOut) {
        println(s"State for ${key} expired / current watermark is ${state.getCurrentWatermarkMs()}")
        s"Expired state for ${key}"
      } else {
        val stateNames = state.getOption.getOrElse(Seq.empty)
        val stateNewNames = stateNames + " "+ values.map(row => row.getAs[String]("name")).mkString(" ")
        state.update(stateNewNames)
        val expirationTime = state.getCurrentWatermarkMs() + 4000L
        state.setTimeoutTimestamp(expirationTime)
        println(s"Updated state for ${key} will expire at ${expirationTime}  / " +
          s"current watermark is ${state.getCurrentWatermarkMs()}")
        stateNewNames
      }
    }
    val testKey = "summit19-watermark-expiration"
    val inputStream = new MemoryStream[(Long, Timestamp, String)](1, sparkSession.sqlContext)
    val mappedValues =inputStream.toDS().toDF("id", "event_time",  "name")
      .withWatermark("event_time", "0 seconds")
      .groupByKey(row => row.getAs[Long]("id"))
      .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(eventTimeExpirationFunc)
    val startTime = 1000L
    inputStream.addData((1L, new Timestamp(startTime), "test10"),
      (2L, new Timestamp(startTime), "test20"), (3L, new Timestamp(startTime), "test30"))

    val query = mappedValues.writeStream.outputMode("update")
      .foreach(new InMemoryStoreWriter[String](testKey, (stateTxt) => stateTxt)).start()

    new Thread(new Runnable() {
      override def run(): Unit = {
        while (!query.isActive) {}
        var timeToAdd = 2000L
        while (true) {
          Thread.sleep(2000L)
          val newEventTime = startTime + timeToAdd
          inputStream.addData((1L, new Timestamp(newEventTime), "test12"),
            (3L, new Timestamp(newEventTime), "test31"))
          timeToAdd += 2000L
        }
      }
    }).start()


    query.awaitTermination(40000)
  }

In the following video you can see that the state 2 passes behind watermark at the very beginning of the query execution and, therefore, it expires:

Why timeout cannot be earlier than watermark?

When I was playing with stateful aggregation, I made a mistake to define the new timeout value to be earlier than the current watermark. And Apache Spark returned an error for that case:

Caused by: java.lang.IllegalArgumentException: Timeout timestamp (1564291203000) cannot be earlier than the current watermark (1564291872001)
      at org.apache.spark.sql.execution.streaming.GroupStateImpl.setTimeoutTimestamp(GroupStateImpl.scala:114)

It happened when the watermark was of 3 minutes and the expiration time of only 30 seconds. To overcome that issue you can use GroupState's getCurrentWatermarkMs() method returning the event-time watermark at the moment of call, and add your expiration time directly to it.

How can it drop too late state?

For mapGroupsWithState, it can't. The state won't be automatically dropped by the watermark. Instead, the watermark will be used to mark expired group states and it will be up to the user to explicitly remove them from the state store.

But it's only valid for mapGroupsWithState operation where the users can access state lifecycle. Other operations, where the state internals are hidden (streaming deduplication, streaming joins), watermark will filter out expired rows and remove them from state store. You can check this out in SymmetricHashJoinStateManager.KeyWithIndexToValueStore#remove or in WatermarkSupport#removeKeysOlderThanWatermark methods.

What happens if state expires and is not removed in *GroupsWithState operation?

From the previous point you know that the state must be explicitly removed by the user at its expiration. It's made by calling GroupState's remove() method. But what happens if we don't call this method? Let's see that in this code snippet:

  "the state" should "expire but be still kept in the state store when it's not removed explicitly" in {
    val sparkSession: SparkSession = SparkSession.builder()
      .appName("Spark Structured Streaming output modes - mapGroupsWithState")
      .config("spark.sql.shuffle.partitions", "1")
      .master("local[2]").getOrCreate()
    import sparkSession.implicits._
    
    val eventTimeExpirationFunc: (Long, Iterator[Row], GroupState[String]) => String = (key, values, state) => {
      if (values.isEmpty && state.hasTimedOut) {
        println(s"State for ${key} expired")
        s"Expired state for ${key}"
      } else {
        val stateNames = state.getOption.getOrElse(Seq.empty)
        val stateNewNames = stateNames + " "+ values.map(row => row.getAs[String]("name")).mkString(" ")
        state.update(stateNewNames)
        state.setTimeoutDuration(4000L)
        println(s"Got state=${state}")
        stateNewNames
      }
    }
    val testKey = "summit19-no-state-removal"
    val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
    val mappedValues =inputStream.toDS().toDF("id", "name")
      .groupByKey(row => row.getAs[Long]("id"))
      .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout())(eventTimeExpirationFunc)
    inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

    val query = mappedValues.writeStream.outputMode("update")
      .option("checkpointLocation", "/tmp/no-state-removal")
      .foreach(new InMemoryStoreWriter[String](testKey, (stateTxt) => stateTxt)).start()

    new Thread(new Runnable() {
      override def run(): Unit = {
        while (!query.isActive) {}
        while (true) {
          Thread.sleep(2000L)
          // Do not add state=2L - it should expire as soon as possible
          inputStream.addData((1L, "test12"), (3L, "test31"))
        }
      }
    }).start()


    query.awaitTermination(60000)
  }

The snippet is quite straightforward. It's simply accumulating some very short-living state (4 seconds). One of the rows expires very soon because it doesn't receive new events from the producer thread. And it's the event I'm not expecting to see when I'll read the state store data for every version. In addition to that, I'm also checkpointing the state store to my local file system to investigate its content with this code located under org.apache.spark.sql.execution.streaming.state package:

val sparkConfig = new SQLConf()
val provider = new HDFSBackedStateStoreProvider()
provider.init(
  StateStoreId(
    checkpointRootLocation = "file:///tmp/no-state-removal/state",
    operatorId = 0L, partitionId = 0,
    storeName = "default"
  ),
  keySchema = StructType(Seq(StructField("value", LongType, false))),
  valueSchema = StructType(Seq(
    StructField("groupState", StructType(Seq(StructField("value", StringType, true))), true),
    StructField("timeoutTimestamp", LongType, false)
  )),
  None,
  storeConf = StateStoreConf(sparkConfig),
  hadoopConf = new Configuration(true)
)

for (storeVersion <- 1 to 131) {
  println(s"Store version ${storeVersion}")
  provider.getStore(storeVersion).iterator()
    .foreach(rowPair => {
      val key = rowPair.key.getLong(0)
      val value = rowPair.value.getString(0)
      val timeout = rowPair.value.getLong(1)
      println(s"${rowPair.key} ${key}=${value} expires at ${timeout}")
    }
  println("----------------------------------")
}

provider.latestIterator().foreach(rowPair => {
  val key = rowPair.key.getLong(0)
  val value = rowPair.value.getString(0)
  val timeout = rowPair.value.getLong(1)
  println(s"${key}=${value} expires at ${timeout}")
})

In the screencast you can see that the state is removed from the state store only when it's explicitly removed by the user. So remember to handle that case correctly on your own. But as stated before, any other operation involving state store than stateful mapping will take care of removing the expired state on its own.

Can I manage the state on my own?

Yes, there is a way to control state expiration only on your own. You can simply omit the call to the state expiration method (setTimeoutDuration or setTimeoutTimestamp) and set the timeoutConf parameter of *GroupsWithState function to GroupStateTimeout.NoTimeout(). In that situation, when Apache Spark will invoke the method to process expired state (FlatMapGroupsWithStateExec.InputProcessor#processTimedOutState) it will simply do nothing because the isTimeoutEnabled will be set to false and the returned result will be an empty iterator:

def processTimedOutState(): Iterator[InternalRow] = {
  if (isTimeoutEnabled) {
     // handle expired states
  } else Iterator.empty
}

What happens if the state is about to expire and it receives new events?

It depends because receiving new events doesn't mean that the state will be extended automatically. Internally Apache Spark processes state in this line of FlatMapGroupsWithStateExec:

        // Generate a iterator that returns the rows grouped by the grouping function
        // Note that this code ensures that the filtering for timeout occurs only after
        // all the data has been processed. This is to ensure that the timeout information of all
        // the keys with data is updated before they are processed for timeouts.
        val outputIterator =
          processor.processNewData(filteredIter) ++ processor.processTimedOutState()

In the first method it will call the state function with new input logs. The second method applies on the expired state which is computed that way:

        val timeoutThreshold = timeoutConf match {
          case ProcessingTimeTimeout => batchTimestampMs.get
          case EventTimeTimeout => eventTimeWatermark.get
          case _ =>
            throw new IllegalStateException(
              s"Cannot filter timed out keys for $timeoutConf")
        }
        val timingOutPairs = stateManager.getAllState(store).filter { state =>
          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
        }
        timingOutPairs.flatMap { stateData =>
          callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
        }

As you can see, your stateful function can be called twice in single query execution, once to handle the new input logs and once to handle state expiration.

Few takeaways after these questions. First, ensure that you remove the expired state. Otherwise, you may end up with indefinitely growing state store. Besides that, think also about configuring your timeout properly. You can either manage it by yourself with NoTimeout configuration, use processing time or event-time with the help of watermark. All have their pros and cons and it's up to you to choose the one fitting into your processing logic.


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!