Stateful transformations with mapWithState

Versions: Spark 2.1.0 https://github.com/bartosz25/spark-...itingforcode/MapWithStateTest.scala

updateStateByKey function, explained in the post about Stateful transformations in Spark Streaming, is not the single solution provided by Spark Streaming to deal with state. Another one, much more optimized, is mapWithState.

This post presents mapWithState. The first section explains this function as well as its differences with updateStateByKey. The second part reveals some implementation details. The last part, through an example of session tracking, shows how use mapWithState to stateful operations in Spark Streaming.

mapWithState explained

Spark's stateful operations work with key-value entries. Thus, before calling some of them, we must often transform scalar values to pairs. Used with stateful transformations it provokes shuffling - but only incoming data is shuffled to be manipulated by the nodes holding the state related to that data.

The addition of mapWithState to Spark's API is the result of updateStateByKey weaknesses. The following list describes shortly the bad points of updateStateByKey and explains in what mapWithState is the response:

mapWithState function takes 3 parameters: key (any type), new value (wrapped as Option) and state (in State object). Each of them is important from the point of view of state lifecycle:

mapWithState implementation details

To use mapWithState we must pass stateful function to function method of StateSpec object. The object represents state specification. It wraps stateful function but also defines timeout, partitioner or initial state.

Underlyed DStream is represented by org.apache.spark.streaming.dstream.MapWithStateDStream class. The state itself is, as already told, represented by State objects holding current state as well as information about removal and timeout.

But the real magic happens in MapWithStateRDD that, among others, handles timeout and optimized state updating. The timeout is managed thanks to data structure called StateMap. The characteristic point of it is that it keeps key-value pairs and also the time of their last update. StateMap also exposes a method getByTime(threshUpdatedTime: Long) helpful to retrieve entries older than the time specified in the parameter. We can see how it's used to mark states as expired in MapWithStateRDDRecord:

if (removeTimedoutData && timeoutThresholdTime.isDefined) {
  newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
    wrappedState.wrapTimingOutState(state)
    // You can see there that None value represents 
    // expired state
    val returned = mappingFunction(batchTime, key, None, wrappedState)
    mappedData ++= returned
    newStateMap.remove(key)
  }
}

Another interesting implementation detail concerns state updates. As told in the first section, updateStateByKey iterates over all entries - even if only a small subset of them has new values. mapWithState does it differently because it iterates only states really having new values to handle. MapWithStateRDD computes new data and calls MapWithStateRDDRecord's updateRecordWithData(prevRecord: Option[MapWithStateRDDRecord[K, S, E]], dataIterator: Iterator[(K, V)],mappingFunction: (Time, K, Option[V], State[S]) => Option[E], batchTime: Time, timeoutThresholdTime: Option[Long], removeTimedoutData: Boolean) function. As you can see, it takes dataIterator object that represents incoming data to handle. Now, instead of checking all states, MapWithStateRDDRecord choses states to update, according to entries in dataIterator:

// stateMap is a map storing states as values and state keys as key
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

// ...

dataIterator.foreach { case (key, value) =>
  wrappedState.wrap(newStateMap.get(key))
  val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
  if (wrappedState.isRemoved) {
    newStateMap.remove(key)
  } else if (wrappedState.isUpdated
      || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
    newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
  }
  mappedData ++= returned
}

Thanks to that, according to some benchmarks (e.g. Faster Stateful Stream Processing in Apache Spark Streaming), mapWithState is much faster than its precedessor.

mapWithState example

To given an example of mapWithState use, let's take the case of user sessions. The following test cases shows a sample mechanism that could be use to manage them:

val dataQueue: mutable.Queue[RDD[Visit]] = new mutable.Queue[RDD[Visit]]()

"expired state" should "help to detect the end of user's visit" in {
  val visits = Seq(
    Visit(1, "home.html", 10), Visit(2, "cart.html", 5), Visit(1, "home.html", 10),
    Visit(2, "address/shipping.html", 10), Visit(2, "address/billing.html", 10)
  )
  visits.foreach(visit => dataQueue += streamingContext.sparkContext.makeRDD(Seq(visit)))

  def handleVisit(key: Long, visit: Option[Visit], state: State[Long]): Option[Any] = {
    (visit, state.getOption()) match {
      case (Some(newVisit), None) => {
        // the 1st visit
        state.update(newVisit.duration)
        None
      }
      case (Some(newVisit), Some(totalDuration)) => {
        // next visit
        state.update(totalDuration + newVisit.duration)
        None
      }
      case (None, Some(totalDuration)) => {
        // last state - timeout occurred and passed
        // value is None in this case
        Some(key, totalDuration)
      }
      case _ => None
    }
  }

  // The state expires 4 seconds after the lasts seen entry for
  // given key. The schedule for our test will look like:
  // user1 -> 0+4, user2 -> 1+4, user1 -> 2+4, user2 -> 3+4, user2 -> 4+4
  val sessionsAccumulator = streamingContext.sparkContext.collectionAccumulator[(Long, Long)]("sessions")
  streamingContext.queueStream(dataQueue)
    .map(visit => (visit.userId, visit))
    .mapWithState(StateSpec.function(handleVisit _).timeout(Durations.seconds(4)))
    .foreachRDD(rdd => {
      val terminatedSessions =
        rdd.filter(_.isDefined).map(_.get.asInstanceOf[(Long, Long)]).collect()
      terminatedSessions.foreach(sessionsAccumulator.add(_))
    })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(10000)

  println(s"Terminated sessions are ${sessionsAccumulator.value}")
  sessionsAccumulator.value.size shouldEqual(2)
  sessionsAccumulator.value should contain allOf((1, 20), (2, 25))
}

