Do not get the title wrong! Having applyInPandasWithState in the PySpark API is huge! However, due to Python duck typing, some operations are more difficult and more risky to express in the code than in the strongly typed Scala API.
A virtual conference at the intersection of Data and AI. This is not a conference for the hype. Its real users talking about real experiences.
- 40+ speakers with the likes of Hannes from Duck DB, Sol Rashidi, Joe Reis, Sadie St. Lawrence, Ryan Wolf from nvidia, Rebecca from lidl
- 12th September 2024
- Three simultaneous tracks
- Panels, Lighting Talks, Keynotes, Booth crawls, Roundtables and Entertainment.
- Topics include (ingestion, finops for data, data for inference (feature platforms), data for ML observability
- 100% virtual and 100% free
👉 Register here
The blog post is the follow-up of my previous article introducing the arbitrary stateful processing in PySpark. At that moment I was so excited to explore the ins and outs that I put the beautiful code aside. But if you take the code example, you'll see several code smells, or at least, risky places, such as:
- Current state retrieval. It's a tuple, hence I'm reading it as start_time, end_time, = current_state.get. It works for an easy state but typing all fields for a more complex one will be painful.
- Besides, the current state must respect the schema defined in the stateStructType attribute of the applyInPandasWithState method. It's easy to desynchronize them and even you'll get an error, it'll be a runtime error, so an error after physically running the job (time consuming).
- Same for the outputStructType attribute that defines the output schema. I was once stuck for like 20 minutes while trying to understand an object of type <class 'str'> cannot be converted to int error. If I only had some output record schema enforcement, I could see the reason much earlier.
- The code relies on the very typeless structures, such as tuples and dictionaries. For the former it's easy to make a mistake while unpacking the variables or a type while declaring a dictionary. Still, PySpark should catch both mistakes at runtime but why not make our life easier at the development stage?
- Finally, I don't like the fact that the schema declaration and their usage is decoupled. Although here the example is simple and the pain is not that visible but in a more advanced project, it can be annoying to have to switch between the declaration and the usage.
The code part we're going to change is this one:
visits = grouped_visits.applyInPandasWithState( func=map_with_state, outputStructType=StructType([ StructField("group_id", IntegerType()), StructField("start_time", StringType()), StructField("end_time", StringType()), StructField("duration_in_milliseconds", LongType()), StructField("is_final", BooleanType()) ]), stateStructType=StructType([ StructField("start_time", TimestampType()), StructField("end_time", TimestampType()) ]), outputMode="update", timeoutConf="EventTimeTimeout" ) def map_with_state(group_id_tuple: Any, input_rows: Iterable[pandas.DataFrame], current_state: GroupState) -> Iterable[pandas.DataFrame]: session_expiration_time_16min_as_ms = 16 * 60 * 1000 group_id = group_id_tuple[0] def generate_final_output(start_time_for_output: datetime.datetime, end_time_for_output: datetime.datetime, is_final: bool) -> Dict[str, Any]: duration_in_milliseconds = (end_time_for_output - start_time_for_output).total_seconds() * 1000 return { "group_id": [group_id], "start_time": [start_time_for_output.isoformat()], "end_time": [end_time_for_output.isoformat()], "duration_in_milliseconds": [int(duration_in_milliseconds)], "is_final": [is_final] } if current_state.hasTimedOut: print(f"Session ({current_state.get}) expired for {group_id}; let's generate the final output here") start_time, end_time, = current_state.get record = generate_final_output(start_time, end_time, is_final=True) current_state.remove() else: should_use_event_time_for_watermark = current_state.getCurrentWatermarkMs() == 0 base_watermark = current_state.getCurrentWatermarkMs() first_event_timestamp_from_input = None last_event_timestamp_from_input = None for input_df_for_group in input_rows: if should_use_event_time_for_watermark: input_df_for_group['event_time_as_milliseconds'] = input_df_for_group['timestamp'] \ .apply(lambda x: int(pandas.Timestamp(x).timestamp()) * 1000) base_watermark = int(input_df_for_group['event_time_as_milliseconds'].max()) first_event_timestamp_from_input = input_df_for_group['timestamp'].min() last_event_timestamp_from_input = input_df_for_group['timestamp'].max() start_time_for_sink = first_event_timestamp_from_input end_time_for_sink = last_event_timestamp_from_input if current_state.exists: start_time, end_time, = current_state.get start_time_for_sink = min(start_time, first_event_timestamp_from_input) end_time_for_sink = max(end_time, last_event_timestamp_from_input) current_state.update(( start_time_for_sink, end_time_for_sink )) else: current_state.update((first_event_timestamp_from_input, last_event_timestamp_from_input)) timeout_timestamp = base_watermark + session_expiration_time_16min_as_ms current_state.setTimeoutTimestamp(timeout_timestamp) record = generate_final_output(start_time_for_sink, end_time_for_sink, is_final=False) yield pandas.DataFrame(record)
I highlighted the places I'm going to optimize in the next section.
Optimizations - state management
To begin with, the state management. I could have created a state class, exactly like in the Scala API but I don't want to copy-paste the solution and play with classes initialization inside the stateful function. It's a valid approach, though! Instead, I'm creating some handlers that help deal with the state-related issues.
class StateSchemaHandler: def __init__(self, start_time: StructFieldWithStateUpdateHandler, end_time: StructFieldWithStateUpdateHandler): self.start_time = start_time self.end_time = end_time self.schema = StructType([ self.start_time.field, self.end_time.field ]) def get_state_as_dict(self, state_tuple: Tuple) -> Dict[str, Any]: return dict(zip(self.schema.fieldNames(), state_tuple) def get_empty_state_dict(self) -> Dict[str, Any]: field_names = self.schema.fieldNames() return {field_names[i]: None for i in range(0, len(field_names))} @staticmethod def transform_in_flight_state_to_state_to_write(in_flight_state: Dict[str, Any]) -> Tuple: return tuple(in_flight_state.values())
That's the first of them. As you can see, it addresses 2 conversion issues between the state tuple and state dictionary. Since it also defines the state schema, the risk of forgetting to unpack a field is also weaker. Additionally, the handler also defines an empty state dict. You'll see that in the moment, but the empty dict is there to simplify some logic in the stateful function.
When it comes to the StructFieldWithStateUpdateHandler, the state handler uses it to declare the PySpark schema. But it's not only a wrapper for the StructField class. It's also a provider for a safer state dictionary updates:
@dataclasses.dataclass class StructFieldWithStateUpdateHandler: field: StructField def get(self, state_dict_to_read: Dict[str, Any]) -> Any: return state_dict_to_read[self.field.name] def update(self, state_dict_to_update: Dict[str, Any], new_value: Any): state_dict_to_update[self.field.name] = new_value
What does it give in the code? First, the job:
state_handler = StateSchemaHandler( start_time=StructFieldWithStateUpdateHandler(StructField("start_time", TimestampType())), end_time=StructFieldWithStateUpdateHandler(StructField("end_time", TimestampType())) ) visits = grouped_visits.applyInPandasWithState( func=functools.partial(map_with_state_refactored, state_handler), #... stateStructType=state_handler.schema, # ...
I'm declaring here the stateful function as a partial one, so a function with some of the arguments already predefined. The function itself becomes:
def map_with_state_refactored(state_handler: StateSchemaHandler, group_id_tuple: Any, input_rows: Iterable[pandas.DataFrame], current_state: GroupState) -> Iterable[pandas.DataFrame]: # ... if current_state.hasTimedOut: print(f"Session ({current_state.get}) expired for {group_id}; let's generate the final output here") current_state_as_dict = state_handler.get_state_as_dict(current_state.get) current_state.remove() else: # ... start_time_for_sink = first_event_timestamp_from_input end_time_for_sink = last_event_timestamp_from_input current_state_as_dict = state_handler.get_empty_state_dict() if current_state.exists: current_state_as_dict = state_handler.get_state_as_dict(current_state.get) start_time_for_sink = min(state_handler.start_time.get(current_state_as_dict), first_event_timestamp_from_input) end_time_for_sink = max(state_handler.start_time.get(current_state_as_dict), last_event_timestamp_from_input) state_handler.start_time.update(current_state_as_dict, start_time_for_sink) state_handler.end_time.update(current_state_as_dict, end_time_for_sink) state_to_update = StateSchemaHandler.transform_in_flight_state_to_state_to_write(current_state_as_dict) current_state.update(state_to_update) timeout_timestamp = base_watermark + session_expiration_time_16min_as_ms current_state.setTimeoutTimestamp(timeout_timestamp)
The first difference you can certainly see is the disappearing of the if-else statement for the state update. Now, I only have a single if statement that resolves the values to put into the state. The update passes through the StructFieldWithStateUpdateHandler's update function that remember, uses the state field name declared in the state schema. In the end, I'm using the dict-to-tuple conversion to prepare the state type for update.
Optimizations - output generation
The output generation is the next part to optimize. It suffers the same issues as the state management of the fields names redeclaration. One simple solution could be the use of constants and their redeclaration inside the stateful function and the job. But there is an alternative I'm going to explore with a possibility to enclose the schema declaration and output generation in the same abstraction, the OutputHandler:
class OutputHandler: def __init__(self, group_id: StructField, start_time: StructField, end_time: StructField, duration_in_milliseconds: StructField, is_final: StructField, state_schema_handler: StateSchemaHandler): self.group_id = group_id self.start_time = start_time self.end_time = end_time self.duration_in_milliseconds = duration_in_milliseconds self.is_final = is_final self.schema = StructType([ self.group_id, self.start_time, self.end_time, self.duration_in_milliseconds, self.is_final ]) self.state_schema_handler = state_schema_handler def generate_output(self, group_id: int, timed_out_state: bool, state_dict: Dict[str, Any]) -> Dict[str, Any]: start_time_for_output = self.state_schema_handler.start_time.get(state_dict) end_time_for_output = self.state_schema_handler.end_time.get(state_dict) duration_in_milliseconds = (end_time_for_output - start_time_for_output).total_seconds() * 1000 return { self.group_id.name: [group_id], self.start_time.name: [start_time_for_output.isoformat()], self.end_time.name: [end_time_for_output.isoformat()], self.duration_in_milliseconds.name: [int(duration_in_milliseconds)], self.is_final.name: [timed_out_state] }
As you can see, there is less magic and complexity that for the state management but the advantage of having everything in the single place is there. The job now looks like that:
state_handler = StateSchemaHandler( start_time=StructFieldWithStateUpdateHandler(StructField("start_time", TimestampType())), end_time=StructFieldWithStateUpdateHandler(StructField("end_time", TimestampType())) ) output_handler = OutputHandler( group_id=StructField("group_id", IntegerType()), start_time=StructField("start_time", StringType()), end_time=StructField("end_time", StringType()), duration_in_milliseconds=StructField("duration_in_milliseconds", LongType()), is_final=StructField("is_final", BooleanType()), state_schema_handler=state_handler ) visits = grouped_visits.applyInPandasWithState( func=functools.partial(map_with_state_refactored, state_handler, output_handler), outputStructType=output_handler.schema, stateStructType=state_handler.schema, outputMode="update", timeoutConf="EventTimeTimeout" )
And the output generation in the stateful function:
def map_with_state_refactored(state_handler: StateSchemaHandler, output_handler: OutputHandler, group_id_tuple: Any, input_rows: Iterable[pandas.DataFrame], current_state: GroupState) -> Iterable[pandas.DataFrame]: # ... record = output_handler.generate_output(group_id=group_id, timed_out_state=current_state.hasTimedOut, state_dict=current_state_as_dict) yield pandas.DataFrame(record)
The solution presented in the blog post is maybe yet another example for the saying that every problem can be solved with an extra layer of abstraction. This time the abstraction is a kind of glue between the business logic expressed in the stateful function and the schema expectations of the mapInPandasWithState. There are certainly other ways to solve this and I'll be happy to read your ideas!