Creating graphs in GraphFrames

Versions: GraphFrames 0.6 https://github.com/bartosz25/spark-...graph/GraphRepresentationTest.scala

The Project Tungsten revolutionized Apache Spark ecosystem. Thanks to the new row-based data structure the jobs became more performant and easier to create. This revolution first affected the batch processing and later the streaming one. As of writing the following article, the graph processing is still not impacted but hopefully GraphFrames project can change this.

This post goes a little bit further than the Introduction to Apache Spark GraphX and proposes much more detailed view of GraphFrames. Its first section focuses on the aspects related to the graph creation. The second one shows some implementation details and their connection to GraphX module.

Graphs in GraphFrames

GraphFrames graphs are based on DataFrame API and therefore they are able to use any of the methods available to load or save the graph However, the persistent data must follow some specific rules and can't contain all original attributes:

  "GraphFrames" should "create the graph from an in-memory structure" in {
    val parentOfRelationship = "is parent of"
    import sparkSession.implicits._
    val people = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user4"), (5L, "user5")).toDF("id", "name")
    val relationships = Seq((1L, 2L, parentOfRelationship), (1L, 3L, parentOfRelationship), (2L, 3L, parentOfRelationship),
      (3L, 4L, parentOfRelationship)).toDF("src", "dst", "label")
    val graph = GraphFrame(people, relationships)

    val mappedVertices = graph.vertices.collect().map(row => row.getAs[String]("name"))

    mappedVertices should have size 5
    mappedVertices should contain allOf("user1", "user2", "user3", "user4", "user5")
  }

