Distinct vs group by key difference

Versions: Apache Spark 3.2.0

I've heard an opinion that using DISTINCT can have a negative impact on big data workloads, and that the queries with GROUP BY were more performant. Is it true for Apache Spark SQL?

Remove duplicates example

Let's check it out with these 2 queries:

    (0 to 10).map(id => (s"id#${id}", s"login${id % 25}"))
      .toDF("id", "login").createTempView("users")
    sparkSession.sql("SELECT login FROM users GROUP BY login").explain(true)
    sparkSession.sql("SELECT DISTINCT(login) FROM users").explain(true)

The execution plans are the same and look like:

== Physical Plan ==
*(2) HashAggregate(keys=[login#8], functions=[], output=[login#8])
+- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#33]
   +- *(1) HashAggregate(keys=[login#8], functions=[], output=[login#8])
      +- *(1) LocalTableScan [login#8]

It happens because Apache Spark has a logical optimization rule called ReplaceDistinctWithAggregate that will transform an expression with distinct keyword by an aggregation:

object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case Distinct(child) => Aggregate(child.output, child.output, child)
  }
}

You should also see it in the logs with the following entry:

=== Applying Rule org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate ===
!Distinct                     Aggregate [login#8], [login#8]
 +- LocalRelation [login#8]   +- LocalRelation [login#8]
           (org.apache.spark.sql.catalyst.rules.PlanChangeLogger:65)

Thereby, DISTINCT and GROUP BY in this simple context of selecting unique values for a column, execute the same way, as an aggregation. And what if instead of them, we would use a dropDuplicates()?

sparkSession.sql("SELECT login FROM users").dropDuplicates("login").explain(true)

Well, this time too, the execution plan remains the same!

== Physical Plan ==
*(2) HashAggregate(keys=[login#8], functions=[], output=[login#8])
+- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#16]
   +- *(1) HashAggregate(keys=[login#8], functions=[], output=[login#8])
      +- *(1) LocalTableScan [login#8]

Aggregates example

What happens when it comes to more complex queries, like the ones involving the aggregations? Since the engine replaces the DISTINCT operation with an aggregation, you will see an extra shuffle step! Let's compare it with:

sparkSession.sql("SELECT COUNT(login) FROM users GROUP BY login").explain(true)
sparkSession.sql("SELECT COUNT(DISTINCT(login)) FROM users").explain(true)

The execution plan for the first query looks like that:

== Physical Plan ==
*(2) HashAggregate(keys=[login#8], functions=[count(login#8)], output=[count(login)#12L])
+- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#16]
   +- *(1) HashAggregate(keys=[login#8], functions=[partial_count(login#8)], output=[login#8, count#15L])
      +- *(1) LocalTableScan [login#8]

The second plan has an extra shuffle step:

== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[count(distinct login#8)], output=[count(DISTINCT login)#17L])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#48]
   +- *(2) HashAggregate(keys=[], functions=[partial_count(distinct login#8)], output=[count#21L])
      +- *(2) HashAggregate(keys=[login#8], functions=[], output=[login#8])
         +- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#43]
            +- *(1) HashAggregate(keys=[login#8], functions=[], output=[login#8])
               +- *(1) LocalTableScan [login#8]

However, this difference is normal because both queries are semantically different! The first query counts the number of unique logins in the dataset and assigns them to the "login" groups. The second query does this step too, but it returns a single value. So it must accumulate the partial results from the executors and do some shuffle. If you adapt the first query to this single-value constraint, you will see that the generated plan will also have 2 shuffles:

sparkSession.sql("SELECT COUNT(*) FROM (SELECT COUNT(login) FROM users GROUP BY login)").explain(true)

And the plan:

== Physical Plan ==
*(3) HashAggregate(keys=[], functions=[count(1)], output=[count(1)#14L])
+- Exchange SinglePartition, ENSURE_REQUIREMENTS, [id=#32]
   +- *(2) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#17L])
      +- *(2) HashAggregate(keys=[login#8], functions=[], output=[])
         +- Exchange hashpartitioning(login#8, 200), ENSURE_REQUIREMENTS, [id=#27]
            +- *(1) HashAggregate(keys=[login#8], functions=[], output=[login#8])
               +- *(1) LocalTableScan [login#8]

Self-join example

I found one more complex query using DISTINCT in the "SQL Cookbook: Query Solutions and Techniques for All SQL Users" by Anthony Molinaro and Robert De Graaf. The goal of the following query is to get the reciprocal rows under some condition. Alongside the DISTINCT, the condition guarantees that a single pair of reciprocals will be returned:

    Seq((20, 20), (50, 25), (70, 90), (90, 70), (90, 70)).toDF("v1", "v2").createTempView("scores")
    sparkSession.sql("SELECT DISTINCT s1.* FROM scores s1, scores s2 " +
      "WHERE s1.v1 = s2.v2 AND s1.v2 = s2.v1 AND s1.v1 <= s1.v2").explain(true)

Apache Spark transforms this query into a join and aggregation:

== Physical Plan ==
*(2) HashAggregate(keys=[v1#29, v2#30], functions=[], output=[v1#29, v2#30])
+- Exchange hashpartitioning(v1#29, v2#30, 200), ENSURE_REQUIREMENTS, [id=#43]
   +- *(1) HashAggregate(keys=[v1#29, v2#30], functions=[], output=[v1#29, v2#30])
      +- *(1) Project [v1#29, v2#30]
         +- *(1) BroadcastHashJoin [v1#29, v2#30], [v2#34, v1#33], Inner, BuildLeft, false
            :- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [id=#31]
            :  +- LocalTableScan [v1#29, v2#30]
            +- *(1) LocalTableScan [v1#33, v2#34]

If you check the logs, you will see the ReplaceDistinctWithAggregate applied again. Logically then, the same query using GROUP BY for the deduplication should have the same execution plan. Let's check:

    sparkSession.sql("SELECT s1.* FROM scores s1, scores s2 " +
          "WHERE s1.v1 = s2.v2 AND s1.v2 = s2.v1 AND s1.v1 <= s1.v2 GROUP BY s1.v1, s1.v2").explain(true)

...and the plan is strictly the same:

== Physical Plan ==
*(2) HashAggregate(keys=[v1#29, v2#30], functions=[], output=[v1#29, v2#30])
+- Exchange hashpartitioning(v1#29, v2#30, 200), ENSURE_REQUIREMENTS, [id=#94]
   +- *(1) HashAggregate(keys=[v1#29, v2#30], functions=[], output=[v1#29, v2#30])
      +- *(1) Project [v1#29, v2#30]
         +- *(1) BroadcastHashJoin [v1#29, v2#30], [v2#38, v1#37], Inner, BuildLeft, false
            :- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, false] as bigint), 32) | (cast(input[1, int, false] as bigint) & 4294967295))),false), [id=#82]
            :  +- LocalTableScan [v1#29, v2#30]
            +- *(1) LocalTableScan [v1#37, v2#38]

Long story short, under-the-hood Apache Spark runs the distinct operation as a group by, with the same plans. I haven't found any example using a different execution plan, but if you have a different experience, we will be very thankful for sharing!


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!