Apache Kafka transactional writer with foreach sink, is it possible?

Versions: Apache Spark 3.1.1

Even though Apache Kafka supports transactional producers, they're not present in Apache Spark Kafka sink. But despite that, is it possible to implement a transactional producer in Apache Spark Structured Streaming? You should see that at the end of this article.

The post is organized as follows. In the first section, you will find a refresher on the foreach sink. Just after that, you will see how to implement a transactional producer in Apache Kafka. Finally, in the 2 last sections, you will see 2 implementations of it in Structured Streaming. The first one won't work correctly due to the micro-batch character of the processing, whereas the latter - thanks to some external help - will overcome this issue.

Foreach sink

To implement a custom writer in Apache Spark Structured Streaming you have different choices. If the writer is available only in batch Dataset, you can use foreachBatch. If the writer is not available, you can implement it on your own with the new APIs like DataWriter. It's maybe the best option, especially if you can give the implementation back to the community. But it requires a bit knowledge of Apache Spark API.

Hopefully, there is a third option that I'll explore here, the foreach sink. To use it, you have to implement the interface defined by ForeachWriter[T] which looks like that:

def open(partitionId: Long, epochId: Long): Boolean
def process(value: T): Unit
def close(errorOrNull: Throwable): Unit

The 3 methods exposed by the sink are easy to understand. The open() is called when Apache Spark starts to write the results of a task. The process() is the method where the framework writes the row to the sink, whereas close where it performs any cleanup action after processing all rows. If you worked with Apache Beam before, you would certainly notice some similarities with the ParDo API (as of this writing, I'm preparing for GCP Data Engineer certification and revising a lot of Beam concepts, that's why this mention here ;-))

Transactional producer in Apache Kafka

If you follow the blog, you know that understanding Apache Kafka transactions is one of my goals. I had to postpone it multiple times already, but I hope to get it right in 2021! Hopefully, thanks to a few excellent resources (links below the article), I will be able to present them at a high level.

The first difference you'll observe between a transactional and non transactional producer is the presence of transactional.id configuration property. It should be unique per producer to prevent the fencing; i.e., prevent that the producer's instance created by one task invalidates the transactional instance of the other task.

What happens then if we start the producer with the same transaction id? According to the "Transactions in Apache Kafka" linked below the article:

The initTransactions API registers a transactional.id with the coordinator. At this point, the coordinator closes any pending transactions with that transactional.id and bumps the epoch to fence out zombies. This happens only once per producer session.

As you can see, all pending transactions for the same transactional.id should be discarded. It's particularly useful for unexpected failures when the producer can't correctly abort the transaction with the transactional API. The API which is very similar to the one of RDBMS:

initTransactions()
beginTransaction()

// Depending on the outcome:
commitTransaction() // to validate the transaction
abortTransaction() // to abandon the transaction

The above presentation is not accidental because it fits the foreach sink interface! Let me show you how in the next section.

Foreach-based transactional producer - first implementation

The first implementation of the writer was very straightforward but - spoiler alert - it didn't work at all. Let me show you:

class ForeachKafkaTransactionalWriter(outputTopic: String) extends ForeachWriter[String] {

  private var kafkaProducer: KafkaProducer[String, String] = _

  override def open(partitionId: Long, epochId: Long): Boolean = { 
    kafkaProducer = ForeachKafkaTransactionalWriter.getOrCreate(partitionId)
    if (!ForeachKafkaTransactionalWriter.wasInitialized(partitionId)) {
      kafkaProducer.initTransactions()
      ForeachKafkaTransactionalWriter.setInitialized(partitionId)
    }
    kafkaProducer.beginTransaction()
    true
  }

  override def process(value: String): Unit = { 
    if (value == "K" && ShouldFailOnK) {
      throw new RuntimeException("Got letter that stops the processing")
    }
    kafkaProducer.send(new ProducerRecord[String, String](outputTopic, value))
  }

  override def close(errorOrNull: Throwable): Unit = {
    if (errorOrNull != null) { 
      kafkaProducer.abortTransaction()
      kafkaProducer.close()
    } else { 
      kafkaProducer.commitTransaction()
    }
  }
}

object ForeachKafkaTransactionalWriter {

  private val producers = mutable.HashMap[Long, KafkaProducer[String, String]]()
  private val producersInitialized = mutable.HashMap[Long, Boolean]()

  def getOrCreate(partitionId: Long): KafkaProducer[String, String] = { 
    val producer = producers.get(partitionId)
    if (producer.isDefined) {
      println("Getting defined producer")
      producer.get
    } else {
      println("Getting new producer")
      val newProducer = create(partitionId)
      producers.put(partitionId, newProducer)
      newProducer
    }
  }

  def wasInitialized(partitionId: Long) = producersInitialized.getOrElse(partitionId, false)
  def setInitialized(partitionId: Long) = producersInitialized.put(partitionId, true)

  private def create(partitionId: Long): KafkaProducer[String, String] = {
    val properties = new Properties()
    properties.setProperty("transactional.id", s"transactional_writer_${partitionId}")
    properties.setProperty("bootstrap.servers", "localhost:29092")
    properties.setProperty("key.serializer", "org.apache.kafka.common.serialization.StringSerializer")
    properties.setProperty("value.serializer", "org.apache.kafka.common.serialization.StringSerializer")

    new KafkaProducer[String, String](properties)
  }

}

Why is this implementation wrong? Let's take a moment and analyze this schema:

This is weird, right? We used a transactional Kafka producer, but the database still contains duplicates! If you check the micro-batch semantic of a distributed data processing framework like Apache Spark, you will understand that it's normal. The transaction is local for every task and we can't guarantee that all the tasks commit at the same time! "But wait, it works with Kafka consumers!" - you may want to say. Indeed, it does, but Kafka consumers validate the offset unitary; i.e., you don't need to synchronize the whole group of consumers as it happens in Structured Streaming.

Can we fix it? Yes, but we need another storage for the already committed transactions. Let's see it in the next section.

Foreach-based transactional producer - second implementation

The second implementation uses an external K/V store to persist the committed transactions. Two changes were necessary to implement it in the foreach-based writer. First, we had to add an extra condition to the open(partitionId: Long, epochId: Long) method. Now, it also queries the committed transactions store:

class ForeachKafkaTransactionalWriter(outputTopic: String) extends ForeachWriter[String] {
  private var skipMicroBatch = false

  override def open(partitionId: Long, epochId: Long): Boolean = {
    println(s"Opening foreach for ${partitionId} and ${epochId}")
    skipMicroBatch = epochId <= CommittedTransactionsStore.getLastCommitForPartition(partitionId)
    if (!skipMicroBatch) {
// ... init the producer
      true
    } else {
      // False indicates that the partition should be skipped
      false
    }
  }
// ...
}

object CommittedTransactionsStore {

  private val committedTransactionsStoreBackend = DBMaker
    .fileDB("/tmp/waitingforcode/kafka-transactions-store")
    .fileMmapEnableIfSupported()
    .make()
  private val committedTransactionsStore = committedTransactionsStoreBackend
    .hashMap("committed-transactions", Serializer.LONG, Serializer.LONG)
    .createOrOpen()

  def commitTransaction(partitionId: Long, epochId: Long) = committedTransactionsStore.put(partitionId, epochId)

  def getLastCommitForPartition(partitionId: Long) = committedTransactionsStore.get(partitionId)

}

The second modification related to this extra component is the close method where now we have to update the committed transaction value and control when the commit or abort is called:

  override def close(errorOrNull: Throwable): Unit = {
    // Check this too because the close is called even if
    // the open returns true
    if (!skipMicroBatch) {
      if (errorOrNull != null) {
        println("An error occurred, aborting the transaction!")
        kafkaProducer.abortTransaction()
        kafkaProducer.close()
      } else {
        println("Committing the transaction")
        kafkaProducer.commitTransaction()
        CommittedTransactionsStore.commitTransaction(partitionId, epochId)
      }
    }
  }

How does it work? Let's see in the following video:

As you can see, we can implement transactional writers, but it comes with a few drawbacks. First, since Apache Spark checkpoints only the metadata (offsets) and works with multiple tasks in parallel, we need secondary storage to know the committed epochs (micro-batches) for every task. Second, the transactions work only for natively partitioned data sources, strengthening the idempotent input for each task. Hence the transaction will always apply to the same data. To strengthen this guarantee, use the maxOffsetsPerTrigger property.

Finally - I insisted on putting it aside since it's a bigger drawback than the previous ones - the use of a transactions store involves a kind of a double commit that can not be atomic. Imagine that you correctly acknowledged the epoch processing by writing it to the transaction store, but Kafka's commit failed. Bad luck; you will miss the data at reprocessing. The opposite is also true but in that case you'll process data twice.

When I started to analyze Apache Kafka integration in Apache Spark 4 years ago (that's the first blog post I wrote on this topic), I couldn't figure out why implementing the transactional consumers were hard. In this blog post, I tried to implement it from a high-level solution without changing anything in the framework. Even this approach spotted some difficulties of doing that in the context of a micro-batch-based distributed streaming processing.