Stream-to-stream state management

on waitingforcode.com

Stream-to-stream state management

Last weeks we've discovered 2 stream-to-stream join types in Apache Spark Structured Streaming. As told in these posts, state management logic may be sometimes omitted (for inner joins) but generally it's advised to reduce the memory pressure. Apache Spark proposes 3 different state management strategies that will be detailed in the following sections.

This post is divided in 4 parts. The first one recalls state specificity for the case of stream-to-stream joins. The next 3 talk about 3 state management strategies.

State and streaming joins

As mentioned in outer joins in Apache Spark Structured Streaming post, each potentially joinable row is buffered in the state store. The join and the result emission is made every time when a matching row is found. It's the case of inner join type. For outer join, the logic is a little bit different and the results are emitted either because of matching row or because of expired state. The expired state represents the moment when we're not expecting to receive matching events for given entry. Hopefully this behavior also applies on inner join but the difference is its optional character.

Without this concept of "expired state" the engine would keep the rows to match indefinitely and since the data source is unbounded, it would inevitably fail sooner or later. Hence, Apache Spark provides 3 different strategies to manage the state expiration (watermark).

State key watermark

The first of these strategies is called state key watermark. It's applied on the query when:

  • a watermark column is defined in at least one of joined streams - it may be either a timestamp column or a window column. If the watermark is defined only in one side, Apache Spark is able to deduce from that the watermark for the another side.
  • the watermark column is used in the JOIN clause as an equality constraint

The name of this strategy comes from the use of watermark directly in the JOIN clause condition - hence state key. To illustrate it we can take several examples presented in the below snippet:

"state key watermark" should "be built from watermark used in join" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark")
    .withWatermark("mainEventTimeWatermark", "2 seconds")
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark")
    .withWatermark("joinedEventTimeWatermark", "2 seconds")

  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    mainEventsDataset("mainEventTimeWatermark") === joinedEventsDataset("joinedEventTimeWatermark"))

  val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      val stateManagementHelper = new StateManagementHelper(mainEventsStream, joinedEventsStream)
      var key = 0
      val processingTimeFrom1970 = 10000L // 10 sec
      stateManagementHelper.waitForWatermarkToChange(query, processingTimeFrom1970) 
      println("progress changed, got watermark" + query.lastProgress.eventTime.get("watermark"))
      key = 2
      // We send keys: 2, 3, 4, 5, 7  in late to see watermark applied
      var startingTime = stateManagementHelper.getCurrentWatermarkMillisInUtc(query) - 5000L
      while (query.isActive) {
        //println(s"Sending key=${key} (${new Timestamp(startingTime)}) for watermark ${query.lastProgress.eventTime.get("watermark")}")
        stateManagementHelper.sendPairedKeysWithSleep(s"key${key}", startingTime)
        startingTime += 1000L
        key += 1
      }
    }
  }).start()
  query.awaitTermination(90000)

  val groupedByKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key)
  groupedByKeys.keys should not contain allOf("key2", "key3", "key4", "key5", "key6", "key7")
  // Check some initial keys that should be sent after the first generated watermark
  groupedByKeys.keys should contain allOf("key8", "key9", "key10") 
}

And for the case when the watermark is defined in only 1 side the results should be the same as above:

"state key watermark" should "be built from watermark used in one side of join" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark")
    .withWatermark("mainEventTimeWatermark", "2 seconds")
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark")

  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    mainEventsDataset("mainEventTimeWatermark") === joinedEventsDataset("joinedEventTimeWatermark"))

  val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      val stateManagementHelper = new StateManagementHelper(mainEventsStream, joinedEventsStream)
      var key = 0
      val processingTimeFrom1970 = 10000L // 10 sec
      stateManagementHelper.waitForWatermarkToChange(query, processingTimeFrom1970)
      println("progress changed, got watermark" + query.lastProgress.eventTime.get("watermark"))
      key = 2
      // We send keys: 2, 3, 4, 5, 7  in late to see watermark applied
      var startingTime = stateManagementHelper.getCurrentWatermarkMillisInUtc(query) - 5000L
      while (query.isActive) {
        stateManagementHelper.sendPairedKeysWithSleep(s"key${key}", startingTime)
        startingTime += 1000L
        key += 1
      }
    }
  }).start()
  query.awaitTermination(90000)

  val groupedByKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key)
  groupedByKeys.keys should not contain allOf("key2", "key3", "key4", "key5", "key6", "key7")
  // Check some initial keys that should be sent after the first generated watermark
  groupedByKeys.keys should contain allOf("key8", "key9", "key10") 
}

State key watermark also works for windowed watermark column:

