Local deduplication or dropDuplicates?

Versions: Apache Spark 2.4.2

One of the points I wanted to cover during my talk but for which I haven't enough time, was the dilemma about using a local deduplication or Apache Spark's dropDuplicates method to not integrate duplicated logs. That will be the topic of this post.

In this post I will analyze what happens for a query using a JOIN and dropDuplicates, and another query using a local deduplication method.

Plan for dropDuplicates and join

In my example I will take a simpler code than for the sessionization solution and try to compute a sum:

case class JoinedData(id: Option[Int], user: Option[String], value: Option[Int],
                      user_previous: Option[String], sum: Option[Int])
case class OutputData(user_previous: String, sum: Int)

object Mapping {
  def computeSum(key: String, joinedData: Iterator[JoinedData]): OutputData = {
    val head = joinedData.next()
    val newSum = joinedData.map(input => input.value.getOrElse(0)).sum + head.value.getOrElse(0)
    OutputData(key, newSum + head.sum.getOrElse(0))
  }
}

In the first method, I'm deduplicating the input logs with dropDuplicates:

    newInput.dropDuplicates("id")
      .join(previousData, newInput("user") === previousData("user_previous"), "full_outer").as[JoinedData]
      .groupByKey(data => data.user.getOrElse(data.user_previous.get))
      .mapGroups(Mapping.computeSum)
      .show(20, false)

The returned result is correct:

+-------------+---+
|user_previous|sum|
+-------------+---+
|user1        |19 |
|user3        |20 |
|user2        |6  |
+-------------+---+

But the execution plan is a little bit heavy:

*(6) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, com.waitingforcode.sql.OutputData, true]).user_previous, true, false) AS user_previous#48, assertnotnull(input[0, com.waitingforcode.sql.OutputData, true]).sum AS sum#49]
+- MapGroups , value#44.toString, newInstance(class com.waitingforcode.sql.JoinedData), [value#44], [id#7, user#8, value#9, user_previous#18, sum#19], obj#47: com.waitingforcode.sql.OutputData
   +- *(5) Sort [value#44 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(value#44, 200) 
         +- AppendColumns , newInstance(class com.waitingforcode.sql.JoinedData), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#44]
            +- SortMergeJoin [user#8], [user_previous#18], FullOuter
               :- *(3) Sort [user#8 ASC NULLS FIRST], false, 0
               :  +- Exchange hashpartitioning(user#8, 200)
               :     +- SortAggregate(key=[id#7], functions=[first(user#8, false), first(value#9, false)], output=[id#7, user#8, value#9])
               :        +- *(2) Sort [id#7 ASC NULLS FIRST], false, 0
               :           +- Exchange hashpartitioning(id#7, 200)
               :              +- SortAggregate(key=[id#7], functions=[partial_first(user#8, false), partial_first(value#9, false)], output=[id#7, first#58, valueSet#59, first#60, valueSet#61])
               :                 +- *(1) Sort [id#7 ASC NULLS FIRST], false, 0
               :                    +- LocalTableScan [id#7, user#8, value#9]
               +- *(4) Sort [user_previous#18 ASC NULLS FIRST], false, 0
                  +- Exchange hashpartitioning(user_previous#18, 200)
                     +- LocalTableScan [user_previous#18, sum#19]

As you can see, Apache Spark first shuffles the data in order to gather all input logs by their ids on the same partition and to keep only the first one (SortAggregate(key=[id#7], functions=[first(user#8, false), first(value#9, false)]). Later the framework performs another shuffle, this time to handle the join (Exchange hashpartitioning(user#8, 200) and Exchange hashpartitioning(user_previous#18, 200)). This double shuffle would be less problematic if we used the same key for dropDuplicates and joins but it's rather a not realistic requirement if, for instance, you want to deduplicate by an event-time but keep the data grouped by another attribute. And finally, the last shuffle is made for the groupBy expression.

Local dedupe

A much simpler execution plan exists for a local dedupe:

  def computeSumLocal(key: String, joinedData: Iterator[JoinedData]): OutputData = {
    val distinctLogs = joinedData.toSet
    val newSum = distinctLogs.map(input => input.value.getOrElse(0)).sum  
    OutputData(key, newSum + distinctLogs.head.sum.getOrElse(0))
  }
    newInput.join(previousData, newInput("user") === previousData("user_previous"), "full_outer").as[JoinedData]
      .groupByKey(data => data.user.getOrElse(data.user_previous.get))
      .mapGroups(Mapping.computeSumLocal)
      .show(20, false)

I'm still getting the correct results:

+-------------+---+
|user_previous|sum|
+-------------+---+
|user1        |19 |
|user3        |20 |
|user2        |6  |
+-------------+---+

But this time the execution plan is much simpler:

== Physical Plan ==
*(4) SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, com.waitingforcode.sql.OutputData, true]).user_previous, true, false) AS user_previous#88, assertnotnull(input[0, com.waitingforcode.sql.OutputData, true]).sum AS sum#89]
+- MapGroups , value#84.toString, newInstance(class com.waitingforcode.sql.JoinedData), [value#84], [id#7, user#8, value#9, user_previous#18, sum#19], obj#87: com.waitingforcode.sql.OutputData
   +- *(3) Sort [value#84 ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(value#84, 200)
         +- AppendColumns , newInstance(class com.waitingforcode.sql.JoinedData), [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, java.lang.String, true], true, false) AS value#84]
            +- SortMergeJoin [user#8], [user_previous#18], FullOuter
               :- *(1) Sort [user#8 ASC NULLS FIRST], false, 0
               :  +- Exchange hashpartitioning(user#8, 200)
               :     +- LocalTableScan [id#7, user#8, value#9]
               +- *(2) Sort [user_previous#18 ASC NULLS FIRST], false, 0
                  +- Exchange hashpartitioning(user_previous#18, 200)
                     +- LocalTableScan [user_previous#18, sum#19]

This time the data is shuffled only twice and so for the join and group operations that are the real keys of this processing logic. The dedupe is made locally, at the level of every group. The only drawback of the snippet is the need to materialize the data. Just to recall, the difference between an iterator and a materialized collection is that the iterator returns one value at a time and forgets about it. On the other hand, a materialized collection stores all values which can make a stronger memory pressure than the iteration. However, that overhead will probably be much weaker than for its distributed version of dropDuplicates. And besides that, you can also try to control duplicates by simply remembering the keys aside and check every time whether the current row is already on the keys map. But, let's keep things simple first and optimize later if needed.

And the preference for a single shuffle was the main reason why I preferred a local dedupe over the distributed dropDuplicates in my sample code. But that's my case and yours can have a different context. Maybe you will use dropDuplicates because you are not comfortable with writing some Scala or Python code and it can be an acceptable trade-off.


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!