Stateful transformations with mapGroupsWithState

Versions: Spark 2.2.1 https://github.com/bartosz25/spark-...eaming/MapGroupsWithStateTest.scala

Streaming stateful processing in Apache Spark evolved a lot from the first versions of the framework. At the beginning was updateStateByKey but some time after, judged inefficient, it was replaced by mapWithState. With the arrival of Structured Streaming the last method was replaced in its turn by mapGroupsWithState.

This post makes focus on the mapGroupsWithState transformation. In the first part it explains this method with its API details. The second section focuses on the object responsible for state storage. The last part shows, through the usual learning tests, how to use the transformation.

mapGroupsWithState explained

The mapGroupsWithState(timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U) is a transformation applied on a group of data. Since it requires the data to be grouped, it introduces a big risk of shuffling. The state is computed by calling the parameter called func. This transformation can be used on a bounder and unbounded source. For this first case, the final state is computed immediately while in the second one the state can change in every triggered processing. It's the reason why it's stored in a fault-tolerant store between subsequent invocations.

Underneath the mapGroupsWithState calls org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState with the isMapGroupsWithState parameter set to true. Thanks to that we can immediately discover the only one supported output mode for this transformation: update. It's because of this check from FlatMapGroupsWithState:

if (isMapGroupsWithState) {
  assert(outputMode == OutputMode.Update)
}

Now, let's analyze mapGroupsWithState parameters more in details:

GroupState

Another important object involved in the state management is the instance of org.apache.spark.sql.streaming.GroupState[T]. It represent the state kept by Spark for given group (= 1 instance per group). It defines the methods helping to deal with the state lifecycle:

mapGroupsWithState example

The following tests try to explain the specificites of mapGroupsWithState:

private val MappingFunction: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
  val stateNames = state.getOption.getOrElse(Seq.empty)
  val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
  state.update(stateNewNames)
  stateNewNames
}

private val MappingExpirationFunc: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
  if (values.isEmpty && state.hasTimedOut) {
    Seq(s"${state.get}: timed-out")
  } else {
    val stateNames = state.getOption.getOrElse(Seq.empty)
    state.setTimeoutDuration(3000)
    val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
    state.update(stateNewNames)
    stateNewNames
  }
}