"state key watermark" should "be built from watermark used in join window" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark",
    window($"mainEventTimeWatermark", "5 seconds").as("watermarkWindow")).withWatermark("watermarkWindow", "5 seconds")
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark",
    window($"joinedEventTimeWatermark", "5 seconds").as("watermarkWindow")).withWatermark("watermarkWindow", "5 seconds")

  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    mainEventsDataset("watermarkWindow") === joinedEventsDataset("watermarkWindow"))

  val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      val stateManagementHelper = new StateManagementHelper(mainEventsStream, joinedEventsStream)
      var key = 0
      val processingTimeFrom1970 = 0
      stateManagementHelper.waitForWatermarkToChange(query, processingTimeFrom1970)
      println("progress changed, got watermark" + query.lastProgress.eventTime.get("watermark"))
      key = 2
      var startingTime = stateManagementHelper.getCurrentWatermarkMillisInUtc(query)
      while (query.isActive) {
        val joinedKeyTime = if (key % 2 == 0) {
          startingTime
        } else {
          // for odd keys we define the time for previous window
          startingTime - 6000L
        }
        stateManagementHelper.sendPairedKeysWithSleep(s"key${key}", startingTime, Some(joinedKeyTime))
        startingTime += 1000L
        key += 1
      }
    }
  }).start()
  query.awaitTermination(90000)

  val allKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key).keys
  val oddNumberKeys = allKeys.map(key => key.substring(3).toInt).filter(key => key > 1 && key % 2 != 0)
  oddNumberKeys shouldBe empty
}

State value watermark

Another state management strategy in stream-to-stream joins is used when the JOIN clause doesn't contain the equality condition on watermark field. Instead of it, the query has a condition called range condition expressed as an inequality on watermark column. Thus, the name of this strategy is state value watermark. Its use is conditioned by:

  • existence of a watermark column - as previously, timestamp and window are allowed here
  • definition of a range condition on watermark column in the JOIN clause - watermark column must be joined otherwise than on the values equality

The range condition defined in the JOIN clause automatically impacts the watermark of one of joined sides. When this condition is expressed as leftTimeWatermark > rightTimeWatermark + 10 minutes, we automatically know that the left side will accept only events occurred later than the right side's watermark + 10 minutes. That said, if the right side watermark is 10:00, then the watermark of the left side becomes automatically 10:10. It works also inversely, i.e. left watermark impacts the right one. To see that nothing better than a simple example:

"state value watermark" should "be built from a watermark column and range condition" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark")
    .withWatermark("mainEventTimeWatermark", "2 seconds")
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark")
    .withWatermark("joinedEventTimeWatermark", "2 seconds")

  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    expr("joinedEventTimeWatermark > mainEventTimeWatermark + interval 2 seconds"))

  val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      val stateManagementHelper = new StateManagementHelper(mainEventsStream, joinedEventsStream)
      var key = 0
      val processingTimeFrom1970 = 10000L // 10 sec
      stateManagementHelper.waitForWatermarkToChange(query, processingTimeFrom1970)
      println("progre ss changed, got watermark" + query.lastProgress.eventTime.get("watermark"))
      key = 2
      var startingTime = stateManagementHelper.getCurrentWatermarkMillisInUtc(query)
      while (query.isActive) {
        val joinedSideEventTime = if (startingTime % 2000 == 0) {
          startingTime + 3000L
        } else {
          // the value computed like this is evidently after the watermark, so should be accepted in the stream
          // but since the range condition is stricter, the row will be ignored
          startingTime - 1000L
        }
        stateManagementHelper.sendPairedKeysWithSleep(s"key${key}", startingTime, Some(joinedSideEventTime))
        startingTime += 1000L
        key += 1
      }
    }
  }).start()
  query.awaitTermination(90000)

  val processedKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key).keys
  val keyNumbers = processedKeys.map(key => key.substring(3).toInt)
  val oddKeyNumbers = keyNumbers.filter(keyNumber => keyNumber % 2 != 0)
  oddKeyNumbers shouldBe empty
}

This strategy will also work if we define 2 different watermark values on both sides. The difference is that the engine will take one common watermark value being the smallest watermark among joined streams. We can observe that in the following test case when taken watermark for joined stream is 1970-01-01T00:00:01.000Z:

"state value watermark" should "be built from different watermark columns and range condition" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark")
    .withWatermark("mainEventTimeWatermark", "2 seconds")
  // To see what happens, let's set the watermark of joined side to 10 times more than the main dataset
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark")
    .withWatermark("joinedEventTimeWatermark", "20 seconds")

  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    expr("joinedEventTimeWatermark > mainEventTimeWatermark + interval 2 seconds"))

  val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

  var firstWatermark: Option[String] = None
  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      val stateManagementHelper = new StateManagementHelper(mainEventsStream, joinedEventsStream)
      var key = 0
      // 21 sec ==> watermark is MAX(event_time) - 20'' and lower value will never change it
      val processingTimeFrom1970 = 21000L
      stateManagementHelper.waitForWatermarkToChange(query, processingTimeFrom1970)
      println("progress changed, got watermark" + query.lastProgress.eventTime.get("watermark"))
      key = 2
      firstWatermark = Some(query.lastProgress.eventTime.get("watermark"))
      var startingTime = stateManagementHelper.getCurrentWatermarkMillisInUtc(query)
      while (query.isActive) {
        val joinedSideEventTime = if (startingTime % 2000 == 0) {
          startingTime + 3000L
        } else {
          startingTime - 1000L
        }
        stateManagementHelper.sendPairedKeysWithSleep(s"key${key}", startingTime, Some(joinedSideEventTime))
        startingTime += 1000L
        key += 1
      }
    }
  }).start()
  query.awaitTermination(90000)

  firstWatermark shouldBe defined
  firstWatermark.get shouldEqual "1970-01-01T00:00:01.000Z"
  val processedKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key).keys
  // In this case we don't expect event numbers because odd numbers goes to the first sending condition and the others
  // to the second one
  val keyNumbers = processedKeys.map(key => key.substring(3).toInt)
  val evenKeyNumbers = keyNumbers.filter(keyNumber => keyNumber % 2 == 0)
  evenKeyNumbers shouldBe empty
}

Obviously, having 2 different watermarks won't work on state key watermark example because we're joining the streams by watermark equality.

State value watermark can also be applied on windowed watermarks but, as in the case of timestamp columns, it must be expressed as an inequality in the JOIN clause:

"state value watermark" should "be built from a watermark window column and range condition" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark",
    window($"mainEventTimeWatermark", "3 seconds").as("mainWatermarkWindow")).withWatermark("mainWatermarkWindow", "3 seconds")
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark",
    window($"joinedEventTimeWatermark", "3 seconds").as("joinedWatermarkWindow")).withWatermark("joinedWatermarkWindow", "3 seconds")

  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    expr("joinedWatermarkWindow > mainWatermarkWindow"))

  val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      val stateManagementHelper = new StateManagementHelper(mainEventsStream, joinedEventsStream)
      var key = 0
      val processingTimeFrom1970 = 10000L // 10 sec
      stateManagementHelper.waitForWatermarkToChange(query, processingTimeFrom1970)
      println("progress changed, got watermark" + query.lastProgress.eventTime.get("watermark"))
      key = 2
      var startingTime = stateManagementHelper.getCurrentWatermarkMillisInUtc(query)
      while (query.isActive) {
        val joinedSideEventTime = if (key % 2 == 0) {
          startingTime + 4000L
        } else {
          startingTime - 4000L
        }
        stateManagementHelper.sendPairedKeysWithSleep(s"key${key}", startingTime, Some(joinedSideEventTime))
        startingTime += 1000L
        key += 1
      }
    }
  }).start()
  query.awaitTermination(90000)

  val processedKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key).keys
  processedKeys.nonEmpty shouldBe true
  val keyNumbers = processedKeys.map(key => key.substring(3).toInt)
  val oddKeyNumbers = keyNumbers.filter(keyNumber => keyNumber % 2 != 0)
  oddKeyNumbers shouldBe empty
}

Mixed

The last strategy is called mixed and it occurs when 2 previous strategies are defined. However the engine uses only one - the state key watermark because of its stricter character. It's shown pretty clearly in getOneSideStateWatermarkPredicate(oneSideInputAttributes: Seq[Attribute], oneSideJoinKeys: Seq[Expression], otherSideInputAttributes: Seq[Attribute]):

if (isWatermarkDefinedOnJoinKey) { // case 1 and 3 in the StreamingSymmetricHashJoinExec docs
  // ...
} else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs
  // ...
} else {
  None
}

The following test shows the presenence of both strategies in the JOIN clause:

"mixed watermark" should "use stricter state key watermark" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark")
    .withWatermark("mainEventTimeWatermark", "2 seconds")
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark")
    .withWatermark("joinedEventTimeWatermark", "2 seconds")

  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    mainEventsDataset("mainEventTimeWatermark") === joinedEventsDataset("joinedEventTimeWatermark") &&
    expr("joinedEventTimeWatermark >= mainEventTimeWatermark - interval 2 seconds"))

  val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      val stateManagementHelper = new StateManagementHelper(mainEventsStream, joinedEventsStream)
      var key = 0
      val processingTimeFrom1970 = 10000L // 10 sec
      stateManagementHelper.waitForWatermarkToChange(query, processingTimeFrom1970)
      println("progress changed, got watermark" + query.lastProgress.eventTime.get("watermark"))
      key = 2
      // We send keys: 2, 3, 4, 5, 7  in late to see watermark applied
      var startingTime = stateManagementHelper.getCurrentWatermarkMillisInUtc(query) - 5000L
      while (query.isActive) {
        stateManagementHelper.sendPairedKeysWithSleep(s"key${key}", startingTime)
        startingTime += 1000L
        key += 1
      }
    }
  }).start()
  query.awaitTermination(90000)

  val groupedByKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key)
  groupedByKeys.keys should not contain allOf("key2", "key3", "key4", "key5", "key6", "key7")
  // Check some initial keys that should be sent after the first generated watermark
  groupedByKeys.keys should contain allOf("key8", "key9", "key10")
  println(s"got keys=${groupedByKeys.mkString("\n")}")
}

Stream-to-stream joins are interesting Apache Spark Structured Streaming feature. However, without properly managed state lifecycle, they may become a nightmare. Unbounded data source is the synonym of unbounded hardware resources, so big costs and maintenance complexity. It's the reason why the engine provides 3 different strategies to deal with the state: state key watermark, state value watermark and mixed. All of them use the concept of watermark to detect the rows coming to late. They become automatically the rows to be discarded in the next processing loop.

Share, like or comment this post on Twitter:

Share on: