https://github.com/bartosz25/spark-playground/tree/master/spark-4-structured-streaming-new-state-api
Last week we discovered the new way to write arbitrary stateful transformations in Apache Spark 4 with the transformWithState API. Today it's time to delve into the implementation details and try to understand the internal logic a bit better.
Data Engineering Design Patterns

Looking for a book that defines and solves most common data engineering problems? I wrote
one on that topic! You can read it online
on the O'Reilly platform,
or get a print copy on Amazon.
I also help solve your data engineering problems 👉 contact@waitingforcode.com 📩
State initialization
An important thing to keep in mind for the state initialization is time. Apache Spark initializes the state only once, when you start the job for the first time. Therefore, it works per micro-batch and not per grouping key. This is currently controlled by the IncrementalExecution class responsible for planning the micro-batches:
object StateOpIdRule extends SparkPlanPartialRule { case t: TransformWithStateExec => val hasInitialState = (currentBatchId == 0L && t.hasInitialState) t.copy( stateInfo = Some(nextStatefulOperationStateInfo()), batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), eventTimeWatermarkForLateEvents = None, eventTimeWatermarkForEviction = None, hasInitialState = hasInitialState)
The previous code snippet also highlights an important property that is passed to the logical node when you define the initial state DataFrame. This property is the hasInitialState flag that informs the physical execution node to perform state initialization before handling the input rows in the first micro-batch:

Concretely speaking, the state initialization consists of creating a DataFrame grouped by the state key and invoking the StatefulProcessor#handleInitialState before processing the records from the input:

All the initial state logic complexity, if any, will be handled in your implementation of the handleInitialState function. The TransformWithStateExec node doesn't perform any action on the data itself besides converting it to the format internally expected by Apache Spark. Consequently, if your state DataFrame has many rows for a given grouping key, the handleInitialState will be called many times. It's quite important to either guarantee only one instance per grouping key in the state DataFrame, or to handle the records correctly inside the initialization method, i.e. apply some state conciliation logic. You can find a code example in my Github: TransformWithStateInitStateWithReconciliation.
Input data processing
The next step, and maybe even more important than the optional state initialization, is data processing. When new data arrive to the stateful transformation, it follows the following execution path:

The input dataset can be processed from two entry points: state initialization or regular input data processing workflow. As an end user, the most important part for you is the last box where Apache Spark passes an iterator of input rows to your StatefulProcessor#handleInputRows function. Although this workflow looks simple, it hides an important detail, the input rows iterator won't necessarily contain all your input data! If the stateful logic uses watermark-based timer, before returning the input rows to your handleInputRows, Apache Spark removes rows older than the watermark:
private def processDataWithPartition // ... CompletionIterator[InternalRow, Iterator[InternalRow]] = { // ... val filteredIter = watermarkPredicateForDataForLateEvents match { case Some(predicate) if timeMode == TimeMode.EventTime() => applyRemovingRowsOlderThanWatermark(iter, predicate) case _ => iter } val newDataProcessorIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( processNewData(filteredIter), // ...)
State expiration - timers
Since the input filtering is a similar concept to the state expiration, let's see now when the expiration happens. The new abstraction involved in the state expiration are Timers. In the aforementioned processDataWithPartition, after processing all input rows, Apache Spark triggers timers in the following snippet:
private def processDataWithPartition // ... val timeoutProcessorIter = new Iterator[InternalRow] { private def getIterator(): Iterator[InternalRow] = CompletionIterator[InternalRow, Iterator[InternalRow]]( processTimers(timeMode, processorHandle), { // ... } val outputIterator = newDataProcessorIter ++ timeoutProcessorIter
The processTimers method has two execution workflows, one for the event time timers, and one for the processing time timers:

This is a good moment to introduce this new construct to manage state expiration which are the Timers. Timers manage state expiration per group key but they don't apply to individual value states. Put differently, if your stateful transformation stores many value states, such as a custom object and a list, the timer will be registered globally, so for all those value states, and not individually for each of them. It becomes pretty obvious if you analyse the API of the timers management in the StatefulProcessorHandle:
trait StatefulProcessorHandle extends Serializable { def registerTimer(expiryTimestampMs: Long): Unit def deleteTimer(expiryTimestampMs: Long): Unit def listTimers(): Iterator[Long]
As you can see, all these methods don't reference the value state, therefore, they are global for the grouping key and shared by all value states.
Internally, timers leverage a state store to keep track of the registered timers for a given grouping key:
/** * Class that provides the implementation for storing timers * used within the `transformWithState` operator. * @param store - state store to be used for storing timer data * @param timeMode - mode of timeout (event time or processing time) * @param keyExprEnc - encoder for key expression */ class TimerStateImpl( store: StateStore, timeMode: TimeMode, keyExprEnc: ExpressionEncoder[Any]) extends Logging {
The key for each timer is composed of the grouping key and the expiration time value in milliseconds, and each key is registered only once, so you don't need any specific deduplication logic in your code:
def registerTimer(expiryTimestampMs: Long): Unit = { val groupingKey = getGroupingKey(keyToTsCFName) if (exists(groupingKey, expiryTimestampMs)) { logWarning(log"Failed to register timer for key=${MDC(KEY, groupingKey)} and " + log"timestamp=${MDC(EXPIRY_TIMESTAMP, expiryTimestampMs)} ms since it already exists") } else { store.put(rowEncoder.encodedKey(groupingKey, expiryTimestampMs), EMPTY_ROW, keyToTsCFName) store.put(rowEncoder.encodeSecIndexKey(groupingKey, expiryTimestampMs), EMPTY_ROW, tsToKeyCFName) logDebug(s"Registered timer for key=$groupingKey and timestamp=$expiryTimestampMs") } }
A very important point here, the new transformWithState doesn't manage your state! As you saw before, the method only removes timers from the timers state store while your custom state must be removed in your StatefulProcessor#handleExpiredTimer function.
State cleaning -TTL
Even though you can define many timers for one grouping key, you can't easily target a state value with the expiration time. However, the new transformWithState API offers this targeted expiration policy with state TTL. Whenever you get a state value from StatefulProcessorHandler#getValueState, you can configure the TTL duration:
trait StatefulProcessorHandle extends Serializable { def getValueState[T](stateName: String, valEncoder: Encoder[T], ttlConfig: TTLConfig): ValueState[T] case class TTLConfig(ttlDuration: Duration)
The TTL duration is based on the processing time and Apache Spark changes it whenever you decide to update a state value with the enabled TTL:

As you can notice, whenever you retrieve the value state with TTL, Apache Spark adds it to an internal state store with TTLState instances. An TTL instance depends on the underlying state type (value, map, list, ...) but they all follow the same execution path with the addition to the ttlStates state store.
OneToOneTTLState vs. OneToManyTTLState
The difference between the two is crucial from the TTL management standpoint. The OneToOneTTLState always maps one value state to one TTL value. For example, it will be the case of a case class-based state. On the opposite side you'll find the OneToManyTTLState with a list state where the elements of the underlying list might have different TTL values. For example, if your transformWithState logic appends a new element to the list in each micro-batch, each element of the list will have a different TTL, as the TTL is calculated on the processing time.
Each time you update a TTL-enabled state value, Apache Spark will update the expiration time based on the micro-batch processing time:
override def update(newState: S): Unit = { val encodedKey = stateTypesEncoder.encodeGroupingKey() val ttlExpirationMs = StateTTL .calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) val encodedValue = stateTypesEncoder.encodeValue(newState, ttlExpirationMs) updatePrimaryAndSecondaryIndices(encodedKey, encodedValue, ttlExpirationMs) }
When it comes to the cleaning part, it happens after processing all input and expired states:

More exactly, here:
val outputIterator = newDataProcessorIter ++ timeoutProcessorIter CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, { allRemovalsTimeMs += timeTakenMs {processorHandle.doTtlCleanup() })
An important point to keep in mind regarding the TTL. Whenever you try to access a time-outed state value, you will get nothing, even though the entry still exists in the TTL state store; example below for the ListStateImplWithTTL this time:
override def get(): Iterator[S] = { // ... new NextIterator[S] { override protected def getNext(): S = { val iter = unsafeRowValuesIterator.dropWhile { row => stateTypesEncoder.isExpired(row, batchTimestampMs) } // returns the iter iterator without the expired values
State update
To update your state you need to call - unsurprisingly - the update method of your ValueState. Depending on the state type (with or without the TTL configured), the update function follows one of the following workflows:

As you can see from the picture, the update workflow starts similarly. Both states call first the update method. Both also write the state value. The difference are the secondary indices managed by the TTL-based state value. The picture clearly shows this additional writer to the state store with the expiration times for each state.
State removal
When you remove the state, similar things happen to the state update process. Let's take a look at the next schema:

The single difference with the state update is the removal operation on the state store instead of the upsert one. But state removal is a good moment to stop and ask, does this clear(...) method mean the state may stay in the state store forever? It depends on the expiration method because:
- Whenever a state expires because of a timer, the TransformWithStateExec only calls your handleExpiredTimer. It doesn't clear the state for you but you can eventually do this action in the handler.
- When your state expires because of the TTL configuration, then the TTLState#clearExpiredStateForAllKeys will take care of removing the state value from the secondary (TTL) and primary (value) state stores:
override private[sql] def clearExpiredStateForAllKeys(): Long = { ttlEvictionIterator().foreach { ttlKey => // Delete from secondary index deleteFromTTLIndex(ttlKey) // Delete from primary index store.remove(toTTLRow(ttlKey).elementKey, stateName)
The new transformWithState introduces new concepts for stateful processing in Apache Spark. Now, your single group state can keep multiple state instances. Thanks to this one-to-many relationship you get more flexibility and, for example, manage different expiration times. But it's not the last blog post about the transformWithState. While I was writing this one, I felt obliged to close the topic with the...batch-based transformWithState!
Consulting

With nearly 16 years of experience, including 8 as data engineer, I offer expert consulting to design and optimize scalable data solutions.
As an O’Reilly author, Data+AI Summit speaker, and blogger, I bring cutting-edge insights to modernize infrastructure, build robust pipelines, and
drive data-driven decision-making. Let's transform your data challenges into opportunities—reach out to elevate your data engineering game today!
👉 contact@waitingforcode.com
đź”— past projects