What's new in Apache Spark 3.1 - streaming joins

Versions: Apache Spark 3.1.1

In the previous blog post, you discovered what changed for joins in Apache Spark 3.1. If you remember the summary sentence, it was not the single join changes in this new release. Apart from them, you can also do a bit more with Structured Streaming joins!

To summarize these changes in a few words, we could say: new join types support and bug fixes. And if you want to know more, keep reading!

New joins - left-semi join

The first of the new supported stream-to-stream join types is left-semi join. Use it when you need a match between 2 sides but in your downstream consumer you only need the left side row for processing.

The implementation uses then the same method to resolve matching rows like the inner join; i.e. it looks for the matching rows in the state store of the right side (more explanation in Data+AI Summit follow-up: joins and state management):

case class StreamingSymmetricHashJoinExec(
// ...
    left: SparkPlan,
    right: SparkPlan) extends BinaryExecNode with StateStoreWriter {
// ...
  private def processPartitions(
      partitionId: Int,
      leftInputIter: Iterator[InternalRow],
      rightInputIter: Iterator[InternalRow]): Iterator[InternalRow] = {
// ...
    val hashJoinOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
      leftOutputIter ++ rightOutputIter, onHashJoinOutputCompletion())

    val outputIter: Iterator[InternalRow] = joinType match {
      case Inner | LeftSemi =>
        hashJoinOutputIter

But wait, does it mean the user has to filter out the right side in the downstream function? Not at all! The feature also changed the join logic by adding an val excludeRowsAlreadyMatched = joinType == LeftSemi && joinSide == RightSide flag. SymmetricHashJoinStateManager uses it when it tries to find the matching rows for the right side:

  def getJoinedRows(
      key: UnsafeRow,
      generateJoinedRow: InternalRow => JoinedRow,
      predicate: JoinedRow => Boolean,
      excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow] = {
    val numValues = keyToNumValues.get(key)
    keyWithIndexToValue.getAll(key, numValues).filterNot { keyIdxToValue =>
      excludeRowsAlreadyMatched && keyIdxToValue.matched
    }.map { keyIdxToValue => ...

To avoid returning the right side of the join, the iterator generating the output changed as well. For left-semi join it will return only the left side:

    val outputProjection = if (joinType == LeftSemi) {
      UnsafeProjection.create(output, output)
    } else {
      UnsafeProjection.create(left.output ++ right.output, output)
    }

      val generateOutputIter: (InternalRow, Iterator[JoinedRow]) => Iterator[InternalRow] =
        joinSide match {
          case LeftSide if joinType == LeftSemi =>
            (input: InternalRow, joinedRowIter: Iterator[JoinedRow]) =>
              // For left side of left semi join, generate one left row if there is matched
              // rows from right side. Otherwise, generate nothing.
              if (joinedRowIter.nonEmpty) {
                Iterator.single(input)
              } else {
                Iterator.empty
              }
          case RightSide if joinType == LeftSemi =>
            (_: InternalRow, joinedRowIter: Iterator[JoinedRow]) =>
              // For right side of left semi join, generate matched left rows only.
              joinedRowIter.map(_.getLeft)
          case _ => (_: InternalRow, joinedRowIter: Iterator[JoinedRow]) => joinedRowIter
        }

OK, but does it work? Let's see:

New joins - full outer join

But the left-semi join is not the single new type added in Apache Spark 3.1. The second one is the full outer join. You can consider it as a combination of left and right outer joins. Since both are already supported in Structured Streaming, the full outer join implementation relies on them quite naturally!

In the beginning of the full outer join algorithm, the joined rows generation uses the same hashJoinOutputIter variable that will return only the rows existing on both sides. After that, it will complete this list with the rows without matches. StreamingSymmetricHashJoinExec will then use the same logic as for left and right outer joins:

    val hashJoinOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]](
      leftOutputIter ++ rightOutputIter, onHashJoinOutputCompletion())

    val outputIter: Iterator[InternalRow] = joinType match {
      case Inner | LeftSemi =>
        hashJoinOutputIter
// BK: left outer has similar logic to the right outer, 
// for the sake of simplicity, I put only the right side here
      case RightOuter =>
        def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = {
          leftSideJoiner.get(rightKeyValue.key).exists { leftValue =>
            postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value))
          }
        }
        val removedRowIter = rightSideJoiner.removeOldState()
        val outerOutputIter = removedRowIter.filterNot { kv =>
          stateFormatVersion match {
            case 1 => matchesWithLeftSideState(new UnsafeRowPair(kv.key, kv.value))
            case 2 => kv.matched
            case _ => throwBadStateFormatVersionException()
          }
        }.map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))

        hashJoinOutputIter ++ outerOutputIter
      case FullOuter =>
        lazy val isKeyToValuePairMatched = (kv: KeyToValuePair) =>
          stateFormatVersion match {
            case 2 => kv.matched
            case _ => throwBadStateFormatVersionException()
          }
        val leftSideOutputIter = leftSideJoiner.removeOldState().filterNot(
          isKeyToValuePairMatched).map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
        val rightSideOutputIter = rightSideJoiner.removeOldState().filterNot(
          isKeyToValuePairMatched).map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))

        hashJoinOutputIter ++ leftSideOutputIter ++ rightSideOutputIter
      case _ => throwBadJoinTypeException()
    }

Regarding the never matched rows, Apache Spark will return them after the watermark expiration. Exactly like in the demo below:

Bug fixes

In addition to the left-semi and full outer join support, joins in the new Structured Streaming release also got important bug fixes. The first of them concerns left outer join semantics and it was already included in the 3.0.1. However, it's important to mention since it proves to me something I couldn't understand in the beginning of my Structured Streaming code analysis - the need to use copy on UnsafeRow. UnsafeRow is mutable and can be reused. In consequence, without copying it, it will always represent the last value. I demonstrated that in UnsafeRow - to copy or not to copy? and Jungtaek Lim used the copy to fix SPARK-32148 regarding the not deterministic results generated in the left outer join:

  private class KeyWithIndexToValueRowConverterFormatV2 extends KeyWithIndexToValueRowConverter {
    override def convertValue(value: UnsafeRow): ValueAndMatchPair = {
      if (value != null) {
        ValueAndMatchPair(valueRowGenerator(value).copy(),
          value.getBoolean(indexOrdinalInValueWithMatchedRow))
      } else {
        null
      }
    }

The second fix is for the global watermark issue spotted in the 3.0 release with SPARK-28074 and fixed with the correctness issue warning: "Detected pattern of possible 'correctness' issue due to global watermark.". What is this global watermark correctness issue? You will find a good example of this correctness issue in Etienne Chauchot's Watermark architecture proposal for Spark Structured Streaming framework article. I will take a shortcut here and explain it with a direct example. Let's imagine that we have a pipeline composed of 2 stateful operations. The first of them takes the max value of a 3 seconds window and the passes it to the second for aggregation. Let's imagine the context like that:

# global watermark = 3
# op#1 max(nr) ⇒ (value=6, timestamp=1), (4, 2), (5, 3) 

Now, the op#2 will get the pair (6, 1) because the first window ended due to the watermark. However, the op#2 will consider this pair as being late and, therefore, discard it. Apache Spark will report the issue to the user either as an AnalysisException or the warning message, depending on the value set to "spark.sql.streaming.statefulOperator.checkCorrectness.enabled:

    val failWhenDetected = SQLConf.get.statefulOperatorCorrectnessCheckEnabled
    try { 
      plan.foreach { subPlan => 
        if (isStatefulOperation(subPlan)) { 
          subPlan.find { p => 
            (p ne subPlan) && isStatefulOperationPossiblyEmitLateRows(p) 
          }.foreach { _ => 
            val errorMsg = "Detected pattern of possible 'correctness' issue " + 
              "due to global watermark. " + 
              "The query contains stateful operation which can emit rows older than " + 
              "the current watermark plus allowed late record delay, which are \"late rows\"" + 
              " in downstream stateful operations and these rows can be discarded. " + 
              "Please refer the programming guide doc for more details. If you understand " +
              "the possible risk of correctness issue and still need to run the query, " +
              "you can disable this check by setting the config " +
              "`spark.sql.streaming.statefulOperator.checkCorrectness.enabled` to false."
            throwError(errorMsg)(plan)
          }
        }
      }
    } catch {
      case e: AnalysisException if !failWhenDetected => logWarning(s"${e.message};\n$plan")
    }
  }

Let's check if it's true with some code:

That's all for the joins part in Apache Spark 3.1 but not all for the Structured Streaming updates! You will find a completed list of the updates in the next blog post.