Stream-to-stream joins internals

Versions: Apache Spark 2.3.1

In 3 recent posts about Apache Spark Structured Streaming we discovered streaming joins: inner joins, outer joins and state management strategies. Discovering what happens under-the-hood of all of these operations is a good point to sum up the series.

New ebook 🔥

Learn 84 ways to solve common data engineering problems with cloud services.

👉 I want my copy

This post starts by a section presenting the classes involved in the streaming joins process. It's followed by a section focused on state management internals related to the joins. The post ends with a short section about join mechanism.

Involved classes

Among the classes involved in streaming joins we can distinguish 3 that are very important: SymmetricHashJoinStateManager, StreamingSymmetricHashJoinExec and StreamingJoinHelper. All of them are used in different stages of the streaming query execution.

To begin, streaming representation of the query, IncrementalExecution instance, stores a reference to the state. This state, if the query has some stream-to-stream joins, is represented at every execution as an instance of StreamingSymmetricHashJoinExec. This instance is different in each execution and the difference point are the offset stats and state watermark predicate. The predicate is computed by:

def getStateWatermarkPredicates(
    leftAttributes: Seq[Attribute],
    rightAttributes: Seq[Attribute],
    leftKeys: Seq[Expression],
    rightKeys: Seq[Expression],
    condition: Option[Expression],
    eventTimeWatermark: Option[Long]): JoinStateWatermarkPredicates

This method computes state watermark predicate used to discard too late rows from state store by applying different rules. It first checks if among all columns involved in the query's equality JOIN, there is at least one that is marked with watermark annotation. If yes, it automatically considers that the state key watermark strategy must be applied to late rows (you can read about them in Outer joins in Apache Spark Structured Streaming). If not, it checks whether one join side has a watermark column defined. If one of these conditions is met, with the preference for the former one, a predicate expression is built with usual org.apache.spark.sql.execution.streaming.WatermarkSupport#watermarkExpression(optionalWatermarkExpression: Option[Expression],optionalWatermarkMs: Option[Long]) method.

Equality importance in JOIN clause

Actual implementation of stream-to-stream joins accepts only equality relation as join keys. It means that if we've 2 streams: stream#1(field1[int], field2[timestamp]), stream#2(field10[int], field20[timestamp]), only the equality relationships between field1 and field10, and field2 and field20 will be considered as join keys. If an inequality is expressed in ON part of the JOIN, it's transformed to a WHERE condition.

For instance, following query:

val mainEventsDataset = mainEventsStream.toDS().select($"mainKey", $"mainEventTime", $"mainEventTimeWatermark",
window($"mainEventTimeWatermark", "3 seconds").as("mainWatermarkWindow")).withWatermark("mainWatermarkWindow", "3 seconds")
val joinedEventsDataset = joinedEventsStream.toDS().select($"joinedKey", $"joinedEventTime", $"joinedEventTimeWatermark",
window($"joinedEventTimeWatermark", "3 seconds").as("joinedWatermarkWindow")).withWatermark("joinedWatermarkWindow", "3 seconds")

val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") && expr("joinedWatermarkWindow > mainWatermarkWindow"))

val query = stream.writeStream.trigger(Trigger.ProcessingTime(5000L)).foreach(RowProcessor).start()

... will be transformed to:

+- Exchange hashpartitioning(joinedKey#966, 200)
   +- EventTimeWatermark joinedWatermarkWindow#23: struct, interval 3 seconds
      +- Union
         :- *(7) Project [joinedKey#966, joinedEventTime#967L, joinedEventTimeWatermark#968, named_struct(start, precisetimestampconversion(((((CASE WHEN 
         // ...
         :  +- *(7) Filter (isnotnull(joinedEventTimeWatermark#968) && isnotnull(joinedKey#966))
         :     +- LocalTableScan [joinedKey#966, joinedEventTime#967L, joinedEventTimeWatermark#968]
// ...

While similar query but with equality condition in JOIN clause will be made at data shuffle level:

    val stream = mainEventsDataset.join(joinedEventsDataset, mainEventsDataset("mainKey") === joinedEventsDataset("joinedKey") &&
      mainEventsDataset("mainEventTimeWatermark") === joinedEventsDataset("joinedEventTimeWatermark"))

And the physical execution plan:

== Physical Plan ==
StreamingSymmetricHashJoin [mainKey#1447, mainEventTimeWatermark#1449-T2000ms], [joinedKey#1451, joinedEventTimeWatermark#1453], Inner, condition = [ leftOnly = null, rightOnly = null, both = null, full = null ], state info [ checkpoint = file:/tmp/temporary-4bd3025b-1e05-4410-9ee4-04dd4191c63b/state, runId = b85412cd-94b1-4112-8d81-1e5c2e94a0f9, opId = 0, ver = 2, numPartitions = 200], 8000, state cleanup [ left key predicate: (input[1, timestamp, true] <= 8000000), right key predicate: (input[1, timestamp, true] <= 8000000) ]
:- Exchange hashpartitioning(mainKey#1447, mainEventTimeWatermark#1449-T2000ms, 200)
:  +- *(5) Filter isnotnull(mainEventTimeWatermark#1449-T2000ms)
:     +- EventTimeWatermark mainEventTimeWatermark#1449: timestamp, interval 2 seconds
:        +- Union
// ...
:              +- LocalTableScan [mainKey#2479, mainEventTime#2480L, mainEventTimeWatermark#2481]
+- Exchange hashpartitioning(joinedKey#1451, joinedEventTimeWatermark#1453, 200)
   +- Union
      :- *(6) Filter (isnotnull(joinedEventTimeWatermark#1453) && isnotnull(joinedKey#1451))
      :  +- LocalTableScan [joinedKey#1451, joinedEventTime#1452L, joinedEventTimeWatermark#1453]

Old state cleaning

After defining watermark expression the engine transforms it to an instance of state key watermark (JoinStateKeyWatermarkPredicate(expr: Expression)) or state value watermark (JoinStateValueWatermarkPredicate(expr: Expression)). It's important to emphasize that each of query sides has its own predicate, since each stream can have different lateness characteristics.

Such built predicates are later used by org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinExec.OneSideHashJoiner instance referenced by StreamingSymmetricHashJoinExec. During the removal process the engine invokes appropriate SymmetricHashJoinStateManager's method:

The argument of both methods is the predicate built from the underlying watermark type. For value one it may be expressed as (joinedWatermarkWindow#23-T3000ms > mainWatermarkWindow#14-T3000ms). Internally the state is stored as a multi-map, i.e. one key with zero or more values. So that the engine can remove only a part of stored state.


Until now, we've learned about the state management but what about joins themselves ? A small overview of their execution was given in the first section inside information block. As we could see there, joins are inevitably made with hash exchange, i.e. a shuffle. All rows having given join key are moved to the same partition and it's here where the magic happens. The physical execution is made in processPartitions(leftInputIter: Iterator[InternalRow], rightInputIter: Iterator[InternalRow]) method. It's helped by 2 instances of OneSideHashJoiner, each one representing each side of the join.

The result of joining is a row of org.apache.spark.sql.catalyst.expressions.JoinedRow type constructed with builder withLeft(InternalRow) and withRight(InternalRow) methods. But before it happens, OneSideHashJoiner retrieves first all non late rows and for each of them and checks if the filtering conditions are met. These conditions are pre-join filter (filter that applies only to processed join side) and post-join filter (overall join condition). Later the joiner tries to find a matching row(s) in the other side for each valid row by doing:

val key = keyGenerator(thisRow)
val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow =>
  generateJoinedRow(thisRow, thatRow)
The joiner verifies later if actually matching row (raw version, not joined) may be persisted in the state store. It's persisted there only when 2 watermark predicates (key and value, already mentioned in the post) are valid (i.e. row itself is not expired). And that's all for inner join. A special treatment is launched in the case of outer join. The matched rows are completed by not matched and it happens here:

val outputIter: Iterator[InternalRow] = joinType match {
  case Inner =>
  case LeftOuter =>
    def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = {
      rightSideJoiner.get(leftKeyValue.key).exists { rightValue =>
    // BK: removeOldState returns an iterator with expired rows 
    //     It clearly shows that without watermark it wouldn't be possible
    //     to emit not matched ones in the case of outer join.
    val removedRowIter = leftSideJoiner.removeOldState()
    val outerOutputIter = removedRowIter
      .filterNot(pair => matchesWithRightSideState(pair))
      .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))

    innerOutputIter ++ outerOutputIter
  case RightOuter =>
    // BK: does the same as LeftOuter but by switching sides
  case _ => throwBadJoinTypeException()

Once all joined rows are consumed, a callback is executed to remove old state from the state store:

val cleanupIter = joinType match {
  case Inner =>
    leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
  // BK: for outer only joined rows are removed - the removal of outer side
  //     was made before, at generating not fully matched rows
  case LeftOuter => rightSideJoiner.removeOldState()
  case RightOuter => leftSideJoiner.removeOldState()
  case _ => throwBadJoinTypeException()

State management in unbounded data sources is very important. It helps to reduce the amount stored rows by computing watermark predicates. But, as shown in the first section of this post, it's not a single part involved in streaming joins feature. Another one is finding the matches. Covered in the 3rd section, it describes the joining logic implemented in StreamingSymmetricHashJoinExec's processPartitions method. As we can learn, the operation first retrieves valid rows (filter + allowed lateness) and only mater looks for matching rows in other side. After doing that, an iterator with all joined rows is returned and at the end of its iteration a function cleaning old rows is invoked.

If you liked it, you should read:

The comments are moderated. I publish them when I answer, so don't worry if you don't see yours immediately :)

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!