Apache Spark 2.4.0 features - EXCEPT ALL and INTERSECT ALL

Versions: Apache Spark 2.4.0

Apache Spark 2.4.0 brought a lot of internal changes but also some new features exposed to the end users, as already presented high-order functions. In this post, I will present another new feature, or rather 2 actually, because I will talk about 2 new SQL functions.

These 2 new SQL operators are EXCEPT ALL and INTERSECT ALL. Each of 2 sections of this posst will describe one of them.

EXCEPT ALL

The simple EXCEPT operator returns all rows of the first dataset that are not present in the second:

  private val TestedSparkSession: SparkSession = SparkSession.builder()
    .appName("EXCEPT ALL test").master("local[*]").getOrCreate()
  import TestedSparkSession.implicits._
  val orders1 = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (3L, "user3"), (4L, "user1"), (5L, "user1"), (5L, "user1"))
    .toDF("order_id", "user_id")
  val orders2 = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (3L, "user3"), (4L, "user1"), (6L, "user1"))
    .toDF("order_id", "user_id")

  "EXCEPT" should "return all rows not present in the second dataset" in {
    val rowsFromDataset1NotInDataset2 = orders1.except(orders2)
    
    val exceptResult = rowsFromDataset1NotInDataset2.collect().map(row => row.getAs[Long]("order_id"))

    exceptResult should have size 1
    exceptResult(0) shouldEqual 5
  }

There is nothing complicated about the snippet. Internally, Apache Spark translates this operation into anti-left join, i.e. a join taking all rows from the left dataset that don't have their corresponding values in the right one. If you're interested, you can discover more join types in Spark SQL.

At the physical execution level, anti join is executed as an aggregation involving shuffle:

