Unions in Apache Spark SQL

Versions: Apache Spark 2.4.5

You have 2 different datasets and want to process them as a single unit? Maybe you have some legacy data that you need to process alongside the brand new dataset? JOIN is not an option because the goal is to build a single processing unit and not combine the rows. UNION operation can be a good fit for that.

In this blog post I will explore the UNIONs feature in Apache Spark SQL.

Union types

The first thing to notice is that Apache Spark exposes 3 and not 2 UNION types that we could meet in relational databases. Indeed, we still retrieve a UNION and UNION ALL operations but there is an extra one called UNION by name. It behaves exactly like UNION ALL except the fact that it resolves columns by name and not by the position! Below you can find a difference between these 2 types:

  "union by name" should "build unions by columns and union by position" in {
    import sparkSession.implicits._

    val legacyDataset = (1 to 3).map(nr => (s"user${nr}", nr)).toDF("login", "id")
    val newDataset = (3 to 5).map(nr => (nr, s"user${nr}")).toDF("id", "login")

    val unionByNames = legacyDataset.unionByName(newDataset).as[(String, Int)].collect()
    val unionByPositions = legacyDataset.union(newDataset).as[(String, String)].collect()

    unionByNames should contain allElementsOf (Seq(("user1", 1), ("user2", 2), ("user3", 3), ("user3", 3),
      ("user4", 4), ("user5", 5)
    ))
    unionByPositions should contain allElementsOf (Seq(("user1", "1"), ("user2", "2"), ("user3", "3"),
      ("3", "user3"), ("4", "user4"), ("5", "user5")
    ))
  }

The last type, UNION, will deduplicate the same rows:

  "union" should "deduplicate the same rows" in {
    val legacyDataset = (1 to 3).map(nr => (nr, s"user${nr}")).toDF("id", "login")
    val newDataset = (3 to 5).map(nr => (nr, s"user${nr}")).toDF("id", "login")

    val uniqueRows = legacyDataset.union(newDataset).distinct().as[(Int, String)].collect()

    uniqueRows should have size 5
    uniqueRows should contain allOf((1, "user1"), (2, "user2"), (3, "user3"),
      (4, "user4"), (5, "user5"))
  }

UNION ALL internals

Let's first check how UNION ALL executes. Internally, this type of UNION is represented by...yes, Union logical node ;) Regarding the physical execution, it's executed by UnionExec operator:

      case logical.Union(unionChildren) =>
        execution.UnionExec(unionChildren.map(planLater)) :: Nil

What does UnionExec do? It calls low-level union method of SparkContext class that looks like:

  /** Build the union of a list of RDDs. */
  def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope {
    val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty)
    val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet
    if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) {
      new PartitionerAwareUnionRDD(this, nonEmptyRdds)
    } else {
      new UnionRDD(this, nonEmptyRdds)
    }
  }

UnionRDD is a simple container of all combined RDDs, so unlike my initial assumptions, there is no shuffle and the processing is executed per partition for every combined RDD:

override def compute(s: Partition, context: TaskContext): Iterator[T] = {
    val part = s.asInstanceOf[UnionPartition[T]]
    parent[T](part.parentRddIndex).iterator(part.parentPartition, context)
  }

And that's why the physical plan for UNION ALL looks like that:

