DataFrame or Dataset to solve sessionization problem?

Versions: Apache Spark 2.4.2

When I was preparing the demo code for my talk about sessionization at Spark AI Summit 2019 in Amsterdam, I wrote my first version of code with DataFrame abstraction. I hadn't type safety but the data manipulation was quite clear thanks to the mapping. Later, I tried to rewrite the code with Dataset and I got type safety but sacrificed a little bit of clarity. Let me deep delve into that in this post.

In this post I would like to explain to you why I decided to keep my initial idea using DataFrame abstraction. Please keep in mind that here too, I faced 2 different trade-offs, namely type safety and code clarity.

DataFrame code

The code I'll talk about here is:

val joinedData = previousSessions.join(sessionsInWindow,
sessionsInWindow("user_id") === previousSessions("userId"), "fullouter")
.groupByKey(log => SessionGeneration.resolveGroupByKey(log))
.flatMapGroups(SessionGeneration.generate(TimeUnit.MINUTES.toMillis(5), windowUpperBound))

I'm doing here a FULL OUTER JOIN to combine inactive sessions from the previous processing window and new input logs. Thanks to this type of join I'm able to manage 3 different situations: a completely new session, an existent session with new events to integrate and an existent session without new events to integrate. I could do that by making 3 separate joins (1 of many other configurations), an INNER JOIN to extend already existent sessions and 2 LEFT ANTI JOINS to cover remaining use cases. But I wanted to avoid making too many shuffle operations and FULL OUTER JOIN seemed the best choice for that.

After making the join, I'm grouping all rows having the same key. And the logic operates on Row abstraction:

  def resolveGroupByKey(log: Row): Long = {
    if (SessionIntermediaryState.Mapper.language(log) != null) {
      SessionIntermediaryState.Mapper.userId(log)
    } else {
      InputLogMapper.userId(log)
    }
  }

I'm simply checking whether the mandatory property in an intermediate session exists. If yes, it means that I'm either extending or closing an inactive session. Otherwise, I'm using the input log grouping key. At the end I'm calling SessionGeneration#generate method:

def generate(inactivityDurationMs: Long, windowUpperBoundMs: Long)(userId: Long, logs: Iterator[Row]): Seq[SessionIntermediaryState] = {
  val materializedLogs = logs.toSeq
  val firstLog = materializedLogs.head

  val sessions = (Option(InputLogMapper.eventTimeString(firstLog)), Option(SessionIntermediaryState.Mapper.userId(firstLog))) match {
    case (Some(_), Some(_)) => generateRestoredSessionWithNewLogs(dedupedAndSortedLogs(materializedLogs), inactivityDurationMs, Some(firstLog))
    case (None, Some(_)) => generateRestoredSessionWithoutNewLogs(materializedLogs, windowUpperBoundMs)
    case (Some(_), None) => generateRestoredSessionWithNewLogs(dedupedAndSortedLogs(materializedLogs), inactivityDurationMs, None)
    case (None, None) => throw new IllegalStateException("Session generation when there is not input nor previous " +
      "session logs should never happen")
  }
  sessions
}

Since the joined row combines the schemas of the input log and the intermediate session, I'm simply trying to retrieve separate values of them and apply corresponding block in the pattern matching section. And here we are, the problematic place which is generateRestoredSessionWithNewLogs() function calling SessionIntermediaryState#createNew:

  def createNew(logs: Iterator[Row], timeoutDurationMs: Long): SessionIntermediaryState = {
    val materializedLogs = logs.toSeq
    val visitedPages = SessionIntermediaryState.mapInputLogsToVisitedPages(materializedLogs)
    val headLog = materializedLogs.head

    SessionIntermediaryState(userId = InputLogMapper.userId(headLog), visitedPages = visitedPages, isActive = true,
      browser = InputLogMapper.browser(headLog), language = InputLogMapper.language(headLog),
      site = InputLogMapper.site(headLog),
      apiVersion = InputLogMapper.apiVersion(headLog),
      expirationTimeMillisUtc = getTimeout(visitedPages.last.eventTime, timeoutDurationMs)
    )
  }

Complicated? Not really. The problem is that this function is also used in my streaming example:

  def mapStreamingLogsToSessions(timeoutDurationMs: Long)(key: Long, logs: Iterator[Row],
                                 currentState: GroupState[SessionIntermediaryState]): SessionIntermediaryState = {
    if (currentState.hasTimedOut) {
      val expiredState = currentState.get.expire
      currentState.remove()
      expiredState
    } else {
      val newState = currentState.getOption.map(state => state.updateWithNewLogs(logs, timeoutDurationMs))
        .getOrElse(SessionIntermediaryState.createNew(logs, timeoutDurationMs))
      currentState.update(newState)
      currentState.setTimeoutTimestamp(currentState.getCurrentWatermarkMs() + timeoutDurationMs)
      currentState.get
    }
  }

Refactored Dataset version

From that I wanted to refactor the code and use Dataset[JoinedLog] abstraction. My code with FULL OUTER JOIN looked like:

    val joinedData = previousSessions.join(sessionsInWindow,
      sessionsInWindow("user_id") === previousSessions("userId"), "fullouter").as[JoinedLog]
      .groupByKey(log => log.groupByKey)
      .flatMapGroups(SessionGeneration.generate(TimeUnit.MINUTES.toMillis(5), windowUpperBound))

case class JoinedLog(user_id: Option[Long], event_time: Option[String], page: Option[Page],
                     source: Option[Source], user: Option[User],
                     technical: Option[Technical], userId: Option[Long], visitedPages: Option[Seq[VisitedPage]],
                     browser: Option[String], language: Option[String], site: Option[String],
                     apiVersion: Option[String], expirationTimeMillisUtc: Option[Long], isActive: Option[Boolean]) {

  def groupByKey = user_id.getOrElse(userId.get)

  def isNewSession = user_id.isDefined && userId.isEmpty
  def isRestoredSessionWithNewLogs = user_id.isDefined && userId.isDefined
  def isRestoredSessionWithoutLogs = user_id.isEmpty && userId.isDefined

}

When I started to refactor the code, I saw that my 2 versions weren't compatible anymore. In the streaming version, the state was managed by the state store class and the input data was mapped to a simple InputLog:

    val query = dataFrame.selectExpr("CAST(value AS STRING)")
      .select(functions.from_json($"value", Visit.Schema).as("data"))
      .select($"data.*").as[InputLog]
      .withWatermark("event_time", "3 minutes")
      .groupByKey(inputLog => inputLog.user_id)
      .mapGroupsWithState(GroupStateTimeout.EventTimeTimeout())(Mapping.mapStreamingLogsToSessions(sessionTimeout))

  def mapStreamingLogsToSessions(timeoutDurationMs: Long)(key: Long, logs: Iterator[InputLog],
                                                          currentState: GroupState[SessionIntermediaryState]): SessionIntermediaryState = {
    if (currentState.hasTimedOut) {
      val expiredState = currentState.get.expire
      currentState.remove()
      expiredState
    } else {
      val newState = currentState.getOption.map(state => state.updateWithNewLogs(logs, timeoutDurationMs))
        .getOrElse(SessionIntermediaryState.createNew(logs, timeoutDurationMs))
      currentState.update(newState)
      currentState.setTimeoutTimestamp(currentState.getCurrentWatermarkMs() + timeoutDurationMs)
      currentState.get
    }
  }

As you can deduce, I had to rework my createNew method to take JoinedLog in the parameter because of the batch. I could also use this class as the entry of the stateful function but it would be misleading. After all, in streaming, there is a clear separation between input data and the state. The very same separation is weaker in the batch mode because of the FULL OUTER JOIN. And that's the reason why I decided to keep the Row-based version for state transformations in my demo code. But unfortunately, Row-based version has also some drawbacks.

Drawbacks

The main drawback is that it's very hard to manipulate Row. If I had to write a single (streaming or batch) oriented application, I would start with Dataset abstractuin. It's much easier to manipulate case classes and especially if you are a TDD devotee. To create a test dataset with case classes, you only need to create case class objects to test and wrap them with a Dataset. It's much more complicated to do with Row objects.

To test a DataFrame-based pipeline, there are 2 main approaches. The first one is mocking. You can create a factory method returning mocked Row for you class:

  private def inputLog(eventTime: String): Row = {
    val mockedRow = Mockito.mock(classOf[Row])
    Mockito.when(mockedRow.getAs[String]("event_time")).thenReturn(eventTime)
    mockedRow
  }

  behavior of "logs dedupe and sort"

  it should "dedupe and sort input logs" in {
    val inputLogs = Seq(inputLog("4"), inputLog("1"), inputLog("2"), inputLog("1"), inputLog("3"))

    val sortedDedupedLogs = SessionGeneration.deduplicatedAndSortedLogs(inputLogs)

    sortedDedupedLogs should have size 4
    sortedDedupedLogs.map(row => row.getAs[String]("event_time")) should contain inOrderElementsOf(
      Seq("1", "2", "3", "4")
    )
  }

Depending on the complexity of your input (nested levels, number of fields,...), it can be painful. To facilitate your work, you can use Dataset abstraction and convert it into a DataFrame. However, in this scenario you will repeat the information and very probably desynchronize your production and testing part sooner or later.

So which one to choose, a DataFrame or a Dataset? If you are generating your final application, Dataset seems a better fit to start because it's simpler to operate on case class attributes rather than extracting Row properties with mappers. On the flip side, if you make a POC or an experiment, and want to keep your code base small (no one class for each tested scenario), working with Row can be good enough.


If you liked it, you should read:

đź“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!