"mapWithState" should "help to buffer messages" in {
  // This time mapWithState operation is considered as a buffer
  // Suppose that we want to send user sessions to a data store
  // only at the end of visit (as previously). To do so we need to
  // accumulate all visits
  val visits = Seq(
    Visit(1, "home.html", 10), Visit(2, "cart.html", 5), Visit(1, "cart.html", 10),
    Visit(2, "address/shipping.html", 10), Visit(2, "address/billing.html", 10)
  )
  visits.foreach(visit => dataQueue += streamingContext.sparkContext.makeRDD(Seq(visit)))

  def bufferVisits(key: Long, visit: Option[Visit], state: State[ListBuffer[Visit]]): Option[Seq[Visit]] = {
    val currentVisits = state.getOption().getOrElse(ListBuffer[Visit]())
    if (visit.isDefined) {
      currentVisits.append(visit.get)
      if (currentVisits.length > 2) {
        val returnedVisits = Seq(currentVisits.remove(0), currentVisits.remove(1))
        state.update(currentVisits)
        Some(returnedVisits)
      } else {
        state.update(currentVisits)
        None
      }
    } else {
      // State expired, get all visits
      Some(currentVisits)
    }
  }

  val bufferedSessionsAccumulator = streamingContext.sparkContext
    .collectionAccumulator[Seq[Visit]]("buffered sessions")
  streamingContext.queueStream(dataQueue)
    .map(visit => (visit.userId, visit))
    .mapWithState(StateSpec.function(bufferVisits _).timeout(Durations.seconds(4)))
    .foreachRDD(rdd => {
      val bufferedSessions =
        rdd.filter(_.isDefined).map(_.get).collect()
      bufferedSessions.foreach(visits => bufferedSessionsAccumulator.add(visits))
    })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(12000)

  val bufferedSessions = bufferedSessionsAccumulator.value
  bufferedSessions.size() shouldEqual(3)
  bufferedSessions.get(0) should contain allOf(Visit(2, "cart.html", 5), Visit(2, "address/billing.html", 10))
  bufferedSessions.get(1) should contain allOf(Visit(1, "home.html", 10), Visit(1, "cart.html", 10))
  bufferedSessions.get(2)(0) shouldBe (Visit(2, "address/shipping.html", 10))
}

"shopping sessions" should "be discarded from session tracking thanks to state removal feature" in {
  // Here we want to keep only sessions that
  // don't concern shopping (cart) part. We consider that
  // cart.html is the last possible visited page if the
  // user makes some shopping
  val visits = Seq(
    Visit(1, "home.html", 2), Visit(2, "index.html", 11), Visit(1, "cart.html", 1),
    Visit(1, "forum.html", 1)
  )
  visits.foreach(visit => dataQueue += streamingContext.sparkContext.makeRDD(Seq(visit)))

  def keepNotShoppingSessions(key: Long, visit: Option[Visit], state: State[ListBuffer[Visit]]): Option[Seq[Visit]] = {
    // For simplicity we keep keys infinitely
    if (visit.isDefined) {
      val isCartPage = visit.get.page.contains("cart")
      if (isCartPage) {
        // Discard state for given user
        // However, the discard concerns only previous
        // entries, i.e. it's not persisted for subsequent
        // data for given key
        state.remove()
      } else {
        val currentVisits = state.getOption().getOrElse(ListBuffer[Visit]())
        currentVisits.append(visit.get)
        // update() is mandatory to call
        // otherwise nothing persists
        state.update(currentVisits)
      }
      None
    } else {
      state.getOption()
    }
  }

  val notShoppingSessionsAccumulator = streamingContext.sparkContext
    .collectionAccumulator[Seq[Visit]]("not shopping sessions")
  streamingContext.queueStream(dataQueue)
    .map(visit => (visit.userId, visit))
    .mapWithState(StateSpec.function(keepNotShoppingSessions _).timeout(Durations.seconds(4)))
    .foreachRDD(rdd => {
      val notShoppingSessions =
        rdd.filter(_.isDefined).map(_.get).collect()
      notShoppingSessions.foreach(visits => notShoppingSessionsAccumulator.add(visits))
    })

  streamingContext.start()
  streamingContext.awaitTerminationOrTimeout(25000)

  val notShoppingSessions = notShoppingSessionsAccumulator.value
  notShoppingSessions.size() shouldEqual(2)
  notShoppingSessions should contain allOf(
    Seq(Visit(1, "forum.html", 1)), Seq(Visit(2, "index.html", 11))
  )
}
  
case class Visit(userId: Long, page: String, duration: Long) 

Stateful operations mandatory to make some tracking operations (e.g. user sessions). Spark Streaming initially provided updateStateByKey transformation that appeared to have some drawbacks (return type the same as state value, slowness). The alternative to it is mapWithState method - optimized, providing more features (timeout, partitioner, initial state) and more flexible. The first part explained its specific grammar and advantages. The second section described some implementation details. Thanks to them we could learn how timeout is managed and what the selective state updates are. The last part, through an example of session tracking, shown how to use mapWithState.