== Physical Plan ==
Union
:- LocalTableScan [id#5, login#6]
+- LocalTableScan [id#14, login#15]

UNION internals

Let's see now what happens for a UNION operation, so the ones handling duplicates. And who says deduplication in Apache Spark, but in distributed computing in general, also says shuffle because there is no other simple way to keep a single row for every key. The physical plan for union(...).distinct() shows that pretty well:

== Physical Plan ==
*(2) HashAggregate(keys=[id#5, login#6], functions=[], output=[id#5, login#6])
+- Exchange hashpartitioning(id#5, login#6, 200)
   +- *(1) HashAggregate(keys=[id#5, login#6], functions=[], output=[id#5, login#6])
      +- Union
         :- LocalTableScan [id#5, login#6]
         +- LocalTableScan [id#14, login#15]

As you can see, the shuffle stage is represented by the Exchange node computing the partitioning key from all columns involved in the UNION. Under-the-hood, the distinctiveness is provided by dropDuplicates method that returns a logical node of Deduplicate(groupCols, logicalPlan), and you can see it with the parsed logical plan:

== Parsed Logical Plan ==
Deduplicate [id#5, login#6]
+- Union
   :- Project [_1#2 AS id#5, _2#3 AS login#6]
   :  +- LocalRelation [_1#2, _2#3]
   +- Project [_1#11 AS id#14, _2#12 AS login#15]
      +- LocalRelation [_1#11, _2#12]

Later, after the query optimization stage, the Deduplicate node becomes an Aggregate:

== Optimized Logical Plan ==
Aggregate [id#5, login#6], [id#5, login#6]
+- Union
   :- LocalRelation [id#5, login#6]
   +- LocalRelation [id#14, login#15]

and thanks to that, the Aggregate is later transformed into a HashAggregate. The rule responsible for transforming Deduplicate into Aggregate is called ReplaceDeduplicateWithAggregate:

=== Applying Rule org.apache.spark.sql.catalyst.optimizer.ReplaceDeduplicateWithAggregate ===
!Deduplicate [id#5, login#6]                         Aggregate [id#5, login#6], [id#5, login#6]
 +- Union                                            +- Union
    :- Project [_1#2 AS id#5, _2#3 AS login#6]          :- Project [_1#2 AS id#5, _2#3 AS login#6]
    :  +- LocalRelation [_1#2, _2#3]                    :  +- LocalRelation [_1#2, _2#3]
    +- Project [_1#11 AS id#14, _2#12 AS login#15]      +- Project [_1#11 AS id#14, _2#12 AS login#15]
       +- LocalRelation [_1#11, _2#12]                     +- LocalRelation [_1#11, _2#12]
                 (org.apache.spark.sql.internal.BaseSessionStateBuilder$$anon$2:62) 

The responsibility for the physical aggregation belongs to Aggregation physical plan strategy.

UNION by name

Finally, let's check our UNION by name. Here, all logic is done after the parsing stage. Let's see the plan produced by this stage before going into the internal details:

== Parsed Logical Plan ==
Union
:- Project [_1#2 AS login#5, _2#3 AS id#6]
:  +- LocalRelation [_1#2, _2#3]
+- Project [login#15, id#14]
   +- Project [_1#11 AS id#14, _2#12 AS login#15]
      +- LocalRelation [_1#11, _2#12]

What? As you can see, even though we explicitly set the order of columns, Apache Spark figured out that we also expected the UNION by name and not by position, and reordered not matching projections. The whole logic is done inside unionByName(other: Dataset[T]) method where first, the framework tries to find all the columns from the right dataset matching the left dataset:

    val rightProjectList = leftOutputAttrs.map { lattr =>
      rightOutputAttrs.find { rattr => resolver(lattr.name, rattr.name) }.getOrElse {
        throw new AnalysisException(
          s"""Cannot resolve column name "${lattr.name}" among """ +
            s"""(${rightOutputAttrs.map(_.name).mkString(", ")})""")
      }
    }

and later all extra columns from the right dataset are appended to the new projection node:

    // Delegates failure checks to `CheckAnalysis`
    val notFoundAttrs = rightOutputAttrs.diff(rightProjectList)
    val rightChild = Project(rightProjectList ++ notFoundAttrs, other.logicalPlan)

    // This breaks caching, but it's usually ok because it addresses a very specific use case:
    // using union to union many files or partitions.
    CombineUnions(Union(logicalPlan, rightChild))

So as you can deduce, if the left dataset has any extra column regarding the right one, the query will fail. But the opposite is also true, except that the error message won't be the same! You can find these 2 situations in the following test case:

  "union by name" should "fail if the columns aren't the same" in {
    val legacyDataset = (1 to 3).map(nr => (s"user${nr}", nr, 40)).toDF("login", "id", "age")
    val newDataset = (3 to 5).map(nr => (nr, s"user${nr}")).toDF("id", "login")

    val sameNumberOfColumnsError = intercept[AnalysisException] {
      newDataset.unionByName(legacyDataset).show(false)
    }
    sameNumberOfColumnsError.getMessage() should startWith("Union can only be performed on tables with the same " +
      "number of columns, but the first table has 2 columns and the second table has 3 columns")

    val resolutionError = intercept[AnalysisException] {
      legacyDataset.unionByName(newDataset).show(false)
    }
    resolutionError.getMessage() should startWith("Cannot resolve column name \"age\" among (id, login)")
  }

Despite their apparent simplicity, UNIONs in Apache Spark are quite an interesting topic. The first thing to notice is that despite the fact of combining 2 datasets, they're not require shuffle, or at least not for all the types. Also, you could learn that Apache Spark comes with an operation that lets us execute the UNIONs not per position of the columns but by their names!

If you liked it, you should read: