Apache Kafka source in Structured Streaming - "beyond the offsets"

Versions: Apache Spark 2.4.4

Even though I've already written a few posts about Apache Kafka as a data source in Apache Spark Structured Streaming, I still had some questions in my head. In this post I will try to answer them and let this Kafka integration in Spark topic for investigation later.

The idea to write this article came to me when I was preparing my talk at Paris Kafka Meetupa and by recalling what I've already known, I listed the questions below. This article completes other blog posts about Kafka Spark structured streaming.

In the post when I speak about metadata consumer, I mean the consumer running on the driver whose principal responsibility is offsets management. By saying data consumer, I mean the consumers from executors responsible for physically polling data from Kafka broker.

Can I define only 1 starting offset positions for a topic with 2 partitions?

At some point in my preparation I wanted to understand the configuration options. So far, I used only one global offset management but during my analysis, I saw that we could also use more fine-grained configuration, per partition, that way:

{"my_topic": {"0": 4, "1": 3}}

According to that configuration, Apache Spark would consume the records from the 4th and 3rd position for the partitions 0 and 1. I was thinking about what happens if we subscribe to a topic and specify only a part of partitions in this configuration. My initial intuition was that the missing partitions will use some default entry. However, after quickly configuring my query and starting it, I got this exception:

Exception in thread "main" org.apache.spark.sql.streaming.StreamingQueryException: assertion failed: If startingOffsets contains specific offsets, you must specify all TopicPartitions.
Use -1 for latest, -2 for earliest, if you don't care.
Specified: Set(ss_starting_offsets-0) Assigned: Set(ss_starting_offsets-0, ss_starting_offsets-1)
=== Streaming Query ===
Identifier: [id = d176b5d0-f649-4538-bcad-7bb4fd3f41b3, runId = 09a072bc-c567-4fcb-ba99-8722c2b0d4ce]
Current Committed Offsets: {}
Current Available Offsets: {}

Current State: ACTIVE
Thread State: RUNNABLE

