GraphX and fault-tolerance

Versions: Apache Spark GraphX 2.4.0 https://github.com/bartosz25/spark-...aitingforcode/graphx/faulttolerance

Bad things happen in distributed data processing and if we're prepared for them, it's better. To prevent against such issues Apache Spark is able to recompute failed partition but also to store the computation snapshot as a checkpoint. Both properties apply to GraphX module's fault-tolerance mechanism.

This post shows how GraphX ensures fault-tolerance. Its 2 first part talks about checkpoint mechanism which we should know from RDD and streaming context. The last one is about GraphX self-recovery process.

GraphX checkpointing

Without delving into the details we can say that checkpointing consists on persisting RDD at given computation stage. To reduce the size of the data to store, the RDD doesn't contain the references to its parents. In order to make it work, we must define checkpoint directory as one of SparkContext's properties:

  private def TestSparkContext(withCheckpoint: Boolean) = {
    val context = SparkContext.getOrCreate(new SparkConf().setAppName("GraphX checkpoint")
      .setMaster("local[*]"))
    if (withCheckpoint) context.setCheckpointDir("/tmp/graphx-checkpoint")
    context
  }
  it should "fail when the checkpoint directory is not defined" in {
    val checkpointError = intercept[SparkException] {
      graph(false).checkpoint()
    }

    checkpointError.getMessage should include ("Checkpoint directory has not been set in the SparkContext")
  }

  private def graph(withCheckpoint: Boolean) = {
    val vertices = TestSparkContext(withCheckpoint).parallelize(
      Seq((1L, Friend("A", 20)), (2L, Friend("B", 21)), (3L, Friend("C", 22)))
    )
    val edges = TestSparkContext(withCheckpoint).parallelize(
      Seq(Edge(1L, 2L, RelationshipAttributes(0L, 1)), Edge(1L, 3L, RelationshipAttributes(0L, 0)))
    )

    Graph(vertices, edges)
  }

Aside from showing the failure when the checkpoint directory is not specified, the above code also shows how to use checkpointing. The checkpoint() method triggers checkpointing operation for given graph. It must be called before any operation applied to that graph. You must also know that it's blocking, i.e. the engine will do nothing with the checkpointed graph as long as the persistence operation doesn't terminate:

  it should "be blocking operation" in {
    val graphWithCheckpointDir = graph(true)

    val graphWithMappedVertices = graphWithCheckpointDir.mapVertices {
      case (vertexId, vertexValue) => vertexValue.copy(name = s"Copy ${vertexValue.name}")
    }

    // Here we want to checkpoint the graph after mapping vertices
    val beforeCheckpoint = System.currentTimeMillis()
    graphWithMappedVertices.checkpoint()
    val afterCheckpoint = System.currentTimeMillis()

    graphWithMappedVertices.collectEdges(EdgeDirection.Either)

    afterCheckpoint > beforeCheckpoint shouldBe true
  }

Therefore, the checkpoint has an impact on the execution time. As it was shown in “Graph Processing in a Distributed Dataflow Framework" article, the checkpoint takes time. But it helps to recover quickly from temporary failures - in the article the code recovered from the checkpoint executed much faster (631 seconds) than the same code recovered from the recomputation (760 seconds, the 2nd solution described in this post).

As you maybe already know, the graph is represented by 2 RDDs: vertices and edges. The checkpoint method calls underneath the checkpoint methods for both of them:

  override def checkpoint(): Unit = {
    vertices.checkpoint()
    replicatedVertexView.edges.checkpoint()
  }

