The Local shuffle reader presented in one of the previous posts might have introduced some doubt in the way how the broadcast join is working. If it's the case, this blog post should shed some light on it. If not, it can give you more in-depth details than the ones introducing this type of join a few years ago.
A virtual conference at the intersection of Data and AI. This is not a conference for the hype. Its real users talking about real experiences.
- 40+ speakers with the likes of Hannes from Duck DB, Sol Rashidi, Joe Reis, Sadie St. Lawrence, Ryan Wolf from nvidia, Rebecca from lidl
- 12th September 2024
- Three simultaneous tracks
- Panels, Lighting Talks, Keynotes, Booth crawls, Roundtables and Entertainment.
- Topics include (ingestion, finops for data, data for inference (feature platforms), data for ML observability
- 100% virtual and 100% free
👉 Register here
It's not the first blog post about the broadcast join on the blog. Another one is broadcast join in Spark SQL but it gives a high-level view of the internals that the article you're reading now, will try to complete.
First, let's check a branch of the query plan representing the broadcast join for the query like SELECT * FROM input4 JOIN input5 ON input4.key = input5.key WHERE input4.value = '1':
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))), [id=#37] +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, TestEntryKV, true])).key AS key#13, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, TestEntryKV, true])).value, true, false) AS value#14] +- Scan[obj#12]
As you can see, there are 2 important nodes. The first one is BroadcastExchange which represents the side being broadcasted. The node has a child node which stores the way how the broadcasted part is computed. In the above example, the logic is quite straightforward since it's a simple data reading. And actually, that's what will be broadcasted (almost, but you'll see that later) from BroadcastExchangeExec class:
private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { // ... // Use executeCollect/executeCollectIterator to avoid conversion to Scala types val (numRows, input) = child.executeCollectIterator() // ... // Construct the relation. val relation = mode.transform(input, Some(numRows)) // ... // Broadcast the relation val broadcasted = sparkContext.broadcast(relation) // ... promise.trySuccess(broadcasted) broadcasted override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { try { relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
From the snippet you can already understand why it's "almost" from the previous paragraph. The broadcasted element is not the last built DataFrame but a HashedRelation. Why is it needed? Simply because without it, the part responsible for joining rows from 2 datasets wouldn't be able to find the matches. In other and more simple words, this HashedRelation can be thought as a tuple composed of the join key and the row to join. And depending on the key type, one of 2 HashedRelation implementations will be used.
LongHashedRelation will be used if the key is a long and UnsafeHashedRelation otherwise. How Apache Spark checks which one to use? The logic is based on the expressions used to build the join key defined in HashedRelation's apply method:
def apply( input: Iterator[InternalRow], key: Seq[Expression], sizeEstimate: Int = 64, taskMemoryManager: TaskMemoryManager = null): HashedRelation = { // ... if (key.length == 1 && key.head.dataType == LongType) { LongHashedRelation(input, key, sizeEstimate, mm) } else { UnsafeHashedRelation(input, key, sizeEstimate, mm) }
And regarding the rows to join retrieval, it happens in the generated code:
/* 019 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 020 */ partitionIndex = index; /* 021 */ this.inputs = inputs; /* 022 */ inputadapter_input_0 = inputs[0]; /* 023 */ /* 024 */ serializefromobject_mutableStateArray_1[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 32); /* 025 */ serializefromobject_mutableStateArray_1[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 32); /* 026 */ /* 027 */ bhj_relation_0 = ((org.apache.spark.sql.execution.joins.LongHashedRelation) ((org.apache.spark.broadcast.TorrentBroadcast) references[4] /* broadcast */).value()).a sReadOnlyCopy(); /* 028 */ incPeakExecutionMemory(bhj_relation_0.estimatedSize()); /* 029 */ /* 030 */ serializefromobject_mutableStateArray_1[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(4, 64); /* 031 */ /* 032 */ } // ... /* 089 */ // generate join key for stream side /* 090 */ boolean bhj_isNull_0 = serializefromobject_isNull_0; /* 091 */ long bhj_value_0 = -1L; /* 092 */ if (!serializefromobject_isNull_0) { /* 093 */ bhj_value_0 = (long) serializefromobject_value_0; /* 094 */ } /* 095 */ // find matches from HashedRelation /* 096 */ UnsafeRow bhj_matched_0 = bhj_isNull_0 ? null: (UnsafeRow)bhj_relation_0.getValue(bhj_value_0); /* 097 */ if (bhj_matched_0 != null) { /* 098 */ {
When the value() method is called from the code, the usual broadcast retrieval flow is executed where the caller retrieves the broadcast blocks from the driver and other executors.
Regarding the content of the broadcasted relation, it contains all records. If you check the apply methods of UnsafeHashedRelation or LongHashedRelation, you'll see that they both take the complete input dataset and that only the null keys can be discarded:
// LongHashedRelation var numFields = 0 while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) if (!rowKey.isNullAt(0)) { val key = rowKey.getLong(0) map.append(key, unsafeRow) } } // HashedRelation // Create a mapping of buildKeys -> rows val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] numFields = row.numFields() val key = keyGenerator(row) if (!key.anyNull) { val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) val success = loc.append( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) if (!success) { binaryMap.free() // scalastyle:off throwerror throw new SparkOutOfMemoryError("There is not enough memory to build hash map") // scalastyle:on throwerror } }
Below you can find a demo video with some breakpoints added to prove that the whole dataset is broadcasted:
If you were confused when reading the local shuffle reader post, this extra explanation about the broadcast join internals should shed some light on. Indeed, there can be an extra local shuffle reader optimization for the broadcast join transformed during the Adaptive Query Execution, but it's mostly to avoid shuffle when it has no reason to happen since the query will use broadcast join strategy in the next step. The broadcast join strategy itself stays simple and broadcasts the whole dataset eligible for broadcasting.