RPC in Apache Spark

Versions: Spark 2.2.1

The communication in distributed systems is an important element. The cluster members rarely share the hardware components and the single solution to communicate is the exchange of messages in the client-server model.

A virtual conference at the intersection of Data and AI. This is not a conference for the hype. Its real users talking about real experiences.
- 40+ speakers with the likes of Hannes from Duck DB, Sol Rashidi, Joe Reis, Sadie St. Lawrence, Ryan Wolf from nvidia, Rebecca from lidl
- 12th September 2024
- Three simultaneous tracks
- Panels, Lighting Talks, Keynotes, Booth crawls, Roundtables and Entertainment.
- Topics include (ingestion, finops for data, data for inference (feature platforms), data for ML observability
- 100% virtual and 100% free

👉 Register here

This post explains how this kind of communication is implemented in Apache Spark. Its first part defines in the big picture what is the RPC. Next part shows how it's implemented in Apache Spark. The third section gives some configuration input while the last shows a sample communication between cluster members.

Definition

The RPC is an acronym for Remote Procedure Call. It's an protocol using client-server model. When the client executes a request, it's sent to the place called stub. The stub has the knowledge about the server able to execute the request as well as whole context needed by the server (e.g. parameters). When the request finally arrives to the appropriate server, it also reaches a stub in the server side. So captured request is later translated to the server-side executable procedure. After its physical execution, the result is sent back to the client.

An example of schema is represented in the following image:

Implementation

The RPC in Apache Spark is implemented with the help of Netty client-server framework. The most of valuable RPC classes are stored in the org.apache.spark.rpc package but before describing them, let's explain how they interact together.

The object responsible for sending the messages to the appropriate endpoint (client stub) is represented by Dispatcher class. Via one of post* methods it prepares the instance of RPC message (RpcMessage class) and sends it to the expected endpoint.

The RPC endpoints are represented by 2 classes. The first one, RpcEndpoint is the physical representation that could be compared to the client's and server's stubs. This trait defines 3 different methods: onStart, receive and onStop. As these names let suppose, the first and the third ones are invoked when the endpoint starts and stops. The second method sends either requests or responses. The second class used with RPC is RpcEndpointRef. It's main role is to send the requests in one of available semantics: fire-and-forget (send(message: Any) method), synchronous request-response (askSync[T: ClassTag](message: Any): T) or asynchronous request-response (ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]).

Among the implementations of RPC endpoints we can find the following classes:

Configuration

The RPC is not only about endpoints but also about the configuration. Spark accepts the following configuration properties:

Spark RPC example

In order to discover what happens with RPC we'll do a small trick and override the mentioned Dispatcher object by this one:

package org.apache.spark.rpc.netty

import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef}

import scala.concurrent.Promise

class LoggingDispatcher(delegate: Dispatcher, nettyEnv: NettyRpcEnv) extends Dispatcher(nettyEnv) {

  var endpoints: Seq[RpcEndpointRef] = Seq.empty

  var messageTypes: Seq[String] = Seq.empty

  override def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = 
    delegate.registerRpcEndpoint(name, endpoint)

  override def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = delegate.getRpcEndpointRef(endpoint)

  override def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = delegate.removeRpcEndpointRef(endpoint)

  override def stop(rpcEndpointRef: RpcEndpointRef): Unit = {
    endpoints = endpoints :+ rpcEndpointRef
    delegate.stop(rpcEndpointRef)
  }

  override def postToAll(message: InboxMessage): Unit = delegate.postToAll(message)

  override def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
    delegate.postRemoteMessage(message, callback)
    handleMessage(message)
  }

  override def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
    delegate.postLocalMessage(message, p)
    handleMessage(message)
  }

  override def postOneWayMessage(message: RequestMessage): Unit = {
    handleMessage(message)
    delegate.postOneWayMessage(message)
  }

  private def handleMessage(message: RequestMessage) = messageTypes = messageTypes :+ message.content.getClass.toString

  override def stop(): Unit = delegate.stop()

  override def awaitTermination(): Unit = delegate.awaitTermination()

  override def verify(name: String): Boolean = delegate.verify(name)

}

As you can see, it doesn't nothing more than memorizing the handled messages and stopped endpoints. Both properties are check in the learning test using a simple sum action:

"custom logging RPC dispatcher" should "show what messages are sent between nodes" in {
  val envField = ReflectUtil.getDeclaredField(classOf[SparkContext], "org$apache$spark$SparkContext$$_env")
  val sparkEnv: SparkEnv = envField.get(sparkContext).asInstanceOf[SparkEnv]
  val rpcEnvField = ReflectUtil.getDeclaredField(classOf[SparkEnv], "rpcEnv")
  val rpcEnv = rpcEnvField.get(sparkEnv).asInstanceOf[NettyRpcEnv]
  val dispatcherField = ReflectUtil.getDeclaredField(classOf[NettyRpcEnv], "dispatcher")
  val dispatcher = dispatcherField.get(rpcEnv).asInstanceOf[Dispatcher]
  val loggingDispatcher = new LoggingDispatcher(dispatcher, rpcEnv)
  // Finally override the dispatcher
  dispatcherField.set(rpcEnv, loggingDispatcher)

  val numbersRdd = sparkContext.parallelize(1 to 100, 2)
  val sum = numbersRdd.sum()
  sparkContext.stop()

  sum shouldEqual 5050
  val endpointNames = loggingDispatcher.endpoints.map(endpoint => endpoint.name)
  endpointNames should contain allOf("HeartbeatReceiver", "MapOutputTracker", "BlockManagerEndpoint1",
    "BlockManagerMaster", "OutputCommitCoordinator")
  loggingDispatcher.messageTypes should contain allOf(
    "class org.apache.spark.storage.BlockManagerMessages$UpdateBlockInfo",
    "class org.apache.spark.scheduler.local.ReviveOffers$", "class org.apache.spark.scheduler.local.StatusUpdate",
    "class org.apache.spark.scheduler.local.StopExecutor$",
    "class org.apache.spark.StopMapOutputTracker$",
    "class org.apache.spark.storage.BlockManagerMessages$StopBlockManagerMaster$",
    "class org.apache.spark.scheduler.StopCoordinator$")
}

// While the getDeclaredField method is a simple getter for Spark's private fields:
def getDeclaredField[T](fieldClass: Class[T], fieldName: String): Field = {
  val searchedField = fieldClass.getDeclaredField(fieldName)
  searchedField.setAccessible(true)
  searchedField
}

RPC is used in the communication between 2 remote nodes. As shown in this post, it's also used in Apache Spark - mainly for the driver-executor and master-slave synchronization. But, as we could discover in the 2nd section, the RPC is also about block management, heartbeats and streaming aggregations. This protocol is not neglected as prooven in the third part. Incorrectly configured it can have a negative impact on the performance.