"the state" should "expire after event time" in {
  val eventTimeExpirationFunc: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
    if (values.isEmpty && state.hasTimedOut) {
      Seq(s"${state.get}: timed-out")
    } else {
      if (state.getOption.isEmpty) state.setTimeoutTimestamp(3000)
      val stateNames = state.getOption.getOrElse(Seq.empty)
      val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
      state.update(stateNewNames)
      stateNewNames
    }
  }
  val now = 5000L
  val testKey = "mapGroupWithState-state-expired-event-time"
  val inputStream = new MemoryStream[(Timestamp, Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("created", "id", "name")
    .withWatermark("created", "3 seconds")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(eventTimeExpirationFunc)
  inputStream.addData((new Timestamp(now), 1L, "test10"),
    (new Timestamp(now), 1L, "test11"), (new Timestamp(now), 2L, "test20"),
    (new Timestamp(now+now), 3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  new Thread(new Runnable() {
    override def run(): Unit = {
      while (!query.isActive) {}
      Thread.sleep(5000)
      // Now the watermark is about 7000 (10000 - 3000 ms)
      // So not only the values for the id=1 will be rejected but also its state
      // will expire
      inputStream.addData((new Timestamp(now), 1L, "test12"),
        (new Timestamp(now+now), 3L, "test31"))
    }
  }).start()


  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  savedValues should have size 6
  savedValues should contain allOf("test30", "test10,test11", "test20", "List(test10, test11): timed-out",
    "test30,test31", "List(test20): timed-out")
}

"the event-time state expiration" should "fail when the set timeout timestamp is earlier than the watermark" in {
  val eventTimeExpirationFunc: (Long, Iterator[Row], GroupState[Seq[String]]) => Seq[String] = (key, values, state) => {
    if (values.isEmpty && state.hasTimedOut) {
      Seq(s"${state.get}: timed-out")
    } else {
      state.setTimeoutTimestamp(2000)
      val stateNames = state.getOption.getOrElse(Seq.empty)
      val stateNewNames = stateNames ++ values.map(row => row.getAs[String]("name"))
      state.update(stateNewNames)
      stateNewNames
    }
  }
  val now = 5000L
  val testKey = "mapGroupWithState-event-time-state-expiration-failure"
  val inputStream = new MemoryStream[(Timestamp, Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("created", "id", "name")
    .withWatermark("created", "2 seconds")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(eventTimeExpirationFunc)
    inputStream.addData((new Timestamp(now), 1L, "test10"),
      (new Timestamp(now), 1L, "test11"), (new Timestamp(now), 2L, "test20"),
      (new Timestamp(now+now), 3L, "test30"))

  val exception = intercept[StreamingQueryException] {
    val query = mappedValues.writeStream.outputMode("update")
      .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

    new Thread(new Runnable() {
      override def run(): Unit = {
        while (!query.isActive) {}
        Thread.sleep(5000)
        // The watermark here is 8000. But in the eventTimeExpirationFunc we define the expiration timeout
        // to 3000. So obviously it makes the watermark condition from
        // org.apache.spark.sql.execution.streaming.GroupStateImpl.setTimeoutTimestamp fail
        inputStream.addData((new Timestamp(now), 1L, "test12"),
          (new Timestamp(now + now), 3L, "test31"))
      }
    }).start()

    query.awaitTermination(30000)
  }
  exception.getMessage should include("Timeout timestamp (2000) cannot be earlier than the current watermark (8000)")
}

"an event time expiration" should "not be executed when the watermark is not defined" in {
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.EventTimeTimeout())(MappingExpirationFunc)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val exception = intercept[AnalysisException] {
    mappedValues.writeStream.outputMode("update")
      .foreach(new NoopForeachWriter[Seq[String]]()).start()
  }
  exception.getMessage() should include("Watermark must be specified in the query using " +
    "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a [map|flatMap]GroupsWithState. " +
    "Event-time timeout not supported without watermark")
}

"a different state" should "be returned after the state expiration" in {
  val testKey = "mapGroupWithState-state-returned-after-expiration"
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues = inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingExpirationFunc)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  new Thread(new Runnable() {
    override def run(): Unit = {
      while (!query.isActive) {}
      Thread.sleep(5000)
      // In this batch we don't have the values for keys 1 and 3, thus both will
      // be returned as expired when this micro-batch will be processed
      // It's because the processing time timeout is of 3000 ms and here we wait 5000 ms
      // before restarting the processing
      inputStream.addData((2L, "test21"))
    }
  }).start()

  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  savedValues should have size 6
  savedValues should contain allOf("test10,test11", "test30", "test20", "List(test10, test11): timed-out",
    "List(test30): timed-out", "test20,test21")
}

"the state" should "not be discarded when there is no new data to process" in {
  val testKey = "mapGroupWithState-state-not-returned-after-expiration"
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingExpirationFunc)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  // Here, unlike in the above test, only 3 values are returned. Since there is no new
  // micro-batch, the expired entries won't be detected as expired
  // It shows that the state execution depends on the data arrival
  savedValues should have size 3
  savedValues should contain allOf("test10,test11", "test30", "test20")
}

"append mode" should "be disallowed in mapGroupWithState" in {
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingFunction)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val exception = intercept[AnalysisException] {
    mappedValues.writeStream.outputMode("append")
      .foreach(new NoopForeachWriter[Seq[String]]).start()
  }
  exception.getMessage() should include("mapGroupsWithState is not supported with Append output mode on a " +
    "streaming DataFrame/Dataset")
}

"update mode" should "work for mapGroupWithState" in {
  val testKey = "mapGroupWithState-update-output-mode"
  val inputStream = new MemoryStream[(Long, String)](1, sparkSession.sqlContext)
  val mappedValues =inputStream.toDS().toDF("id", "name")
    .groupByKey(row => row.getAs[Long]("id"))
    .mapGroupsWithState(timeoutConf = GroupStateTimeout.ProcessingTimeTimeout)(MappingFunction)
  inputStream.addData((1L, "test10"), (1L, "test11"), (2L, "test20"), (3L, "test30"))

  val query = mappedValues.writeStream.outputMode("update")
    .foreach(new InMemoryStoreWriter[Seq[String]](testKey, (stateSeq) => stateSeq.mkString(","))).start()

  new Thread(new Runnable() {
    override def run(): Unit = {
      while (!query.isActive) {}
      Thread.sleep(5000)
      inputStream.addData((1L, "test12"), (1L, "test13"), (2L, "test21"))
    }
  }).start()

  query.awaitTermination(30000)

  val savedValues = InMemoryKeyedStore.getValues(testKey)
  savedValues should have size 5
  savedValues should contain allOf("test30", "test10,test11", "test20", "test10,test11,test12,test13",
    "test20,test21")
}

In this post we learned about the mapGroupsWithState transformation. The first part described what it does. We could see among others that it enables not only the data processing per groups with a persistent state but that it also handles the state timeout with either event time-based or processing time-based strategies. The second part presented the GroupState, that said an object used to represent the state. In the third part we could see some of rules defining mapGroupsWithState behavior as well some working cases.