Making applyInPandasWithState less painful

Versions: Apache Spark 3.4.0

Do not get the title wrong! Having applyInPandasWithState in the PySpark API is huge! However, due to Python duck typing, some operations are more difficult and more risky to express in the code than in the strongly typed Scala API.

A virtual conference at the intersection of Data and AI. This is not a conference for the hype. Its real users talking about real experiences.
- 40+ speakers with the likes of Hannes from Duck DB, Sol Rashidi, Joe Reis, Sadie St. Lawrence, Ryan Wolf from nvidia, Rebecca from lidl
- 12th September 2024
- Three simultaneous tracks
- Panels, Lighting Talks, Keynotes, Booth crawls, Roundtables and Entertainment.
- Topics include (ingestion, finops for data, data for inference (feature platforms), data for ML observability
- 100% virtual and 100% free

👉 Register here

The blog post is the follow-up of my previous article introducing the arbitrary stateful processing in PySpark. At that moment I was so excited to explore the ins and outs that I put the beautiful code aside. But if you take the code example, you'll see several code smells, or at least, risky places, such as:

The code part we're going to change is this one:

visits = grouped_visits.applyInPandasWithState(
  func=map_with_state,
  outputStructType=StructType([
    StructField("group_id", IntegerType()),
    StructField("start_time", StringType()),
    StructField("end_time", StringType()),
    StructField("duration_in_milliseconds", LongType()),
    StructField("is_final", BooleanType())
  ]),
  stateStructType=StructType([
    StructField("start_time", TimestampType()),
    StructField("end_time", TimestampType())
  ]),
  outputMode="update",
  timeoutConf="EventTimeTimeout"
)

def map_with_state(group_id_tuple: Any,
               input_rows: Iterable[pandas.DataFrame],
               current_state: GroupState) -> Iterable[pandas.DataFrame]:
  session_expiration_time_16min_as_ms = 16 * 60 * 1000
  group_id = group_id_tuple[0]

  def generate_final_output(start_time_for_output: datetime.datetime, end_time_for_output: datetime.datetime, is_final: bool) -> Dict[str, Any]:
     duration_in_milliseconds = (end_time_for_output - start_time_for_output).total_seconds() * 1000
     return {
        "group_id": [group_id],
        "start_time": [start_time_for_output.isoformat()],
        "end_time": [end_time_for_output.isoformat()],
        "duration_in_milliseconds": [int(duration_in_milliseconds)],
        "is_final": [is_final]
     }

  if current_state.hasTimedOut:
    print(f"Session ({current_state.get}) expired for {group_id}; let's generate the final output here")
    start_time, end_time,  = current_state.get
    record = generate_final_output(start_time, end_time, is_final=True)
    current_state.remove()
  else:
    should_use_event_time_for_watermark = current_state.getCurrentWatermarkMs() == 0
    base_watermark = current_state.getCurrentWatermarkMs()

    first_event_timestamp_from_input = None
    last_event_timestamp_from_input = None
    for input_df_for_group in input_rows:
        if should_use_event_time_for_watermark:
            input_df_for_group['event_time_as_milliseconds'] = input_df_for_group['timestamp'] \
                	.apply(lambda x: int(pandas.Timestamp(x).timestamp()) * 1000)
            base_watermark = int(input_df_for_group['event_time_as_milliseconds'].max())

        first_event_timestamp_from_input = input_df_for_group['timestamp'].min()
        last_event_timestamp_from_input = input_df_for_group['timestamp'].max()

    start_time_for_sink = first_event_timestamp_from_input
    end_time_for_sink = last_event_timestamp_from_input
    if current_state.exists:
        start_time, end_time, = current_state.get
        start_time_for_sink = min(start_time, first_event_timestamp_from_input)
        end_time_for_sink = max(end_time, last_event_timestamp_from_input)
        current_state.update((
            start_time_for_sink,
            end_time_for_sink
        ))
    else:
        current_state.update((first_event_timestamp_from_input, last_event_timestamp_from_input))

    timeout_timestamp = base_watermark + session_expiration_time_16min_as_ms
    current_state.setTimeoutTimestamp(timeout_timestamp)

    record = generate_final_output(start_time_for_sink, end_time_for_sink, is_final=False)

    yield pandas.DataFrame(record)

I highlighted the places I'm going to optimize in the next section.

Optimizations - state management

To begin with, the state management. I could have created a state class, exactly like in the Scala API but I don't want to copy-paste the solution and play with classes initialization inside the stateful function. It's a valid approach, though! Instead, I'm creating some handlers that help deal with the state-related issues.

class StateSchemaHandler:

  def __init__(self, start_time: StructFieldWithStateUpdateHandler,
             end_time: StructFieldWithStateUpdateHandler):
    self.start_time = start_time
    self.end_time = end_time
    self.schema = StructType([
        self.start_time.field, self.end_time.field
    ])

  def get_state_as_dict(self, state_tuple: Tuple) -> Dict[str, Any]:
    return dict(zip(self.schema.fieldNames(), state_tuple)

  def get_empty_state_dict(self) -> Dict[str, Any]:
    field_names = self.schema.fieldNames()
    return {field_names[i]: None for i in range(0, len(field_names))}

  @staticmethod
  def transform_in_flight_state_to_state_to_write(in_flight_state: Dict[str, Any]) -> Tuple:
    return tuple(in_flight_state.values())

That's the first of them. As you can see, it addresses 2 conversion issues between the state tuple and state dictionary. Since it also defines the state schema, the risk of forgetting to unpack a field is also weaker. Additionally, the handler also defines an empty state dict. You'll see that in the moment, but the empty dict is there to simplify some logic in the stateful function.

When it comes to the StructFieldWithStateUpdateHandler, the state handler uses it to declare the PySpark schema. But it's not only a wrapper for the StructField class. It's also a provider for a safer state dictionary updates:

@dataclasses.dataclass
class StructFieldWithStateUpdateHandler:
  field: StructField

  def get(self, state_dict_to_read: Dict[str, Any]) -> Any:
    return state_dict_to_read[self.field.name]

  def update(self, state_dict_to_update: Dict[str, Any], new_value: Any):
    state_dict_to_update[self.field.name] = new_value

What does it give in the code? First, the job:

state_handler = StateSchemaHandler(
	start_time=StructFieldWithStateUpdateHandler(StructField("start_time", TimestampType())),
	end_time=StructFieldWithStateUpdateHandler(StructField("end_time", TimestampType()))
)

visits = grouped_visits.applyInPandasWithState(
	func=functools.partial(map_with_state_refactored, state_handler),
#...
	stateStructType=state_handler.schema,
# ...

I'm declaring here the stateful function as a partial one, so a function with some of the arguments already predefined. The function itself becomes:

def map_with_state_refactored(state_handler: StateSchemaHandler, 
                          	group_id_tuple: Any,
                          	input_rows: Iterable[pandas.DataFrame],
                          	current_state: GroupState) -> Iterable[pandas.DataFrame]:
# ...
  if current_state.hasTimedOut:
    print(f"Session ({current_state.get}) expired for {group_id}; let's generate the final output here")
    current_state_as_dict = state_handler.get_state_as_dict(current_state.get)
    current_state.remove()
  else:
    # ...
    start_time_for_sink = first_event_timestamp_from_input
    end_time_for_sink = last_event_timestamp_from_input
    current_state_as_dict = state_handler.get_empty_state_dict()
    if current_state.exists:
        current_state_as_dict = state_handler.get_state_as_dict(current_state.get)
        start_time_for_sink = min(state_handler.start_time.get(current_state_as_dict),
                                  first_event_timestamp_from_input)
        end_time_for_sink = max(state_handler.start_time.get(current_state_as_dict),
                                last_event_timestamp_from_input)

    state_handler.start_time.update(current_state_as_dict, start_time_for_sink)
    state_handler.end_time.update(current_state_as_dict, end_time_for_sink)

    state_to_update = StateSchemaHandler.transform_in_flight_state_to_state_to_write(current_state_as_dict)
    	current_state.update(state_to_update)

    	timeout_timestamp = base_watermark + session_expiration_time_16min_as_ms
    	current_state.setTimeoutTimestamp(timeout_timestamp)

The first difference you can certainly see is the disappearing of the if-else statement for the state update. Now, I only have a single if statement that resolves the values to put into the state. The update passes through the StructFieldWithStateUpdateHandler's update function that remember, uses the state field name declared in the state schema. In the end, I'm using the dict-to-tuple conversion to prepare the state type for update.

Optimizations - output generation

The output generation is the next part to optimize. It suffers the same issues as the state management of the fields names redeclaration. One simple solution could be the use of constants and their redeclaration inside the stateful function and the job. But there is an alternative I'm going to explore with a possibility to enclose the schema declaration and output generation in the same abstraction, the OutputHandler:

class OutputHandler:

  def __init__(self, group_id: StructField, start_time: StructField, end_time: StructField,
             duration_in_milliseconds: StructField, is_final: StructField,
             state_schema_handler: StateSchemaHandler):
    self.group_id = group_id
    self.start_time = start_time
    self.end_time = end_time
    self.duration_in_milliseconds = duration_in_milliseconds
    self.is_final = is_final
    self.schema = StructType([
       self.group_id, self.start_time, self.end_time,
        self.duration_in_milliseconds, self.is_final
    ])
    self.state_schema_handler = state_schema_handler

  def generate_output(self, group_id: int, timed_out_state: bool, state_dict: Dict[str, Any]) -> Dict[str, Any]:
    start_time_for_output = self.state_schema_handler.start_time.get(state_dict)
    end_time_for_output = self.state_schema_handler.end_time.get(state_dict)
    duration_in_milliseconds = (end_time_for_output - start_time_for_output).total_seconds() * 1000
    return {
        self.group_id.name: [group_id],
        self.start_time.name: [start_time_for_output.isoformat()],
        self.end_time.name: [end_time_for_output.isoformat()],
        self.duration_in_milliseconds.name: [int(duration_in_milliseconds)],
        self.is_final.name: [timed_out_state]
    }

As you can see, there is less magic and complexity that for the state management but the advantage of having everything in the single place is there. The job now looks like that:

state_handler = StateSchemaHandler(
  start_time=StructFieldWithStateUpdateHandler(StructField("start_time", TimestampType())),
  end_time=StructFieldWithStateUpdateHandler(StructField("end_time", TimestampType()))
)

output_handler = OutputHandler(
  group_id=StructField("group_id", IntegerType()),
  start_time=StructField("start_time", StringType()),
  end_time=StructField("end_time", StringType()),
  duration_in_milliseconds=StructField("duration_in_milliseconds", LongType()),
  is_final=StructField("is_final", BooleanType()),
  state_schema_handler=state_handler
)

visits = grouped_visits.applyInPandasWithState(
  func=functools.partial(map_with_state_refactored, state_handler, output_handler),
  outputStructType=output_handler.schema,
  stateStructType=state_handler.schema,
  outputMode="update",
  timeoutConf="EventTimeTimeout"
)

And the output generation in the stateful function:

def map_with_state_refactored(state_handler: StateSchemaHandler, output_handler: OutputHandler, group_id_tuple: Any,
  input_rows: Iterable[pandas.DataFrame],
  current_state: GroupState) -> Iterable[pandas.DataFrame]:
# ...
  record = output_handler.generate_output(group_id=group_id,
                                    timed_out_state=current_state.hasTimedOut,
                                    state_dict=current_state_as_dict)
  yield pandas.DataFrame(record)

The solution presented in the blog post is maybe yet another example for the saying that every problem can be solved with an extra layer of abstraction. This time the abstraction is a kind of glue between the business logic expressed in the stateful function and the schema expectations of the mapInPandasWithState. There are certainly other ways to solve this and I'll be happy to read your ideas!


If you liked it, you should read:

đź“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!