How to initialize state in Apache Spark Structured Streaming stateful jobs?

Versions: Apache Spark 3.3.0

Starting from Apache Spark 3.2.0 is now possible to load an initial state of the arbitrary stateful pipelines. Even though the feature is easy to implement, it hides some interesting implementation details!

State initialization

The initial state is provided as a separate Dataset. It's passed to the stateful processing logic as a part of the mapGroupsWithState signature:

class KeyValueGroupedDataset[K, V] private[sql](
	kEncoder: Encoder[K],
	vEncoder: Encoder[V],
	@transient val queryExecution: QueryExecution,
	private val dataAttributes: Seq[Attribute],
	private val groupingAttributes: Seq[Attribute]) extends Serializable {

  def mapGroupsWithState[S, U](
  	func: MapGroupsWithStateFunction[K, V, S, U],
  	stateEncoder: Encoder[S],
  	outputEncoder: Encoder[U],
  	timeoutConf: GroupStateTimeout,
  	initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
	mapGroupsWithState[S, U](timeoutConf, initialState)(
  	(key: K, it: Iterator[V], s: GroupState[S]) => func.call(key, it.asJava, s)
	)(stateEncoder, outputEncoder)
  }

You've certainly noticed one thing about the types. The initialState shares the key (K) and state type (S) with the func. Without this matching it wouldn't be possible to integrate the initial state with the new incoming records. On a client code a state initialization could look like that:

  val defaultState = Seq(
	("user1", 10), ("user2", 20), ("user3", 30)
  ).toDF("login", "points").as[(String, Int)].groupByKey(row => row._1).mapValues(_._2)

  val inputStream = new MemoryStream[(String, Int)](1, sparkSession.sqlContext)
  inputStream.addData(("user1", 5))
  inputStream.addData(("user4", 2))

  val statefulAggregation = inputStream.toDS().toDF("login", "points")
	.groupByKey(row => row.getAs[String]("login"))
	.mapGroupsWithState(GroupStateTimeout.ProcessingTimeTimeout(), defaultState)(StatefulMapper.apply)

State initialization join

Under-the-hood the initial state is first passed as a SparkPlan to the FlatMapGroupsWithStateExec from this strategy mapping the logical into physical representation:

  object FlatMapGroupsWithStateStrategy extends Strategy {
  override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan)


	override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
  	case FlatMapGroupsWithState(
    	func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
    	timeout, hasInitialState, stateGroupAttr, sda, sDeser, initialState, child) =>
    	val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
    	val execPlan = FlatMapGroupsWithStateExec(
      	func, keyDeser, valueDeser, sDeser, groupAttr, stateGroupAttr, dataAttr, sda, outputAttr,
      	None, stateEnc, stateVersion, outputMode, timeout, batchTimestampMs = None,
      	eventTimeWatermark = None, planLater(initialState), hasInitialState, planLater(child)
    	)
    	execPlan :: Nil
  	case _ =>
    	Nil
	}
  }

Inside the physical node, Apache Spark processes new input rows in the processNewDataWithInitialState function instead of the usual processNewData. Inside, it creates a CoGroupedIterator executing 2 operations:

The operation is an implicit left join because all the keys from the initial Dataset - even those without the matching input row - are processed in the stateful function. Therefore, if you want to keep them for later, you shouldn't remove them from the state store in this initial call.

Initial state and micro-batches

That was for the first micro-batch execution. After all, the state initialization implies doing that only once, in the beginning of the job. But what with the subsequent executions? How does Apache Spark skip the state initialization? The answer is in the IncrementalExecution class where the planner has a special treatment for the state initialization regarding the micro-batch number:

  	class IncrementalExecution(
	sparkSession: SparkSession,
	logicalPlan: LogicalPlan,
	val outputMode: OutputMode,
	val checkpointLocation: String,
	val queryId: UUID,
	val runId: UUID,
	val currentBatchId: Long,
	val offsetSeqMetadata: OffsetSeqMetadata)
  extends QueryExecution(sparkSession, logicalPlan) with Logging {


	
	override def apply(plan: SparkPlan): SparkPlan = plan transform {
// ...
  	case m: FlatMapGroupsWithStateExec =>
    	// We set this to true only for the first batch of the streaming query.
    	val hasInitialState = (currentBatchId == 0L && m.hasInitialState)
    	m.copy(
      	stateInfo = Some(nextStatefulOperationStateInfo),
      	batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
      	eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs),
      	hasInitialState = hasInitialState
    	)

As you can see, having a single state is not enough. The planner overwrites this parameter for any micro-batch but the first one. That way any subsequent micro-batch execution passes through the usual processNewData where the initialization logic is missing.

Loading state is a relatively new feature in Apache Spark Structured Streaming but is an important so-far-missing component for stateful pipelines. Four years ago I've blogged about that topic already (Initializing state in Structured Streaming) and found the problem quite challenging to solve. Now, with the hidden complexity, it should be much easier for the end-users!


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!