== Optimized Logical Plan ==
Aggregate [order_id#5L, user_id#6], [order_id#5L, user_id#6]
+- Join LeftAnti, ((order_id#5L <=> order_id#14L) && (user_id#6 <=> user_id#15))
   :- LocalRelation [order_id#5L, user_id#6]
   +- LocalRelation [order_id#14L, user_id#15]

== Physical Plan ==
*(2) HashAggregate(keys=[order_id#5L, user_id#6], functions=[], output=[order_id#5L, user_id#6])
+- Exchange hashpartitioning(order_id#5L, user_id#6, 200)
   +- *(1) HashAggregate(keys=[order_id#5L, user_id#6], functions=[], output=[order_id#5L, user_id#6])
      +- *(1) BroadcastHashJoin [coalesce(order_id#5L, 0), coalesce(user_id#6, )], [coalesce(order_id#14L, 0), coalesce(user_id#15, )], LeftAnti, BuildRight, ((order_id#5L <=> order_id#14L) && (user_id#6 <=> user_id#15))
         :- LocalTableScan [order_id#5L, user_id#6]
         +- BroadcastExchange HashedRelationBroadcastMode(List(coalesce(input[0, bigint, false], 0), coalesce(input[1, string, true], )))
            +- LocalTableScan [order_id#14L, user_id#15]

The EXCEPT ALL operator extends the behavior of EXCEPT by allowing the duplicated rows in the result:

  "EXCEPT ALL" should "return all rows not present in the second dataset and keep the duplicates" in {
    val rowsFromDataset1NotInDataset2 = orders1.exceptAll(orders2)

    val exceptResultAll = rowsFromDataset1NotInDataset2.collect().map(row => row.getAs[Long]("order_id"))

    exceptResult should have size 2
    exceptResultAll(0) shouldEqual 5
    exceptResultAll(1) shouldEqual 5
    rowsFromDataset1NotInDataset2.explain(true)
  }

The execution of the above query is planned as:

== Optimized Logical Plan ==
Project [order_id#5L, user_id#6]
+- Generate replicaterows(sum#31L, order_id#5L, user_id#6), [2], false, [order_id#5L, user_id#6]
   +- Filter (isnotnull(sum#31L) && (sum#31L > 0))
      +- Aggregate [order_id#5L, user_id#6], [order_id#5L, user_id#6, sum(vcol#28L) AS sum#31L]
         +- Union
            :- LocalRelation [vcol#28L, order_id#5L, user_id#6]
            +- LocalRelation [vcol#29L, order_id#14L, user_id#15]

== Physical Plan ==
*(3) Project [order_id#5L, user_id#6]
+- Generate replicaterows(sum#31L, order_id#5L, user_id#6), [order_id#5L, user_id#6], false, [order_id#5L, user_id#6]
   +- *(2) Filter (isnotnull(sum#31L) && (sum#31L > 0))
      +- *(2) HashAggregate(keys=[order_id#5L, user_id#6], functions=[sum(vcol#28L)], output=[order_id#5L, user_id#6, sum#31L])
         +- Exchange hashpartitioning(order_id#5L, user_id#6, 200)
            +- *(1) HashAggregate(keys=[order_id#5L, user_id#6], functions=[partial_sum(vcol#28L)], output=[order_id#5L, user_id#6, sum#33L])
               +- Union
                  :- LocalTableScan [vcol#28L, order_id#5L, user_id#6]
                  +- LocalTableScan [vcol#29L, order_id#14L, user_id#15]

As you can see, the EXCEPT ALL uses a replicatedrows function. This function was also added in 2.4.0 release and its main goal was to support EXCEPT ALL and INTERSECT ALL operations. In our case, this function will generate the triplets composed of (COUNT order_id, user_id) fields COUNT number of times. For instance, the query SELECT replicatedrows(2, 'a', 'b') will output:

(2, 'a', 'b')
(2, 'a', 'b')

The rest of the execution plan is very similar to the one of EXCEPT operator.

INTERSECT ALL

The INTERSECT operator does the opposite of EXCEPT, i.e. it returns the rows that are present in both datasets:

  "INTERSECT" should "return all rows present in both datasets" in {
    val rowsFromDataset1InDataset2 = orders1.intersect(orders2)

    val intersectResult = rowsFromDataset1InDataset2.collect().map(row => row.getAs[Long]("order_id"))

    intersectResult should have size 4
    intersectResult should contain allOf(1L, 2L, 3L, 4L)
    rowsFromDataset1InDataset2.explain(true)
  }

Here too, the execution is based on one of available join types. Unlike EXCEPT's anti join, INTERSECT uses left semi join which takes the rows present in both datasets. The final result contains only the values only from the left part of the query though:

== Optimized Logical Plan ==
Aggregate [order_id#5L, user_id#6], [order_id#5L, user_id#6]
+- Join LeftSemi, ((order_id#5L <=> order_id#14L) && (user_id#6 <=> user_id#15))
   :- LocalRelation [order_id#5L, user_id#6]
   +- LocalRelation [order_id#14L, user_id#15]

== Physical Plan ==
*(2) HashAggregate(keys=[order_id#5L, user_id#6], functions=[], output=[order_id#5L, user_id#6])
+- Exchange hashpartitioning(order_id#5L, user_id#6, 200)
   +- *(1) HashAggregate(keys=[order_id#5L, user_id#6], functions=[], output=[order_id#5L, user_id#6])
      +- *(1) BroadcastHashJoin [coalesce(order_id#5L, 0), coalesce(user_id#6, )], [coalesce(order_id#14L, 0), coalesce(user_id#15, )], LeftSemi, BuildRight, ((order_id#5L <=> order_id#14L) && (user_id#6 <=> user_id#15))
         :- LocalTableScan [order_id#5L, user_id#6]
         +- BroadcastExchange HashedRelationBroadcastMode(List(coalesce(input[0, bigint, false], 0), coalesce(input[1, string, true], )))
            +- LocalTableScan [order_id#14L, user_id#15]

As you may guess, INTERSECT ALL does the same as INTERSECT but keeps the duplicates:

  "INTERSECT ALL" should "return all rows present in both datasets and keeps the duplicates" in {
    val rowsFromDataset1InDataset2 = orders1.intersectAll(orders2)

    val intersectAllResult = rowsFromDataset1InDataset2.collect().map(row => row.getAs[Long]("order_id"))

    intersectAllResult should have size 5
    intersectAllResult should contain allElementsOf(Seq(1L, 2L, 3L, 3L, 4L))
    rowsFromDataset1InDataset2.explain(true)
  }

As for EXCEPT ALL, the generated plan doesn't involve any JOIN but replicaerows with different aggregation function:

== Optimized Logical Plan ==
Project [order_id#5L, user_id#6]
+- Generate replicaterows(min_count#43L, order_id#5L, user_id#6), [2], false, [order_id#5L, user_id#6]
   +- Project [order_id#5L, user_id#6, if ((vcol1_count#40L > vcol2_count#42L)) vcol2_count#42L else vcol1_count#40L AS min_count#43L]
      +- Filter ((vcol1_count#40L >= 1) && (vcol2_count#42L >= 1))
         +- Aggregate [order_id#5L, user_id#6], [count(vcol1#35) AS vcol1_count#40L, count(vcol2#38) AS vcol2_count#42L, order_id#5L, user_id#6]
            +- Union
               :- LocalRelation [vcol1#35, vcol2#38, order_id#5L, user_id#6]
               +- LocalRelation [vcol1#36, vcol2#37, order_id#14L, user_id#15]

== Physical Plan ==
*(3) Project [order_id#5L, user_id#6]
+- Generate replicaterows(min_count#43L, order_id#5L, user_id#6), [order_id#5L, user_id#6], false, [order_id#5L, user_id#6]
   +- *(2) Project [order_id#5L, user_id#6, if ((vcol1_count#40L > vcol2_count#42L)) vcol2_count#42L else vcol1_count#40L AS min_count#43L]
      +- *(2) Filter ((vcol1_count#40L >= 1) && (vcol2_count#42L >= 1))
         +- *(2) HashAggregate(keys=[order_id#5L, user_id#6], functions=[count(vcol1#35), count(vcol2#38)], output=[vcol1_count#40L, vcol2_count#42L, order_id#5L, user_id#6])
            +- Exchange hashpartitioning(order_id#5L, user_id#6, 200)
               +- *(1) HashAggregate(keys=[order_id#5L, user_id#6], functions=[partial_count(vcol1#35), partial_count(vcol2#38)], output=[order_id#5L, user_id#6, count#46L, count#47L])
                  +- Union
                     :- LocalTableScan [vcol1#35, vcol2#38, order_id#5L, user_id#6]
                     +- LocalTableScan [vcol1#36, vcol2#37, order_id#14L, user_id#15]

Thanks to EXCEPT ALL and INTERSECT ALL operators, Apache Spark SQL becomes more SQL-compliant. Prior to 2.4.0, we could only use their simpler versions that don't keep the duplicates. The 2.4.0 release added the possibility to accept duplicates. It's mainly possible thanks to the replicatedrows function that as shown at the end of each section, returns the same row n number of times.