Logical Plan:
Project [cast(value#27 as string) AS value#40]
+- StreamingExecutionRelation KafkaV2[Subscribe[ss_starting_offsets]], [key#26, value#27, topic#28, partition#29, offset#30L, timestamp#31, timestampType#32]

    at org.apache.spark.sql.execution.streaming.StreamExecution.org$apache$spark$sql$execution$streaming$StreamExecution$$runStream(StreamExecution.scala:295)
    at org.apache.spark.sql.execution.streaming.StreamExecution$$anon$1.run(StreamExecution.scala:189)
Caused by: java.lang.AssertionError: assertion failed: If startingOffsets contains specific offsets, you must specify all TopicPartitions.

And the check is quite straightforward since it compares the total number of partitions assigned with the partition offsets. So, if you subscribe and want fine-grained control over the read offsets, you must define all of them explicitly. If you don't want to read all partitions, you should prefer the assign option with the subset of partitions. But here too, you must define all of them in the startingOffsets option.

Does data locality exist in Apache Kafka source?

Yes! Apache Spark will try to always allocate the consumers reading given partition on the same executor. How is it implemented? KafkaOffsetRangeCalculator, used to calculate offset ranges to process in each Spark partition, defines a method called getLocation which, as you can suppose, returns - eventually - the executor that will execute given Spark partition:

  private def getLocation(tp: TopicPartition, executorLocations: Seq[String]): Option[String] = {
    def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b

    val numExecutors = executorLocations.length
    if (numExecutors > 0) {
      // This allows cached KafkaConsumers in the executors to be re-used to read the same
      // partition in every batch.
      Some(executorLocations(floorMod(tp.hashCode, numExecutors)))
    } else None
  }

The preferred locations are later returned by KafkaSourceRDD through:

  override def getPreferredLocations(split: Partition): Seq[String] = {
    val part = split.asInstanceOf[KafkaSourceRDDPartition]
    part.offsetRange.preferredLoc.map(Seq(_)).getOrElse(Seq.empty)
  }

To recall, the preferred locations are used by DAGScheduler to figure out where every task should be executed:

    // If the RDD has some placement preferences (as is the case for input RDDs), get those
    val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
    if (rddPrefs.nonEmpty) {
      return rddPrefs.map(TaskLocation(_))
    }

There is any mechanism to deal with data skew?

When I was answering this question, I was pretty amazed by how far the Apache Kafka integration went. Thanks to the property called minPartitions you can set how many offsets should be read by each partition and, therefore, split bigger (skewed) partitions into a few smaller and better parallelizable ones on Spark. That's how it works:

      // Splits offset ranges with relatively large amount of data to smaller ones.
      val totalSize = offsetRanges.map(_.size).sum
      offsetRanges.flatMap { range =>
        val tp = range.topicPartition
        val size = range.size
        // number of partitions to divvy up this topic partition to
        val parts = math.max(math.round(size.toDouble / totalSize * minPartitions.get), 1).toInt
        var remaining = size
        var startOffset = range.fromOffset
        (0 until parts).map { part =>
          // Fine to do integer division. Last partition will consume all the round off errors
          val thisPartition = remaining / (parts - part)
          remaining -= thisPartition
          val endOffset = math.min(startOffset + thisPartition, range.untilOffset)
          val offsetRange = KafkaOffsetRange(tp, startOffset, endOffset, None)
          startOffset = endOffset
          offsetRange
        }
      }

The idea is then to allocate the most even number of offsets to process per Spark partition. An interesting thing to notice is also that Spark explicitly removes data locality by creating KafkaOffsetRange(tp, startOffset, endOffset, None). Otherwise, the optimization wouldn't have any positive effect on the data processing since smaller tasks would be queued on the same executor.

How Spark knows how many offsets to read?

The skew management question pointed me out another point, this time related to the number of offsets. Apache Spark is able to control the maximal number of records to process with maxOffsetsPerTrigger setting. This value is global and is divided by the number of partitions in KafkaSource#rateLimit(limit: Long, from: Map[TopicPartition, Long], until: Map[TopicPartition, Long]) method.

But what happens if we don't specify the maxOffsetsPerTrigger? In that case, the metadata reader on driver side will seek to the end of each assigned partition and return this number for every micro-batch execution:

  def fetchLatestOffsets(
      knownOffsets: Option[PartitionOffsetMap]): PartitionOffsetMap = runUninterruptibly {
    withRetriesWithoutInterrupt {
// ...
        consumer.seekToEnd(partitions)
        partitions.asScala.map(p => p -> consumer.position(p)).toMap
// ...
// Used for instance here
    endPartitionOffsets = Option(end.orElse(null))
        .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets)
        .getOrElse {
          val latestPartitionOffsets =
            kafkaOffsetReader.fetchLatestOffsets(Some(startPartitionOffsets))
          maxOffsetsPerTrigger.map { maxOffsets =>
            rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets)
          }.getOrElse {
            latestPartitionOffsets
          }
        }

What is the responsibility of failOnDataLoss property?

Apache Spark consumer also implements the concept of data loss management. It let you define the behavior of data loss by enabling or disabling failOnDataLoss property. If you enable it and you miss some data between 2 consecutive reads, you should get an IllegalStateException with the following message:

Some data may have been lost because they are not available in Kafka any more; either the
 data was aged out by Kafka or the topic may have been deleted before all the data in the
 topic was processed. If you don't want your streaming query to fail on such cases, set the
 source option "failOnDataLoss" to "false".

It's easy to memorize but a little bit more complicated to understand. What after all is this data loss? The data loss may happen when Apache Spark works with the offsets to read and it represented by 4 situations:

What happens for pattern subscription at the query restart from checkpoint metadata?

Another question I had was about pattern subscription. After a quick research, I found that Apache Kafka brings a property called metadata.max.age.ms to control the frequency of lookups for new topics matching the pattern. And it makes sense since the consumer responsible for metadata management in Apache Spark is created only once for all query executions within the given run. But what happens if you restart your query? After all, Apache Spark will store the offsets for the processed topics. Well, that's only the offsets. All eligible topics will be resolved once again, so potentially including the ones created during the downtime, with the consumer initialized by org.apache.spark.sql.kafka010.SubscribePatternStrategy:

case class SubscribePatternStrategy(topicPattern: String) extends ConsumerStrategy {
  override def createConsumer(
      kafkaParams: ju.Map[String, Object]): Consumer[Array[Byte], Array[Byte]] = {
    val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams)
    consumer.subscribe(
      ju.regex.Pattern.compile(topicPattern),
      new NoOpConsumerRebalanceListener())
    consumer
  }

  override def toString: String = s"SubscribePattern[$topicPattern]"
}

Data consumers are cached?

When you analyze the code source of the connector, you can see that the Kafka consumer class is not created but acquired via org.apache.spark.sql.kafka010.KafkaDataConsumer#acquire( topicPartition: TopicPartition, kafkaParams: ju.Map[String, Object], useCache: Boolean) method. This, and the useCache parameter, automatically involve the existence of cached consumers. The consumer is cached per combination of topic/partition and consumer parameters. But it doesn't happen every time!

In continuous query execution, the consumer is never cached. But it's logical seems the continuous mode uses long-running tasks over batch-scoped tasks in micro-batch execution. You can see that in KafkaContinuousReader class:

class KafkaContinuousInputPartitionReader(
    topicPartition: TopicPartition,
    startOffset: Long,
    kafkaParams: ju.Map[String, Object],
    pollTimeoutMs: Long,
    failOnDataLoss: Boolean) extends ContinuousInputPartitionReader[InternalRow] {
  private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams, useCache = false)

Regarding micro-batch processing, based on KafkaMicroBatchReader, things a little bit more complex. The fact of reusing the consumer is driven by the tasks concurrency. Remember our minPartitions param from the beginning of this post? If it creates multiple tasks reading the same topic/partition, Spark will acquire a not cached Kafka consumer:

    // Reuse Kafka consumers only when all the offset ranges have distinct TopicPartitions,
    // that is, concurrent tasks will not read the same TopicPartitions.
    val reuseKafkaConsumer = offsetRanges.map(_.topicPartition).toSet.size == offsetRanges.size

    // Generate factories based on the offset ranges
    offsetRanges.map { range =>
      new KafkaMicroBatchInputPartition(
        range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer
      ): InputPartition[InternalRow]
    }.asJava

And here too, we can find an explanation pretty easily. The concurrent tasks can be executed on the same executor and we don't want to share the consumer in that case to avoid inconsistencies.

Do offsets to process are retrieved only once?

This question came to my mind when I was rehearsing my talk and started to say that at every micro-batch execution, the metadata consumer knows the last offsets to process per partition and that it only fetches them in the first query execution. After saying that I knew that I was wrong. With this strategy data consumers would process only a small part of offsets. So, what's the real logic behind offsets retrieval?

The answer is quite simple. MicroBatchExecution sets offsets range for given current micro-batch that way:

      case s: MicroBatchReader =>
        updateStatusMessage(s"Getting offsets from $s")
        reportTimeTaken("setOffsetRange") {
          // Once v1 streaming source execution is gone, we can refactor this away.
          // For now, we set the range here to get the source to infer the available end offset,
          // get that offset, and then set the range again when we later execute.
          s.setOffsetRange(
            toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
            Optional.empty())
        }

Where the Optional.empty() represents the end offsets to set. This hardcoded value will make Apache Spark to always fetch the latests offsets from the broker:

  override def setOffsetRange(start: ju.Optional[Offset], end: ju.Optional[Offset]): Unit = {
    // Make sure initialPartitionOffsets is initialized
    initialPartitionOffsets

    startPartitionOffsets = Option(start.orElse(null))
        .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets)
        .getOrElse(initialPartitionOffsets)

    endPartitionOffsets = Option(end.orElse(null))
        .map(_.asInstanceOf[KafkaSourceOffset].partitionToOffsets)
        .getOrElse {
          val latestPartitionOffsets = kafkaOffsetReader.fetchLatestOffsets()
          maxOffsetsPerTrigger.map { maxOffsets =>
            rateLimit(maxOffsets, startPartitionOffsets, latestPartitionOffsets)
          }.getOrElse {
            latestPartitionOffsets
          }
        }
  }

So only the starting offsets, after being resolved in the initial query execution, are moved from one query to another. The end offsets for every micro-batch are always fetched from the broker. The signature of this method has optionals because it's shared with the physical execution of the micro-batch and in that case, the end offset will correspond to previously fetched offsets, so they'll be defined.

After writing 2 blog posts about Apache Kafka source, I thought to be done. However, my recent talk preparations made me discover a few more interesting topics like data loss or data skewness management.