The 2 checkpoint() methods are delegated to RDD's implementation. Hence, restoring a graph should work exactly the same as restoring a simple RDD:

  it should "restore the graph from checkpoint" in {
    // Please notice that GraphX advises to use Pregel for iterative computation
    // However to not introduce too many concepts at once I'll use it more naive implementation
    val accumulatedVertices = new scala.collection.mutable.HashMap[String, Seq[String]]
    val checkpoints = new scala.collection.mutable.HashSet[String]()
    val graphWithCheckpointDir = graph(true)
    var currentGraph = graphWithCheckpointDir
    for (i <- 1 until 5) {
      val graphWithMappedVertices = currentGraph.mapVertices {
        case (vertexId, vertexValue) => {
          if (i == 5) {
            throw new RuntimeException("Expected failure")
          }
          vertexValue.copy(name = s"Copy ${vertexValue.name}_${i}")
        }
      }

      // Here we want to checkpoint the graph after mapping vertices
      graphWithMappedVertices.checkpoint()
      currentGraph.getCheckpointFiles.foreach(checkpointFile => checkpoints.add(checkpointFile))
      if (i == 5) {
        intercept[SparkException] {
          graphWithMappedVertices.vertices.collectAsMap()
        }
      } else {
        accumulatedVertices.put(s"run1_${i}",
          graphWithMappedVertices.vertices.collect().map {
            case (vertexId, friendVertex) => friendVertex.name
          })
      }
      currentGraph = graphWithMappedVertices
    }
    for (i <- 4 to 10) {
      val graphWithMappedVertices = currentGraph.mapVertices {
        case (vertexId, vertexValue) => vertexValue.copy(name = s"Copy ${vertexValue.name}_${i}")
      }
      graphWithMappedVertices.checkpoint()
      currentGraph.getCheckpointFiles.foreach(checkpointFile => checkpoints.add(checkpointFile))
      accumulatedVertices.put(s"run2_${i}",
        graphWithMappedVertices.vertices.collect().map {
          case (vertexId, friendVertex) => friendVertex.name
        })
      currentGraph = graphWithMappedVertices
    }

    accumulatedVertices should have size 11
    accumulatedVertices.keys should contain allOf("run1_1", "run1_2", "run1_3", "run1_4",
      "run2_4", "run2_5", "run2_6", "run2_7", "run2_8", "run2_9", "run2_10")
    accumulatedVertices("run1_1") should contain allOf("Copy A_1", "Copy B_1", "Copy C_1")
    accumulatedVertices("run1_2") should contain allOf("Copy Copy A_1_2", "Copy Copy B_1_2", "Copy Copy C_1_2")
    accumulatedVertices("run1_3") should contain allOf("Copy Copy Copy A_1_2_3", "Copy Copy Copy B_1_2_3",
      "Copy Copy Copy C_1_2_3")
    accumulatedVertices("run1_4") should contain allOf("Copy Copy Copy Copy A_1_2_3_4", "Copy Copy Copy Copy B_1_2_3_4",
      "Copy Copy Copy Copy C_1_2_3_4")
    accumulatedVertices("run2_4") should contain allOf("Copy Copy Copy Copy Copy A_1_2_3_4_4",
      "Copy Copy Copy Copy Copy B_1_2_3_4_4", "Copy Copy Copy Copy Copy C_1_2_3_4_4")
    accumulatedVertices("run2_5") should contain allOf("Copy Copy Copy Copy Copy Copy A_1_2_3_4_4_5",
      "Copy Copy Copy Copy Copy Copy B_1_2_3_4_4_5", "Copy Copy Copy Copy Copy Copy C_1_2_3_4_4_5")
    accumulatedVertices("run2_6") should contain allOf("Copy Copy Copy Copy Copy Copy Copy A_1_2_3_4_4_5_6",
      "Copy Copy Copy Copy Copy Copy Copy B_1_2_3_4_4_5_6", "Copy Copy Copy Copy Copy Copy Copy C_1_2_3_4_4_5_6")
    accumulatedVertices("run2_7") should contain allOf("Copy Copy Copy Copy Copy Copy Copy Copy A_1_2_3_4_4_5_6_7",
      "Copy Copy Copy Copy Copy Copy Copy Copy B_1_2_3_4_4_5_6_7",
      "Copy Copy Copy Copy Copy Copy Copy Copy C_1_2_3_4_4_5_6_7")
    accumulatedVertices("run2_8") should contain allOf("Copy Copy Copy Copy Copy Copy Copy Copy Copy A_1_2_3_4_4_5_6_7_8",
      "Copy Copy Copy Copy Copy Copy Copy Copy Copy B_1_2_3_4_4_5_6_7_8",
      "Copy Copy Copy Copy Copy Copy Copy Copy Copy C_1_2_3_4_4_5_6_7_8")
    accumulatedVertices("run2_9") should contain allOf("Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy A_1_2_3_4_4_5_6_7_8_9",
      "Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy B_1_2_3_4_4_5_6_7_8_9",
      "Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy C_1_2_3_4_4_5_6_7_8_9")
    accumulatedVertices("run2_10") should contain allOf("Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy A_1_2_3_4_4_5_6_7_8_9_10",
      "Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy B_1_2_3_4_4_5_6_7_8_9_10",
      "Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy Copy C_1_2_3_4_4_5_6_7_8_9_10")
    checkpoints should have size 10
  }

After executing the above test case we can observe checkpointing activity in the logs:

Written partitioner to file:/tmp/graphx-checkpoint/6a534fac-1850-4d64-892c-69929f31659d/rdd-104/_partit
ioner (org.apache.spark.rdd.ReliableCheckpointRDD:58)
Done checkpointing RDD 104 to file:/tmp/graphx-checkpoint/6a534fac-1850-4d64-892c-69929f31659d/rdd-104,
new parent is RDD 112 (org.apache.spark.rdd.ReliableRDDCheckpointData:54)

As you can see, the RDD's partitioner is written alongside the data. This writing is a best-effort because any error occurred here doesn't break data checkpointing.

Pregel checkpointing

And the partitioner is not the single uncovered checkpoint-element yet. Another one is PeriodicGraphCheckpointer. It's used in the computations using Pregel computation model. As mentioned in the first comment of the previous test case, it's the advised method to deal with iterative algorithms in GraphX.

PeriodicGraphCheckpointer is used in Pregel to checkpoint the graph at every iteration:

    while (activeMessages > 0 && i < maxIterations) {
      // Receive the messages and update the vertices.
      prevG = g
      g = g.joinVertices(messages)(vprog)
      graphCheckpointer.update(g)

This special type of checkpointer delays the physical checkpointing process. That said, when update(graph) method is invoked by the client, PeriodicGraphCheckpointer first puts the vertices and the edges into the cache. If the number of cached graphs is greater than 3, it removes extra caches. Later, the PeriodicGraphCheckpointer instance does checkpointing only when the number of successive update calls is equal to the checkpointInterval parameter defined in its constructor. If it's the case, the checkpointer checkpoints the most recent data and removes all older checkpoints.

PeriodicGraphCheckpointer is used in Pregel because of the initial lack of support for long lineage chains. That said, the engine is not prepared to deal with iterative algorithms where each iteration produces new RDD based on the previous one. It was shown that without PeriodicGraphCheckpointer which eliminates parents, the processing was slowing down and eventually failing with StackOverflowException. You can find more information in SPARK-5484 ticket from Read also section.

Recomputation

Another way to deal with failures in GraphX consists of letting the task fail and recover automatically. During the automatic recovery, the engine will recompute the data needed by the task, thus the failed processed partition. Since Graph is composed of 2 RDDs (vertices and edges), does it mean that always the partitions for both of them are recomputed? No. We can figure out that by making an exception happen and analyzing the produced stack trace. For an exception produced in mapVertices transformation, we can see that the computation concerns VertexRDD and, hence, it'll be, if it's not checkpointed and cached, recomputed:

ERROR Exception in task 0.0 in stage 2.0 (TID 4) (org.apache.spark.executor.Executor:91)
java.lang.RuntimeException: Expected failure
        at com.waitingforcode.graphx.faulttolerance.RecomputeTest$$anonfun$4$$anonfun$3.apply$mcIJI$sp(RecomputeTest.scala:44)
        at com.waitingforcode.graphx.faulttolerance.RecomputeTest$$anonfun$4$$anonfun$3.apply(RecomputeTest.scala:40)
        at com.waitingforcode.graphx.faulttolerance.RecomputeTest$$anonfun$4$$anonfun$3.apply(RecomputeTest.scala:40)
        at org.apache.spark.graphx.impl.VertexPartitionBaseOps.map(VertexPartitionBaseOps.scala:61)
        at org.apache.spark.graphx.impl.GraphImpl$$anonfun$5.apply(GraphImpl.scala:129)
        at org.apache.spark.graphx.impl.GraphImpl$$anonfun$5.apply(GraphImpl.scala:129)
        at scala.collection.Iterator$$anon$11.next(Iterator.scala:410)
        at org.apache.spark.storage.memory.MemoryStore.putIterator(MemoryStore.scala:221)
        at org.apache.spark.storage.memory.MemoryStore.putIteratorAsValues(MemoryStore.scala:298)
        at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1165)
        at org.apache.spark.storage.BlockManager$$anonfun$doPutIterator$1.apply(BlockManager.scala:1156)
        at org.apache.spark.storage.BlockManager.doPut(BlockManager.scala:1091)
        at org.apache.spark.storage.BlockManager.doPutIterator(BlockManager.scala:1156)
        at org.apache.spark.storage.BlockManager.getOrElseUpdate(BlockManager.scala:882)
        at org.apache.spark.rdd.RDD.getOrCompute(RDD.scala:335)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:286)
        at org.apache.spark.graphx.VertexRDD.compute(VertexRDD.scala:69)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)

