Stateful transformations with mapGroupsWithState

on waitingforcode.com

Stateful transformations with mapGroupsWithState

Streaming stateful processing in Apache Spark evolved a lot from the first versions of the framework. At the beginning was updateStateByKey but some time after, judged inefficient, it was replaced by mapWithState. With the arrival of Structured Streaming the last method was replaced in its turn by mapGroupsWithState.

This post makes focus on the mapGroupsWithState transformation. In the first part it explains this method with its API details. The second section focuses on the object responsible for state storage. The last part shows, through the usual learning tests, how to use the transformation.

mapGroupsWithState explained

The mapGroupsWithState(timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U) is a transformation applied on a group of data. Since it requires the data to be grouped, it introduces a big risk of shuffling. The state is computed by calling the parameter called func. This transformation can be used on a bounder and unbounded source. For this first case, the final state is computed immediately while in the second one the state can change in every triggered processing. It's the reason why it's stored in a fault-tolerant store between subsequent invocations.

Underneath the mapGroupsWithState calls org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState with the isMapGroupsWithState parameter set to true. Thanks to that we can immediately discover the only one supported output mode for this transformation: update. It's because of this check from FlatMapGroupsWithState:

if (isMapGroupsWithState) {
  assert(outputMode == OutputMode.Update)
}

Now, let's analyze mapGroupsWithState parameters more in details:

  • timeoutConf: GroupStateTimeout - as the name indicates, this argument is responsible for handling the timeouts. The mapGroupsWithState transformation accepts one of 3 different timeout values: no timeout, processing-time based or event-time based (works only if watermark is defined). The state expiration time is defined inside the function passed as func parameter, either with org.apache.spark.sql.streaming.GroupState#setTimeoutDuration (processing time) or org.apache.spark.sql.streaming.GroupState#setTimeoutTimestamp (event time).
    Every time when at least one record is sent to the group, the timeout is updated. The state is considered as expired only when given group hasn't received any data during the specified threshold.
    The timeout is defined inside the specific GroupState object (details above this list) so it completely possible to specify different timeout configurations depending on processed group.
  • func: (K, Iterator[V], GroupState[S]) => U - this function defines how the group values will be processed in order to generate the state. As you can correctly notice, it's completely type-independent, i.e. neither the values ([V]) nor the state ([S]) influences the type of returned dataset ([U]).
    This function is called when 1 of 2 conditions is met: either the group has new values to process or the state has expired. In this second case, the function is called with an empty list of values in the 2nd parameter.

GroupState

Another important object involved in the state management is the instance of org.apache.spark.sql.streaming.GroupState[T]. It represent the state kept by Spark for given group (= 1 instance per group). It defines the methods helping to deal with the state lifecycle:

  • exists - tells if the state for given group was set. Obviously, it returns false every time when the state has just expired or if given group is processed for the first time.
  • get - gets the state associated with given group. However it must be used carefuly since it throws NoSuchElementException when the state doesn't exist. Another, more safe maneer to get the state is getOption method optionally returning the state or None if it doesn't exist for given group.
  • update(newState: S) - this method overrides existent state by the newState defined in the parameter.
  • remove() - it removes the state associated to given group.
  • hasTimedOut - it returns true if the state has just expired
  • setTimeoutDuration(...) - this method defines the timeout value for processing-time configuration. It accepts either a Long (ms) or a stringified representation of time
  • setTimeoutTimestamp(...) - defines the timeout as milliseconds in epoch time

mapGroupsWithState example

The following tests try to explain the specificites of mapGroupsWithState:

private val MappingFunction: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
  val stateNames = state.getOption.getOrElse(Seq.empty)
  val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
  state.update(stateNewNames)
  stateNewNames
}

private val MappingExpirationFunc: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
  if (values.isEmpty && state.hasTimedOut) {
    Seq(s"${state.get}: timed-out")
  } else {
    val stateNames = state.getOption.getOrElse(Seq.empty)
    state.setTimeoutDuration(3000)
    val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
    state.update(stateNewNames)
    stateNewNames
  }
}

"the state" should "expire after event time" in {
  val eventTimeExpirationFunc: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
    if (values.isEmpty && state.hasTimedOut) {
      Seq(s"${state.get}: timed-out")
    } else {
      if (state.getOption.isEmpty) state.setTimeoutTimestamp(3000)
      val stateNames = state.getOption.getOrElse(Seq.empty)
      val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
      state.update(stateNewNames)
      stateNewNames
    }
  }
  val now = 5000L
  val testKey = "mapGroupWithState-state-expired-event-time"
  val inputStream = new MemoryStream[(Timestamp, Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("created", "id", "name")
    .withWatermark("created", "3 seconds")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(eventTimeExpirationFunc)
  inputStream.addData((new Timestamp(now), 1L, "test10"),
    (new Timestamp(now), 1L, "test11"), (new Timestamp(now), 2L, "test20"),
    (new Timestamp(now+now), 3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  new Thread(new Runnable() {
    override def run(): Unit = {
      while (!query.isActive) {}
      Thread.sleep(5000)
      // Now the watermark is about 7000 (10000 - 3000 ms)
      // So not only the values for the id=1 will be rejected but also its state
      // will expire
      inputStream.addData((new Timestamp(now), 1L, "test12"),
        (new Timestamp(now+now), 3L, "test31"))
    }
  }).start()


  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  savedValues should have size 6
  savedValues should contain allOf("test30", "test10,test11", "test20", "List(test10, test11): timed-out",
    "test30,test31", "List(test20): timed-out")
}

"the event-time state expiration" should "fail when the set timeout timestamp is earlier than the watermark" in {
  val eventTimeExpirationFunc: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
    if (values.isEmpty && state.hasTimedOut) {
      Seq(s"${state.get}: timed-out")
    } else {
      state.setTimeoutTimestamp(2000)
      val stateNames = state.getOption.getOrElse(Seq.empty)
      val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
      state.update(stateNewNames)
      stateNewNames
    }
  }
  val now = 5000L
  val testKey = "mapGroupWithState-event-time-state-expiration-failure"
  val inputStream = new MemoryStream[(Timestamp, Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("created", "id", "name")
    .withWatermark("created", "2 seconds")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(eventTimeExpirationFunc)
    inputStream.addData((new Timestamp(now), 1L, "test10"),
      (new Timestamp(now), 1L, "test11"), (new Timestamp(now), 2L, "test20"),
      (new Timestamp(now+now), 3L, "test30"))

  val exception = intercept[StreamingQueryException] {
    val query = mappedValues.writeStream.outputMode("update")
      .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

    new Thread(new Runnable() {
      override def run(): Unit = {
        while (!query.isActive) {}
        Thread.sleep(5000)
        // The watermark here is 8000. But in the eventTimeExpirationFunc we define the expiration timeout
        // to 3000. So obviously it makes the watermark condition from
        // org.apache.spark.sql.execution.streaming.GroupStateImpl.setTimeoutTimestamp fail
        inputStream.addData((new Timestamp(now), 1L, "test12"),
          (new Timestamp(now + now), 3L, "test31"))
      }
    }).start()

    query.awaitTermination(30000)
  }
  exception.getMessage should include("Timeout timestamp (2000) cannot be earlier than the current watermark (8000)")
}

"an event time expiration" should "not be executed when the watermark is not defined" in {
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(MappingExpirationFunc)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val exception = intercept[AnalysisException] {
    mappedValues.writeStream.outputMode("update")
      .foreach(new NoopForeachWriter[Seq[String]]()).start()
  }
  exception.getMessage() should include("Watermark must be specified in the query using " +
    "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a [map|flatMap]GroupsWithState. " +
    "Event-time timeout not supported without watermark")
}

"a different state" should "be returned after the state expiration" in {
  val testKey = "mapGroupWithState-state-returned-after-expiration"
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues = inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingExpirationFunc)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  new Thread(new Runnable() {
    override def run(): Unit = {
      while (!query.isActive) {}
      Thread.sleep(5000)
      // In this batch we don't have the values for keys 1 and 3, thus both will
      // be returned as expired when this micro-batch will be processed
      // It's because the processing time timeout is of 3000 ms and here we wait 5000 ms
      // before restarting the processing
      inputStream.addData((2L, "test21"))
    }
  }).start()

  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  savedValues should have size 6
  savedValues should contain allOf("test10,test11", "test30", "test20", "List(test10, test11): timed-out",
    "List(test30): timed-out", "test20,test21")
}

"the state" should "not be discarded when there is no new data to process" in {
  val testKey = "mapGroupWithState-state-not-returned-after-expiration"
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingExpirationFunc)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  // Here, unlike in the above test, only 3 values are returned. Since there is no new
  // micro-batch, the expired entries won't be detected as expired
  // It shows that the state execution depends on the data arrival
  savedValues should have size 3
  savedValues should contain allOf("test10,test11", "test30", "test20")
}

"append mode" should "be disallowed in mapGroupWithState" in {
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingFunction)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val exception = intercept[AnalysisException] {
    mappedValues.writeStream.outputMode("append")
      .foreach(new NoopForeachWriter[Seq[String]]).start()
  }
  exception.getMessage() should include("mapGroupsWithState is not supported with Append output mode on a " +
    "streaming DataFrame/Dataset")
}

"update mode" should "work for mapGroupWithState" in {
  val testKey = "mapGroupWithState-update-output-mode"
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingFunction)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  new Thread(new Runnable() {
    override def run(): Unit = {
      while (!query.isActive) {}
      Thread.sleep(5000)
      inputStream.addData((1L, "test12"), (1L, "test13"), (2L, "test21"))
    }
  }).start()

  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  savedValues should have size 5
  savedValues should contain allOf("test30", "test10,test11", "test20", "test10,test11,test12,test13",
    "test20,test21")
}

In this post we learned about the mapGroupsWithState transformation. The first part described what it does. We could see among others that it enables not only the data processing per groups with a persistent state but that it also handles the state timeout with either event time-based or processing time-based strategies. The second part presented the GroupState, that said an object used to represent the state. In the third part we could see some of rules defining mapGroupsWithState behavior as well some working cases.

Read also about Stateful transformations with mapGroupsWithState here: Exploring Stateful Streaming with Spark Structured Streaming .

Share, like or comment this post on Twitter: