Inner joins between streams in Apache Spark Structured Streaming

Versions: Apache Spark 2.3.1 https://github.com/bartosz25/spark-...dstreaming/join/InnerJoinTest.scala

Apache Kafka Streams supports joins between streams and the community expected the same for Apache Spark. This feature was implemented and released with recent 2.3.0 version and after some months after that, it's a good moment to talk a little about it.

This post is composed of 2 sections. The first one introduces the idea of joins in streaming pipelines. The next one talks about one of 2 supported types - inner join.

Joins and streaming

Executing join operations on streams is quite challenging and that for many reasons. The most obvious one is latency. For bounded data source all data to join is present in place while for the streaming-based source the data is continuously moving. And sometimes (often?) it's moving not at the same speed. It may be because of technical reasons as ingestion pipeline issues or simply because of functional requirements where related events aren't always generated during similar period of time. One of such functional limitations may be the ordering process on a e-commerce store where the order is rarely finalized very quickly. Thus, the join operation must somehow manage that case of related but very asynchronous events.

Another important point to address is state management. Since data for given event may arrive at any moment (very soon as very late) and the amount of place reserved to store it is limited, the engine must figure out what to do with accumulated state and, especially, when to discard it. This specific time corresponds to the moment when we're not expecting to receive any new event for given join key.

Apache Spark Structured Streaming addressed both questions in the 2.3.0 release by providing the ability to join 2 or more streams. Stream-to-stream joins brought there can be characterized by the following axis:

However, as of Apache Spark 2.3.1, stream-to-stream joins has several limitations:

Streaming inner joins

The first type of supported stream-to-stream joins is inner join. Since it's a strict join when a row without matching is not emitted, it doesn't require any time constraints on joined columns. However it can be dangerous since the rows, even those without matches, may remain in the state store for a very long time. It's because having a condition telling how long the state for a particular key should be persisted is a recommended strategy. The simplest case of inner join when the state is not cleaned is shown below:

it should "output the result as soon as it arrives without watermark" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val stream = mainEventsStream.toDS().join(joinedEventsStream.toDS(), $"mainKey" === $"joinedKey")

  val query = stream.writeStream.foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      var key = 0
      while (true) {
        // Here main event are always sent before the joined
        // But we also send, an event for key - 10 in order to see if the main event is still kept in state store
        joinedEventsStream.addData(Events.joined(s"key${key-10}"))
        val mainEventTime = System.currentTimeMillis()
        mainEventsStream.addData(MainEvent(s"key${key}", mainEventTime, new Timestamp(mainEventTime)))
        Thread.sleep(1000L)
        joinedEventsStream.addData(Events.joined(s"key${key}"))
        key += 1
      }
    }
  }).start()
  query.awaitTermination(60000)

  // As you can see in this test, when neither watermark nor range condition is defined, the state isn't cleared
  // It's why we can see data came 9/10 seconds after the first joined event of the same key
  val groupedByKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key)
  val keysWith2Entries = groupedByKeys.filter(keyWithEntries => keyWithEntries._2.size == 2)
  keysWith2Entries.foreach(keyWithEntries => {
    val entries = keyWithEntries._2
    val metric1 = entries(0)
    val metric2 = entries(1)
    val diffBetweenEvents = metric2.joinedEventMillis - metric1.joinedEventMillis
    val timeDiffSecs = diffBetweenEvents/1000
    (timeDiffSecs >= 9 && timeDiffSecs <= 10) shouldBe true
  })
}

In the previous example the streams were joined with a simple key condition. Another join type uses windows. We can see how they work with the following example where rows with odd keys are put in late, i.e. outside the window where the event of not nullable join side is emitted. In result we should receive only the rows having even keys:

it should "join rows per windows" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", window($"mainEventTimeWatermark", "5 seconds"),
    $"mainEventTime", $"mainEventTimeWatermark")
  val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", window($"joinedEventTimeWatermark", "5 seconds"),
    $"joinedEventTime", $"joinedEventTimeWatermark")
  val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
    mainEventsDataset("window") === joinedEventsDataset("window"))

  val query = stream.writeStream.foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      var key = 0
      var iterationTimeFrom1970 = 1000L // 1 sec
      while (query.isActive) {
        val (key1, key2) = (key + 1, key + 2)
        // join window is of 5 seconds so joining the value 6 seconds later (1 sec of sleep * 6)
        // should exclude given row from the join. Thus, at the end we should retrieve only rows with even keys
        joinedEventsStream.addData(Events.joined(s"key${key1-6}", eventTime = iterationTimeFrom1970))
        mainEventsStream.addData(MainEvent(s"key${key1}", iterationTimeFrom1970, new Timestamp(iterationTimeFrom1970)),
          MainEvent(s"key${key2}", iterationTimeFrom1970, new Timestamp(iterationTimeFrom1970)))
        Thread.sleep(1000L)
        joinedEventsStream.addData(Events.joined(s"key${key2}", eventTime = iterationTimeFrom1970))
        iterationTimeFrom1970 += iterationTimeFrom1970
        key += 2
      }
    }
  }).start()
  query.awaitTermination(60000)

  // Because rows with odd key are joined in late (outside the 5 seconds window), we should find
  // here only rows with even keys
  val processedEventsKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key)
  processedEventsKeys.keys.foreach(key => {
    val keyNumber = key.substring(3).toInt
    keyNumber % 2 == 0 shouldBe true
  })
}

As told in the first section, stream-to-stream joins don't support all possible operations. It's possible to use them together with only map or filter transformations, as proven in the following test cases:

it should "filter and map before joining" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val mainEventsWithMappedKey = mainEventsStream.toDS().filter(mainEvent => mainEvent.mainKey.length > 3)
    .map(mainEvent => mainEvent.copy(mainKey = s"${mainEvent.mainKey}_copy"))
  // For nullable side we deliberately omit the filtering - it shows that the query
  // works even without some subtle differences
  val joinedEventsWithMappedKey = joinedEventsStream.toDS()
    .map(joinedEvent => joinedEvent.copy(joinedKey = s"${joinedEvent.joinedKey}_copy"))

  val stream = mainEventsWithMappedKey.join(joinedEventsWithMappedKey, $"mainKey" === $"joinedKey")

  val query = stream.writeStream.foreach(RowProcessor).start()

  while (!query.isActive) {}
  new Thread(new Runnable() {
    override def run(): Unit = {
      var key = 0
      while (query.isActive) {
        val eventTime = System.currentTimeMillis()
        mainEventsStream.addData(MainEvent(s"key${key}", eventTime, new Timestamp(eventTime)))
        joinedEventsStream.addData(Events.joined(s"key${key}", eventTime = eventTime))
        Thread.sleep(1000L)
        key += 1
      }
    }
  }).start()
  query.awaitTermination(60000)

  val groupedByKeys = TestedValuesContainer.values.groupBy(testedValues => testedValues.key)
  groupedByKeys.keys.foreach(key => {
    key should endWith("_copy")
  })
}

it should "fail when the aggregations are made before the join" in {
  val mainEventsStream = new MemoryStream[MainEvent](1, sparkSession.sqlContext)
  val joinedEventsStream = new MemoryStream[JoinedEvent](2, sparkSession.sqlContext)

  val exception = intercept[AnalysisException] {
    val mainEventsWithMappedKey = mainEventsStream.toDS()
    val joinedEventsWithMappedKey = joinedEventsStream.toDS().groupBy($"joinedKey").count()

    val stream = mainEventsWithMappedKey.join(joinedEventsWithMappedKey, $"mainKey" === $"joinedKey")

    stream.writeStream.foreach(RowProcessor).start()
  }
  exception.getMessage() should include("Append output mode not supported when there are streaming aggregations " +
    "on streaming DataFrames/DataSets without watermark")
}

This first introductory post about stream-to-stream joins talked about inner joins. The first section however presented the joins in a big picture by showing why it was difficult to implement them. As we could learn from there the main problems were late events and state management, both directly impacting results quality and hardware requirements. The second part shown one of 2 implemented join types - inner join. The examples presented the limitations of stream-to-stream joins as well as their behavior for the strictest case, the one where two sides must have matches.