As you can see, the graphs in GraphFrames must have a specific structure - just like in GraphX. First, the vertices must have a numerical id field. Secondly, the edges must be represented by src and dst attributes. And finally, these attributes must be the ids of related vertices:

  "GraphFrames" should "fail creating the vertices without id" in {
    val parentOfRelationship = "is parent of"
    val people = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user4"), (5L, "user5")).toDF("user_id", "name")
    val relationships = Seq((1L, 2L, parentOfRelationship), (1L, 3L, parentOfRelationship), (2L, 3L, parentOfRelationship),
      (3L, 4L, parentOfRelationship)).toDF("src", "dst", "label")

    val error = intercept[IllegalArgumentException] {
      GraphFrame(people, relationships)
    }

    error.getMessage() should include("requirement failed: Vertex ID column 'id' missing from vertex DataFrame, which has columns: user_id,name")
  }

  "GraphFrames" should "fail creating the graph when edges don't have src or dst attributes" in {
    val parentOfRelationship = "is parent of"
    val people = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user4"), (5L, "user5")).toDF("id", "name")
    val relationships = Seq((1L, 2L, parentOfRelationship), (1L, 3L, parentOfRelationship), (2L, 3L, parentOfRelationship),
      (3L, 4L, parentOfRelationship)).toDF("sourceVertex", "dst", "label")

    val error = intercept[IllegalArgumentException] {
      GraphFrame(people, relationships)
    }

    error.getMessage() should include("requirement failed: Source vertex ID column 'src' missing from edge DataFrame, " +
      "which has columns: sourceVertex,dst,label")
  }

  "GraphFrames" should "fail creating graph with incorrect src and dst values" in {
    val parentOfRelationship = "is parent of"
    val people = Seq((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user4"), (5L, "user5")).toDF("id", "name")
    val relationships = Seq(("1", "2", parentOfRelationship), ("1", "3", parentOfRelationship), ("2", "3", parentOfRelationship),
      ("3", "4", parentOfRelationship)).toDF("sourceVertex", "dst", "label")

    val error = intercept[IllegalArgumentException] {
      GraphFrame(people, relationships)
    }

    error.getMessage() should include("requirement failed: Source vertex ID column 'src' missing from edge DataFrame, " +
      "which has columns: sourceVertex,dst,label")
  }

The above examples were quite simple. We've only created a graph from the memory and collected the vertices. But we can do more complicated things, for instance, reading the data from JSON files, process it and write back to JSON format:

  "GraphFrames" should "read the graph from JSON and read it back after mapping the edges" in {
    val usersFile = new File("./users")
    usersFile.deleteOnExit()
    val users =
      """
        |{"id": 1, "name": "user1"}
        |{"id": 2, "name": "user2"}
        |{"id": 3, "name": "user3"}
      """.stripMargin
    FileUtils.writeStringToFile(usersFile, users)
    val friendsFile = new File("./friends")
    friendsFile.deleteOnExit()
    val friends =
      """|{"userFrom": 1, "userTo": 2, "confirmed": true, "isNew": false}
        |{"userFrom": 1, "userTo": 3, "confirmed": true, "isNew": false}""".stripMargin
    FileUtils.writeStringToFile(friendsFile, friends)

    val vertices = sparkSession.read.json(usersFile.getAbsolutePath)
    val edges = sparkSession.read.json(friendsFile.getAbsolutePath)
      .withColumnRenamed("userFrom", "src")
      .withColumnRenamed("userTo", "dst")

    val graph = GraphFrame(vertices, edges)

    val verticesToRestore = new File("./vertices_to_restore")
    verticesToRestore.deleteOnExit()
    graph.vertices.write.mode(SaveMode.Overwrite).json(verticesToRestore.getAbsolutePath)
    val edgesToRestoreFile = new File("./edges_to_restore")
    edgesToRestoreFile.deleteOnExit()
    graph.edges.write.mode(SaveMode.Overwrite).json(edgesToRestoreFile.getAbsolutePath)
    val verticesRestored = sparkSession.read.json(verticesToRestore.getAbsolutePath)
    val edgesRestored = sparkSession.read.json(edgesToRestoreFile.getAbsolutePath)
    val graphRestored = GraphFrame(verticesRestored, edgesRestored)

    val verticesFromRawSource = graph.vertices.collect()
    val verticesFroRestoredGraph = graphRestored.vertices.collect()
    verticesFromRawSource should contain allElementsOf(verticesFroRestoredGraph)
    def mapEdge(row: Row) = s"${row.getAs[Long]("src")} --> ${row.getAs[Long]("dst")}"
    val edgesFromRawSource = graph.edges.collect().map(row => mapEdge(row))
    val edgesFromRestoredGraph = graphRestored.edges.collect().map(row => mapEdge(row))
    edgesFromRawSource should contain allElementsOf(edgesFromRestoredGraph)
  }

Implementation details

GraphFrames graphs use Apache Spark SQL Row to represent the sets of vertices and edges. We can see that from the constructor's analysis:

class GraphFrame private(
    @transient private val _vertices: DataFrame,
    @transient private val _edges: DataFrame)

The underlying structure also helps to leverage the power of SQL to make some operations. For instance, to compute the degrees, GraphFrames uses an aggregation based on one of edge's vertices:

edges.groupBy(edges(SRC).as(ID)).agg(count("*").cast("int").as("outDegree"))

Another pretty interesting query detail is the pattern finding. If you've already worked with Cypher, you've certainly used more than once the patterns like (vertex1)-[:IS_FRIEND]->(vertex2) to retrieve interesting parts of the graph. The pattern finding in GraphFrames offers very similar manner to find graph structures. Please notice I won't cover it here because it's quite interesting feature that merits its own post.

The graphs created with GraphFrames can be easily converted to the graphs manipulated by GraphX, and inversely. By the way, we can observe that in some of implemented graph algorithms, as label propagation, PageRank or strongly connected components that under-the-hood call the GraphX implementation:

private object LabelPropagation {
  private def run(graph: GraphFrame, maxIter: Int): DataFrame = {
    val gx = graphxlib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter)
    GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(LABEL_ID)).vertices
  }
// ...

private object PageRank {
  def run(
      graph: GraphFrame,
      maxIter: Int,
      resetProb: Double = 0.15,
      srcId: Option[Any] = None): GraphFrame = {
    val longSrcId = srcId.map(GraphXConversions.integralId(graph, _))
    val gx = graphxlib.PageRank.runWithOptions(
      graph.cachedTopologyGraphX, maxIter, resetProb, longSrcId)
    GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(PAGERANK), edgeNames = Seq(WEIGHT))
  }
/// ...
private object StronglyConnectedComponents {
  private def run(graph: GraphFrame, numIter: Int): DataFrame = {
    val gx = graphxlib.StronglyConnectedComponents.run(graph.cachedTopologyGraphX, numIter)
    GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(COMPONENT_ID)).vertices
  }

Graphs in GraphFrames have the same requirements that the graphs in GraphX. The vertices must have the id field, the edges vertices must be called with src and dst numerical columns. But it's not the only similarity. Apart from that, GraphFrames share a lot of features with GraphX, such as: exposing triplets, caching or calling iterative algorithms with Pregel implementation. However, there is a subtle but important difference , because GraphFrames offer an easy way to deal with patterns finding through the syntax similar to the one used in Cypher query language. However, this feature will be covered in more detail in one of the next posts.


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!