Broadcast join - complementary notes for local shuffle reader

Versions: Apache Spark 3.0.0

The Local shuffle reader presented in one of the previous posts might have introduced some doubt in the way how the broadcast join is working. If it's the case, this blog post should shed some light on it. If not, it can give you more in-depth details than the ones introducing this type of join a few years ago.

It's not the first blog post about the broadcast join on the blog. Another one is broadcast join in Spark SQL but it gives a high-level view of the internals that the article you're reading now, will try to complete.

First, let's check a branch of the query plan representing the broadcast join for the query like SELECT * FROM input4 JOIN input5 ON input4.key = input5.key WHERE input4.value = '1':

+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))), [id=#37]
   +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, TestEntryKV, true])).key AS key#13, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, TestEntryKV, true])).value, true, false) AS value#14]
      +- Scan[obj#12]

As you can see, there are 2 important nodes. The first one is BroadcastExchange which represents the side being broadcasted. The node has a child node which stores the way how the broadcasted part is computed. In the above example, the logic is quite straightforward since it's a simple data reading. And actually, that's what will be broadcasted (almost, but you'll see that later) from BroadcastExchangeExec class:

  private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
    SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
      sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
// ...
            // Use executeCollect/executeCollectIterator to avoid conversion to Scala types
            val (numRows, input) = child.executeCollectIterator()
// ...
            // Construct the relation.
            val relation = mode.transform(input, Some(numRows))
// ...
            // Broadcast the relation
            val broadcasted = sparkContext.broadcast(relation)
// ...
            promise.trySuccess(broadcasted)
            broadcasted

  override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
    try {
      relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]

From the snippet you can already understand why it's "almost" from the previous paragraph. The broadcasted element is not the last built DataFrame but a HashedRelation. Why is it needed? Simply because without it, the part responsible for joining rows from 2 datasets wouldn't be able to find the matches. In other and more simple words, this HashedRelation can be thought as a tuple composed of the join key and the row to join. And depending on the key type, one of 2 HashedRelation implementations will be used.

LongHashedRelation will be used if the key is a long and UnsafeHashedRelation otherwise. How Apache Spark checks which one to use? The logic is based on the expressions used to build the join key defined in HashedRelation's apply method:

  def apply(
      input: Iterator[InternalRow],
      key: Seq[Expression],
      sizeEstimate: Int = 64,
      taskMemoryManager: TaskMemoryManager = null): HashedRelation = {
// ...
    if (key.length == 1 && key.head.dataType == LongType) {
      LongHashedRelation(input, key, sizeEstimate, mm)
    } else {
      UnsafeHashedRelation(input, key, sizeEstimate, mm)
    }

And regarding the rows to join retrieval, it happens in the generated code:

/* 019 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 020 */     partitionIndex = index;
/* 021 */     this.inputs = inputs;
/* 022 */     inputadapter_input_0 = inputs[0];
/* 023 */
/* 024 */     serializefromobject_mutableStateArray_1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 32);
/* 025 */     serializefromobject_mutableStateArray_1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 32);
/* 026 */
/* 027 */     bhj_relation_0 = ((org.apache.spark.sql.execution.joins.LongHashedRelation) ((org.apache.spark.broadcast.TorrentBroadcast) references[4] /* broadcast */).value()).a
sReadOnlyCopy();
/* 028 */     incPeakExecutionMemory(bhj_relation_0.estimatedSize());
/* 029 */
/* 030 */     serializefromobject_mutableStateArray_1[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(4, 64);
/* 031 */
/* 032 */   }

// ...
/* 089 */       // generate join key for stream side
/* 090 */       boolean bhj_isNull_0 = serializefromobject_isNull_0;
/* 091 */       long bhj_value_0 = -1L;
/* 092 */       if (!serializefromobject_isNull_0) {
/* 093 */         bhj_value_0 = (long) serializefromobject_value_0;
/* 094 */       }
/* 095 */       // find matches from HashedRelation
/* 096 */       UnsafeRow bhj_matched_0 = bhj_isNull_0 ? null: (UnsafeRow)bhj_relation_0.getValue(bhj_value_0);
/* 097 */       if (bhj_matched_0 != null) {
/* 098 */         {

When the value() method is called from the code, the usual broadcast retrieval flow is executed where the caller retrieves the broadcast blocks from the driver and other executors.

Regarding the content of the broadcasted relation, it contains all records. If you check the apply methods of UnsafeHashedRelation or LongHashedRelation, you'll see that they both take the complete input dataset and that only the null keys can be discarded:

// LongHashedRelation
    var numFields = 0
    while (input.hasNext) {
      val unsafeRow = input.next().asInstanceOf[UnsafeRow]
      numFields = unsafeRow.numFields()
      val rowKey = keyGenerator(unsafeRow)
      if (!rowKey.isNullAt(0)) {
        val key = rowKey.getLong(0)
        map.append(key, unsafeRow)
      }
    }

// HashedRelation
    // Create a mapping of buildKeys -> rows
    val keyGenerator = UnsafeProjection.create(key)
    var numFields = 0
    while (input.hasNext) {
      val row = input.next().asInstanceOf[UnsafeRow]
      numFields = row.numFields()
      val key = keyGenerator(row)
      if (!key.anyNull) {
        val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
        val success = loc.append(
          key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
          row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
        if (!success) {
          binaryMap.free()
          // scalastyle:off throwerror
          throw new SparkOutOfMemoryError("There is not enough memory to build hash map")
          // scalastyle:on throwerror
        }
      }

Below you can find a demo video with some breakpoints added to prove that the whole dataset is broadcasted:

If you were confused when reading the local shuffle reader post, this extra explanation about the broadcast join internals should shed some light on. Indeed, there can be an extra local shuffle reader optimization for the broadcast join transformed during the Adaptive Query Execution, but it's mostly to avoid shuffle when it has no reason to happen since the query will use broadcast join strategy in the next step. The broadcast join strategy itself stays simple and broadcasts the whole dataset eligible for broadcasting.


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!