Testing Spark applications

Versions: Spark 2.1.0

It's difficult to contest the importance of testing in programming. Tests help to avoid regressions (a lot of regressions) and also to better understand developed code. Spark (and other data processing frameworks by the way) is not an exception of this rule. But, obviously, testing applications working in distributed mode is more tricky than in the case of standalone programs.

In this post we focus on writing tests for Spark applications. The first part concerns unit tests. It shows some useful patterns making unit tests definition easier. The second part is about integration tests using local SparkContext.

Unit tests in Spark

As a reminder, let's recall the short and simple definition of unit test. It's a test verifying a particular piece of code working in separation of its dependencies. Specifically, it concerns tests of particular methods.

In Spark a lot of things depend on RDDs, Datasets or DStreams and sometimes it's an excuse for programmers to ignore unit tests. However, beside of these 3 objects, Spark is also an assembly of functional transformations (map, filter, group, ...). Thanks to that, the code isn't highly coupled to Spark-related data abstractions and can be easily tested with any testing framework. Below examples show some tests of transformations based on Scalatest:

class UnitTestExampleTest extends FlatSpec with Matchers {

  "filters" should "be testable without Spark context" in {
    // example of use:
    // rdd.filters(isGreaterThan(_, 5))
    Filters.isGreaterThan(1, 5) shouldBe(false)
  }

  "mapper" should "be testable without Spark context" in {
    // example of use:
    // rdd.map(mapToString(_))
    Mappers.mapToString(1) shouldEqual("Number 1")
  }

  "to pair mapper" should "be testable without Spark context" in {
    // example of use:
    // rdd.mapToPair(mapToPair(_))
    Mappers.mapToPair("vegetables: potato, carrot") shouldEqual(("vegetables", "potato, carrot"))
  }

  "partitioner" should "return 0 partition for pair key" in {
    // example of use:
    // rdd.partitionBy(new SamplePartitioner())
    val partitioner = new SamplePartitioner()

    partitioner.getPartition(4) shouldEqual(0)
  }

  "aggregate by partition" should "be testable without SparkContext" in {
    // example of use:
    // rdd.aggregate("")((text, number) => text + ";" + number, (text1, text2) => text1 + "-"+text2)

    Reducers.partitionConcatenator("123", 4) shouldEqual("123;4")
  }

  "aggregate combiner" should "be testable without SparkContext" in {
    // refer to "aggregate by partition" should "be testable without SparkContext" in
    // to see the example of use

    Reducers.combineConcatenations("123", "456") shouldEqual("123-456")
  }

}


object Filters {

  def isGreaterThan(number: Int, lowerBound: Int): Boolean = {
    number > lowerBound
  }

}

object Mappers {

  def mapToString(number: Int): String = {
    s"Number ${number}"
  }

  // mapToPair
  def mapToPair(shoppingList: String): (String, String) = {
    val textParts = shoppingList.split(": ")
    (textParts(0).trim, textParts(1).trim)
  }

}

object Reducers {

  def partitionConcatenator(concatenatedText: String, newNumber: Int): String = {
    concatenatedText + ";" + newNumber
  }

  def combineConcatenations(concatenatedText1: String, concatenatedText2: String): String = {
    concatenatedText1 + "-" + concatenatedText2
  }


}

class SamplePartitioner extends Partitioner{
  override def numPartitions: Int = 2

  override def getPartition(key: Any): Int = {
    key match {
      case Int => partitionIntKey(key.asInstanceOf[Int])
      case _ => throw new IllegalArgumentException(s"Unsupported key ${key}")
    }
  }

  private def partitionIntKey(key: Int): Int = {
    if (key%2 == 0) 0 else 1
  }
}

As you can see, above tests never rely on Spark objects (RDDs etc.) but they represent the pieces of code used in processing pipelines.

Integration tests in Spark

Another category of tests are integration tests. Unlike unit tests, their goal is to verify that multiple units work correctly all together. In the case of Spark, it means that we'll check how data pipeline behaves, i.e. if all defined transformations and actions work correctly, eventually with mocks used to input sources or output sinks.

As you can deduce, integration tests need either a lot of mocking or the use of real SparkContext. Since the tests must be among others easily maintainable, the option with mock is not possible. It's why SparkContext is used. But as we know, there must be a single one context per JVM, thus some additional effort helping to avoid manual creation of it must be devoted. In our example we use Scala's trait but in posts's conclusion you can find some open source projects having this feature already implemented.

