Data+AI Summit follow-up: aggregations and state management

Versions: Apache Spark 3.0.1

In previous blog posts you discovered how the state store interacts with dropDuplicates and limit operators. This time you will see how it's used in aggregations.

The organization here is similar to the previous follow-up posts. You will start by discovering how Apache Spark plans the streaming aggregations. In the next section, you will see the first of the operators involved in the state store interaction, the StateStoreRestoreExec. Later, I will quickly introduce a state store manager component. And to terminate, you will see the last operator interacting with the state component, the StateStoreSaveExec.

Query planning

If you remember my aggregations execution in Apache Spark SQL post, the magic generating the aggregation plan happens in org.apache.spark.sql.execution.aggregate.AggUtils. For the streaming aggregations, it will happen more exactly in planStreamingAggregation method.

The plan built in this method is very similar to batch queries' aggregation plan, except for a few differences. The first part is common because the plan starts with partial aggregations, executed locally on every node.

+- StateStoreRestore [window#37-T300000ms], state info [ checkpoint = file:/tmp/data+ai/stateful/aggregation_demo/checkpoint/state, runId = ff158a12-4fc3-4aca-ab7f-60a3de835438, opId = 0, ver = 0, numPartitions = 2], 2
  +- *(3) HashAggregate(keys=[window#37-T300000ms], functions=[merge_sum(cast(value#27 as bigint))], output=[window#37-T300000ms, sum#43L])
     +- Exchange hashpartitioning(window#37-T300000ms, 2), true, [id=#66]
        +- *(2) HashAggregate(keys=[window#37-T300000ms], functions=[partial_sum(cast(value#27 as bigint))], output=[window#37-T300000ms, sum#43L])

On top of it, a StateStoreRestoreExec node is added. Just after that, the partial aggregates are merged into the final result in a new "reduce" aggregation. On top of it, another state-related operator called StateStoreSaveExec is invoked. At the end, one aggregation preparing the data for the output projection is added.

+- *(5) HashAggregate(keys=[window#37-T300000ms], functions=[sum(cast(value#27 as bigint))], output=[window#31-T300000ms, sum(value)#36L])
   +- StateStoreSave [window#37-T300000ms], state info [ checkpoint = file:/tmp/data+ai/stateful/aggregation_demo/checkpoint/state, runId = ff158a12-4fc3-4aca-ab7f-60a3de835438, opId = 0, ver = 0, numPartitions = 2], Complete, 0, 2
      +- *(4) HashAggregate(keys=[window#37-T300000ms], functions=[merge_sum(cast(value#27 as bigint))], output=[window#37-T300000ms, sum#43L])

StateStoreRestoreExec

The first operation involving state store is StateStoreRestoreExec. Are you agree that the name is very meaningful if I say to you that it restores the previous aggregation value from the state store? Anyway, that's the main purpose of this operator and you can notice that just by looking in the mapPartitionsWithStateStore' function:

        val hasInput = iter.hasNext
        if (!hasInput && keyExpressions.isEmpty) {
          store.iterator().map(_.value)
        } else {
          iter.flatMap { row =>
            val key = stateManager.getKey(row.asInstanceOf[UnsafeRow])
            val restoredRow = stateManager.get(store, key)
            numOutputRows += 1
            Option(restoredRow).toSeq :+ row
          }
        }

State store manager

Depending on the aggregation type, the state store manager gets either all the values (in fact, it's a single value, you'll see that in the next part) or a value for the key of the partial aggregation. But before going to the second part, let's stop a while and introduce a new abstraction layer for the state store interaction called state store manager.

This component is represented by one of two classes implementing the StreamingAggregationStateManager interface. The main role of state managers - it's not present only in the stateful aggregations - is to ensure retro compatibility for the state in case of evolutions. For the aggregations, the difference comes from the values. In the V1 of the state manager, the value is composed of the key and value of the aggregated columns whereas the V2 only stores the value:

class StreamingAggregationStateManagerImplV1(
    keyExpressions: Seq[Attribute],
    inputRowAttributes: Seq[Attribute])
  extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {

  override def getStateValueSchema: StructType = inputRowAttributes.toStructType
// ...
}

class StreamingAggregationStateManagerImplV2(
    keyExpressions: Seq[Attribute],
    inputRowAttributes: Seq[Attribute])
  extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {

  private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions)
// ...
}

StateStoreSaveExec

Regarding the StateStoreSaveExec operator, it's a bit more complex than the restore since to figure out the aggregions to return it uses the output mode. Let's focus on them one-by-one.

The first of the modes is called complete. Its logic is quite straightforward. It starts by updating all aggregations changed in the given micro-batch execution. Just after, it commits the state store version and finally returns all of the aggregates, even the one which didn't change by calling StateStore's iterator() method under-the-hood:

              while (iter.hasNext) {
                val row = iter.next().asInstanceOf[UnsafeRow]
                stateManager.put(store, row)
                numUpdatedStateRows += 1
              }
            }
            allRemovalsTimeMs += 0
            commitTimeMs += timeTakenMs {
              stateManager.commit(store)
            }
            setStoreMetrics(store)
            stateManager.values(store).map { valueRow =>
              numOutputRows += 1
              valueRow
            }

Regarding the second mode, the append, the general algorithm is the same but it involves the watermark predicate. Every aggregation result which is not older than the watermark updates the state store . After that, an instance of NextIterator is created to return the corresponding results to the next node in the plan. One of specificites of this iterator type is the possibility to define a kind of "callback" method that will be invoked after returning the last element (= hastNext returns false). It's inside this method where the state store commit() is called. Among the returned rows, you will find the ones that are older than the watermark. In addition to be returned to the next operator, they're also physically removed from the state store:

// Update - only results that are not older than the watermark
val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row))
while (filteredIter.hasNext) {
  val row = filteredIter.next().asInstanceOf[UnsafeRow]
  stateManager.put(store, row)
  numUpdatedStateRows += 1
}

val rangeIter = stateManager.iterator(store)
new NextIterator[InternalRow] {
  override protected def getNext(): InternalRow = {
    var removedValueRow: InternalRow = null
    while(rangeIter.hasNext && removedValueRow == null) {
      val rowPair = rangeIter.next()
      // Remove the value older than the watermark  
      if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
        stateManager.remove(store, rowPair.key)
        removedValueRow = rowPair.value
      }
    }
    // Return only the removed values (not supposed to
    // receive new data since older than the watermark)
    if (removedValueRow == null) {
      finished = true
      null
    } else {
      numOutputRows += 1
      removedValueRow
    }
  }

  override protected def close(): Unit = {
    allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
    // As usually, commit the state store after
    // making all the writes
    commitTimeMs += timeTakenMs { stateManager.commit(store) }
    setStoreMetrics(store)
  }
}

And finally, the update operation starts by creating the iterator filtering or not (if the watermark predicate doesn't exist) the aggregated results. Later, a NextIterator is created that for all rows that are not older than the watermark (if filtering exists; all rows will be returned otherwise) will update their value in the state store. All the updates records will be also returned to the next operation in the plan. At the end (= no more elements in the iterator), too old keys will be removed (= only if one of the aggregation columns is the same as the watermark column) and after this removal, the state is committed:

// Update results
lazy val watermarkExpression: Option[Expression] = {
  WatermarkSupport.watermarkExpression(
    child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)),
    eventTimeWatermark)
}
lazy val watermarkPredicateForData: Option[BasePredicate] =
  watermarkExpression.map(Predicate.create(_, child.output))

private[this] val baseIterator = watermarkPredicateForData match {
  case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row))
  case None => iter
}
private val updatesStartTimeNs = System.nanoTime

override protected def getNext(): InternalRow = {
  if (baseIterator.hasNext) {
    val row = baseIterator.next().asInstanceOf[UnsafeRow]
    stateManager.put(store, row)
    numOutputRows += 1
    numUpdatedStateRows += 1
    row
  } else {
    finished = true
    null
  }
}

// Commit the state
override protected def close(): Unit = {
  allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)

  // Remove old aggregates if watermark specified
  allRemovalsTimeMs += timeTakenMs {
    removeKeysOlderThanWatermark(stateManager, store)
  }
  commitTimeMs += timeTakenMs { stateManager.commit(store) }
  setStoreMetrics(store)
}

Watermark predicate creation

Streaming aggregation uses 2 types of watermark predicates, the watermarkPredicateForKeys and watermarkPredicateForData. The creation of both of them is controlled by this expression:

  lazy val watermarkExpression: Option[Expression] = {
    WatermarkSupport.watermarkExpression(
      child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)),
      eventTimeWatermark)
  }

As you can see, it exists only if any of the children nodes returns the watermark column. Hence, for a query like .withWatermark("event_time", "5 seconds").groupBy("user_id").count() the watermarkExpression will be empty. As you can deduce then, the groupBy command should store the watermark column. That's the reason why the watermark-based predicates will not be always created. Below you can find their current (3.0.1) implementation:

  /** Predicate based on keys that matches data older than the watermark */
  lazy val watermarkPredicateForKeys: Option[BasePredicate] = watermarkExpression.flatMap { e =>
    if (keyExpressions.exists(_.metadata.contains(EventTimeWatermark.delayKey))) {
      Some(Predicate.create(e, keyExpressions))
    } else {
      None
    }
  }

  /** Predicate based on the child output that matches data older than the watermark. */
  lazy val watermarkPredicateForData: Option[BasePredicate] =
    watermarkExpression.map(Predicate.create(_, child.output))

As you can see, the stateful aggregations are something more complicated than the 2 operations you saw so far, the dropDuplicates and limit. Not only it involves an intermediate layer represented by the state store manager, but also adapts its behavior to the output modes.

If you liked it, you should read:

The comments are moderated. I publish them when I answer, so don't worry if you don't see yours immediately :)

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!