Apache Spark 2.4.0 features - bucket pruning

Versions: Apache Spark 2.4.0 https://github.com/bartosz25/spark-...tingforcode/sql/BucketingTest.scala

This post begins a new series dedicated to Apache Spark 2.4.0 features. The first covered topic will be bucket pruning.

The bucket pruning feature will be explained in 3 parts. The first one will define the bucket because we haven't approached it yet here and it may be a misunderstood concept for the ones of you who didn't work with Hive. The second section will focus on the changes introduced in Apache Spark 2.4.0 while the last one will show them in action.

Bucket definition

Evenly balanced partitions let us to process the data faster. For instance, we can collect IoT events and partition them by the date and store in a tree-looking like structure:

/events/2018/10/29
/events/2018/10/30
/events/2018/10/31

And what we can do if we would like to partition the same data by IoT device number? Technically it's possible but from the conceptual point of view may be less efficient than the date-based partitioning. The device key is a value with very high cardinality (number of possible unique values) and we would end up with a tree having hundreds or thousands of subdirectories. One of the solutions for this issue of values which are not good candidates for partitioning is bucketing, also called clustering.

A bucket is a "partition inside a partition". The difference is that the number of buckets is fixed. The most of the time, the values are allocated to the buckets with a simple hash modulo-based strategy. Below you can find an example of buckets in Apache Spark:

  "Spark" should "create buckets in partitions for orders Dataset" in {
    val tableName = s"orders${System.currentTimeMillis()}"
    val orders = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user1")).toDF("order_id", "user_id")

    orders.write.mode(SaveMode.Overwrite).bucketBy(2, "user_id").saveAsTable(tableName)

    val metadata = TestedSparkSession.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
    metadata.bucketSpec shouldBe defined
    metadata.bucketSpec.get.numBuckets shouldEqual 2
    metadata.bucketSpec.get.bucketColumnNames(0) shouldEqual "user_id"
  }

In distributed data processing frameworks this approach helps often to avoid the shuffle stage. For instance, when the bucketing is used on 2 Datasets joined with sort-merge join in Spark SQL, the shuffle may not be necessary because both Datasets can be already located in the same partitions. Of course, both datasets must have the same number of partitions and use hash partitioning algorithm.

Bucket pruning implementation

Prior to Apache Spark 2.4.0 when one of the bucketed columns were involved in the query, the engine didn't make any optimization on it. After all, since the bucketing is deterministic, the engine could read only the bucket files storing the filtered values.

The feature was added as a new private method inside FileSourceStrategy, called only when given dataset has only 1 bucketing column and at least 2 buckets:

  private def genBucketSet(
      normalizedFilters: Seq[Expression],
      bucketSpec: BucketSpec): Option[BitSet] = {

Inside the method, Apache Spark figures out what buckets should be involved in the query execution by calling one of 2 methods: getBucketSetFromIterable or getBucketSetFromValue. They're used according to the defined filtering method that can be an equality or a "is in" constraint. The bucket id number built in BucketingUtils.getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any) method and returned as a result for previous methods.

The filtered buckets are next passed as a BitSet to FileSourceScanExec which uses them to filter out the bucket files not storing the queried data:

val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
      val bucketSet = optionalBucketSet.get
      filesGroupedToBuckets.filter {
        f => bucketSet.get(f._1)
      }
    } else {
      filesGroupedToBuckets
    }
    val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
      FilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Nil))
    }

Bucket pruning example

To see the optimization in action we'll use the same example as in the first section of this post, namely the table of orders:

  "Spark 2.4.0" should "not read buckets filtered out" in {
    val tableName = s"orders${System.currentTimeMillis()}"
    val orders = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user1"), (5L, "user4"), (6L, "user5"))
      .toDF("order_id", "user_id")

    orders.write.mode(SaveMode.Overwrite).bucketBy(3, "user_id").saveAsTable(tableName)

    val filteredBuckets = TestedSparkSession.sql(s"SELECT * FROM ${tableName} WHERE user_id = 'user1'")

    filteredBuckets.queryExecution.executedPlan.toString() should include("SelectedBucketsCount: 1 out of 3")
  }

As you can notice, the assertion checks whether the physical plan contains the "SelectedBucketsCount" text. It was added in the release 2.4.0 to indicate the bucket pruning feature.

Bucket pruning is only one of the new features in Apache Spark 2.4.0. It helps to process only the buckets with the filtered entries, and hence to reduce the number of processed partitions. The implementation consists on passing a BitSet holding all processable bucket ids to the operator responsible for scanning a collection of files.