Another point related to integration tests (and unit tests by the way) concerns data generation. Popular solution helping to avoid some nasty effects of desynchronization between data stored in files and evolving code base is Domain-Specific Language (DSL). We can use it to model testing dataset on objects used in transformations or actions.

Other improvement in testing can be customized assertions. They help to centralize test validation logic in one place and thus better control assertions use, e.g. avoid to forget to write assertions on specific objects. To do that in Scala we can use customized Scalatest's Matcher while for Java code base a great solution appears to be AssertJ custom assertions. Below example will show the use of matchers because it's written in Scala.

Below an example of integration test for batch processing of unstructured data with the use of SparkContext's trait, DSL and matchers:

/**
  * Integration test example on Spark. It shows the use
  * of 3 concepts described in blog post about unit and
  * integration tests in Spark: shared context, DSL and custom matchers.
  */
class IntegrationTestExampleTest extends FlatSpec with Matchers with BeforeAndAfter with WithMethodScopedSparkContext {

  before {
    withLocalContext("Integration test example")
  }

  "RDD representing user sessions" should "be grouped by key and collected in the action" in {
    val rdd = (UserSessionDsl sessionsData sparkContext) > UserSession(1, "home.html", 30) >
      UserSession(2, "home.html", 45) > UserSession(2, "catalog.html", 10) > UserSession(2, "exit.html", 5) >
      UserSession(1, "cart.html", 10) > UserSession(1, "payment.html", 10) toRdd

    val wholeSessions = rdd.map(session => (session.userId, session.timeInSec))
      .reduceByKey((time1, time2) => time1 + time2)
      .map((userStatPair) => UserVisit(userStatPair._1, userStatPair._2))
      .collect()

    UserVisitMatcher.user(wholeSessions, 1).spentTotalTimeOnSite(50)
    UserVisitMatcher.user(wholeSessions, 2).spentTotalTimeOnSite(60)
  }

}


/**
  * Trait that can be used to facilitate work with SparkContext lifecycle
  * in tests.
  * It's only for illustration purposes. To your production
  * application you can consider to use Spark Testing Base from https://github.com/holdenk/spark-testing-base
  */
trait WithMethodScopedSparkContext {

  var sparkContext:SparkContext = null

  def withLocalContext(appName: String) = {
    if (sparkContext != null) {
      sparkContext.stop()
    }
    val conf = new SparkConf().setAppName(appName).setMaster("local")
    sparkContext = SparkContext.getOrCreate(conf)
  }
}


/**
  * Below code shows sample custom DSL used to inject data
  * representing user sessions for batch processing.
  */
class UserSessionDsl(sparkContext: SparkContext) {

  val sessions:ListBuffer[UserSession] = new ListBuffer()

  def >(userSession: UserSession): UserSessionDsl = {
    sessions.append(userSession)
    this
  }

  def toRdd(): RDD[UserSession] = {
    sparkContext.parallelize(sessions)
  }
}

object UserSessionDsl {
  def sessionsData(sparkContext: SparkContext): UserSessionDsl = {
    new UserSessionDsl(sparkContext)
  }
}

case class UserSession(userId: Long, page: String, timeInSec: Long)

/**
  * Sample code used to build a matcher.
  */
case class UserVisit(userId: Long, totalSpentTimeInSec: Long)

class UserVisitMatcher(visits: Seq[UserVisit], userId: Long) {

  val userVisitOption = visits.find(_.userId == userId)
  assert(userVisitOption.isDefined)
  val userVisit = userVisitOption.get

  def spentTotalTimeOnSite(expectedTime: Long): UserVisitMatcher = {
    assert(userVisit.totalSpentTimeInSec == expectedTime)
    this
  }


}

object UserVisitMatcher {
  def user(visits: Seq[UserVisit], userId: Long): UserVisitMatcher = {
    new UserVisitMatcher(visits, userId)
  }
}

This post shows the way of testing Spark-based applications from programming level. The first part talks about unit tests and proves that code involved in Spark treatment can easily be checked with classic TDD frameworks (JUnit, Scalatest...). The second part explains how to write integration tests focused on singular processing containing unitary tested transformations and actions. The example of not described test is end-to-end one that can verify data pipeline integration with other parts of the system. Also, the post doesn't talk about available Spark testing libraries, as SSCheck or Spark testing base.