Regression tests with Apache Spark SQL joins

Versions: Apache Spark 2.4.0

Regressions are one of the risks of our profession. Fortunately, we can limit the risk thanks to different testing strategies. One of them are regression tests that we can use to check whether the modified data processing logic didn't introduce the regressions simply by comparing two datasets.

In this post I will show a simple implementation of regression tests applied on JSON files generated with Apache Spark SQL. The first part will show different testing approaches that we can use in the context of data. The second section will focus on one of them using SQL JOIN's.

Regression tests in data context

The primary goal of regression tests is to ensure that the previously developed feature doesn't break after code changes. Regression tests are at a higher level than unit tests which should remain the main testing strategy of every software project, including the data ones. Regression tests should be rather considered as their complements.

Regression test in data-centric systems can be a simple comparison of 2 datasets. The first dataset is considered as the reference dataset because it was generated by currently running, and therefore, validated previously, version of the system. The second one is constructed by the release candidate version.

To build your regression testing system you can use different strategies:

Among the discussed solutions, regression tests applied on the real data are not only good because of the increased correctness but also because they can help to compare the execution time of both versions of code.

But as you can see, there is no one-fits-all solution. It's up to you and your team to take the advantages and disadvantages, and apply them to your application context. For the illustration purposes of this post I will opt for the solution using dataset from production.

JOIN-based regression tests

To represent regression tests in the code I will use the SQL's JOINs. This operator comes quite naturally to mind if we consider regression tests as a comparison between 2 datasets. In the following paragraphs you can find the different types of JOINs and their purpose in the regression tests context.

The first JOIN type that by the way covers almost all needs I'm testing, is LEFT JOIN where the left side is the reference dataset. With the help of LEFT JOIN you will be able to detect missing records in the tested dataset and also make some fine-grained checks on the data. Depending on your use case, you can later export invalid records to another place in order to make some ad-hoc querying, or simply count them to get the error ratio:

  val sparkSession: SparkSession = SparkSession.builder()
    .appName("Spark SQL JOIN-based regression test")
    .master("local[*]").getOrCreate()
  import sparkSession.implicits._

  private val referenceDataset = Seq(
    RegressionTestOrder(1L, 39.99d, Set(1L, 2L)),
    RegressionTestOrder(2L, 41.25d, Set(1L)),
    RegressionTestOrder(3L, 100d, Set(1L, 2L, 3L)),
    RegressionTestOrder(4L, 120d, Set(1L))
  ).toDF.cache() // cache it in order to avoid the recomputation
  private val generatedDataset = Seq(
    RegressionTestOrder(1L, 39.99d, Set(1L, 2L)),
    RegressionTestOrder(2L, 41.25d, Set.empty),
    RegressionTestOrder(3L, 200d, Set(1L, 2L, 3L)),
    RegressionTestOrder(100L, 200d, Set(1L, 2L, 3L))
  ).toDF.select($"id".as("generated_id"), $"amount".as("generated_amount"), $"itemIds".as("generated_itemIds")).cache()


  "LEFT JOIN with a custom comparator" should "be used to detect missing data" in {
    val allReferenceDataWithOptionalMatches =
      referenceDataset.join(generatedDataset, referenceDataset("id") === generatedDataset("generated_id"), "left")

    val notGeneratedReferenceData = allReferenceDataWithOptionalMatches.filter(row => row.getAs[Long]("generated_id") == null)
      .count()
    val commonEntries = allReferenceDataWithOptionalMatches.filter(row => row.getAs[Long]("generated_id") != null)
    commonEntries.cache()
    val invalidAmountGeneratedData = commonEntries
      .filter(row => row.getAs[Double]("generated_amount") != row.getAs[Double]("amount"))
    val invalidItemIdsGeneratedData = commonEntries
      .filter(row => row.getAs[Set[Long]]("generated_itemIds") != row.getAs[Set[Long]]("itemIds"))

    // Please notice that I'm using the .count() as an action but you can use any other valid action, like materializing
    // not matching data in order to investigate the inconsistencies later.
    notGeneratedReferenceData shouldEqual 1
    invalidAmountGeneratedData.count() shouldEqual 1
    invalidItemIdsGeneratedData.count() shouldEqual 1
  }

case class RegressionTestOrder(id: Long, amount: Double, itemIds: Set[Long])

The second useful type is LEFT ANTI JOIN that can be used to detect the records from the new dataset that are not included in the reference dataset:

  "LEFT ANTI JOIN" should "be used to detect the data missing in the reference dataset" in {
    val extraGeneratedData = generatedDataset
      .join(referenceDataset, referenceDataset("id") === generatedDataset("generated_id"), "leftanti").count()

    extraGeneratedData shouldEqual 1
  }

Of course, you could do both checks with FULL OUTER JOIN which is a valid solution too but requiring a little bit conditional logic in the code:

  "FULL OUTER JOIN" should "be used to detect all errors with a single join" in {
    val allReferenceDataWithOptionalMatches =
      referenceDataset.join(generatedDataset, referenceDataset("id") === generatedDataset("generated_id"), "full_outer")

    val notGeneratedReferenceData = allReferenceDataWithOptionalMatches.filter(row => row.getAs[Long]("generated_id") == null)
      .count()
    val commonEntries = allReferenceDataWithOptionalMatches.filter(row => row.getAs[Long]("id") != null &&
      row.getAs[Long]("generated_id") != null)
    commonEntries.cache()
    val invalidAmountGeneratedData = commonEntries
      .filter(row => row.getAs[Double]("generated_amount") != row.getAs[Double]("amount"))
    val invalidItemIdsGeneratedData = commonEntries
      .filter(row => row.getAs[Set[Long]]("generated_itemIds") != row.getAs[Set[Long]]("itemIds"))
    val extraGeneratedData = allReferenceDataWithOptionalMatches.filter(row => row.getAs[Long]("generated_id") != null &&
      row.getAs[Long]("id") == null).count()

    notGeneratedReferenceData shouldEqual 1
    invalidAmountGeneratedData.count() shouldEqual 1
    invalidItemIdsGeneratedData.count() shouldEqual 1
    extraGeneratedData shouldEqual 1
  }

Regression tests or globally any specific tests applied on the big volumes of data, are not an easy piece of cake. The execution time, the maintenance effort and the complexity may scary at first glance. However, with some simple practices like the ones discussed in this post, all these negatives points can be mitigated. The tested dataset can be dynamic and the complexity of the code can be summarized to the SQL JOINs and different attribute-focus comparison logic.


If you liked it, you should read:

📚 Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!