Below you can find an example of GraphX automatic recovery from where provides the previous stack trace:


  // Since local mode doesn't accept retry, we must try on standalone installation
  private val TestSparkContext = SparkContext.getOrCreate(new SparkConf().setAppName("GraphX automatic recovery")
    .set("spark.task.maxFailures", "5")
    .set("spark.executor.extraClassPath", sys.props("java.class.path"))
    .setMaster("spark://localhost:7077"))

  private val EdgesFile = new File("/tmp/graphx-recompute/edges.txt")

  before {
    val edgesContent =
      """
        |# comment
        |1 2
        |1 3
        |1 4
        |4 2""".stripMargin
    FileUtils.writeStringToFile(EdgesFile, edgesContent)
  }

  after {
    EdgesFile.delete()
  }

  behavior of "recompute recovery"

  it should "recover from temporary failure" in {
    val failingGraph = GraphLoader.edgeListFile(TestSparkContext, EdgesFile.getAbsolutePath)

    val mappedVertices = failingGraph.mapVertices {
      case (vertexId, vertexValue) => {
        if (vertexId == 2 && !FailingFlag.wasFailed) {
          FailingFlag.wasFailed = true
          throw new RuntimeException("Expected failure")
        }
        vertexValue + 10
      }
    }

    val vertices = mappedVertices.vertices.collect().map {
      case (vertexId, vertex) => vertex
    }

    vertices should have size 4
    vertices should contain allElementsOf(Seq(11, 11, 11, 11))
  }

object FailingFlag {
  var wasFailed = false
}

This fault-tolerance mechanism is also known as a self-recovery process and it's guaranteed by the DAG described in Directed Acyclic Graph in Spark post. Thanks to it the engine knows how the data of failed task is computed and can retrigger whole computation in order to retry the task execution. Of course, if the error is a programming bug, the retry will never work. But in the case of temporary errors, as for instance 3rd party service temporary unavailability, it should work.

In this post we can discover how GraphX ensures fault-tolerance. The module uses 2 solutions implemented in other Apache Spark modules: checkpointing and automatic task recovery. The former one stores graph before it's materialized. It's thus a blocking operation helping, however, to recover from failures faster than the automatic task recovery. The checkpointing in GraphX also offers a special implementation reserved to Pregel iterative algorithms called PeriodicGraphCheckpointer. It addresses the problem of long RDDs created after multiple iterations. By reducing their size to 0 parents it helps to avoid StackOverflowError, as described in SPARK-5484 ticket.