How to use IN clause in Spark SQL query ?

SqlInClauseTest.scala

Spark SQL provides the support for a lot of standard SQL operations, including IN clause. It can be easily used through the import of the implicits of created SparkSession object:

private val sparkSession: SparkSession = SparkSession.builder()
    .appName("Spark SQL IN tip").master("local[*]").getOrCreate()

import sparkSession.implicits._

The method responsible for executing IN clause is def isin(list: Any*). Under-the-hood it invokes the org.apache.spark.sql.catalyst.expressions.In predicate. The predicate, depending on the data source, either generates the Java code or creates predicate to push down to the data source as value IN(....) clause.

Concerning the generated code, we'd think that one of available Java methods is used (kind of contains*). But it's not the case and Spark flattens the list and maps each entry of it to the following code:

boolean filter_value = false;
boolean filter_isNull = inputadapter_isNull;
if (!filter_isNull) {
  if (!filter_value) {
    if (false) {
      filter_isNull = true;
    } else if (inputadapter_value.equals(((UTF8String) references[1]))) {
      filter_isNull = false;
      filter_value = true;
    }
  }
  if (!filter_value) {
    if (false) {
      filter_isNull = true;
    } else if (inputadapter_value.equals(((UTF8String) references[2]))) {
      filter_isNull = false;
      filter_value = true;
    }
  }
}

As you an see, the references[*] corresponds to the options from IN clause. Each of them is teted against simple if-else cases only if previous case didn't change the filter_value field to true.

To see how to use IN clause in Spark, let's see the following learning test case:

override def beforeAll {
  case class CountryOperation(isoCode: String, name: String) extends DataOperation {
    override def populatePreparedStatement(preparedStatement: PreparedStatement): Unit = {
      preparedStatement.setString(1, isoCode)
      preparedStatement.setString(2, name)
    }
  }
  val countriesToInsert = Seq(CountryOperation("FR", "France"), CountryOperation("PL", "Poland"),
    CountryOperation("GB", "The United Kingdom"), CountryOperation("HT", "Haiti"), CountryOperation("JM", "Jamaica")
  )

  InMemoryDatabase.createTable("CREATE TABLE countries(iso VARCHAR(20) NOT NULL, countryName VARCHAR(20) NOT NULL)")
  InMemoryDatabase.populateTable("INSERT INTO countries (iso, countryName) VALUES (?, ?)", countriesToInsert)
}

override def afterAll {
  sparkSession.stop()
  InMemoryDatabase.cleanDatabase()
}

"SQL IN clause" should "be used to filter some rows" in {
  val countriesReader = sparkSession.read.format("jdbc")
    .option("url", InMemoryDatabase.DbConnection)
    .option("driver", InMemoryDatabase.DbDriver)
    .option("dbtable", "countries")
    .option("user", InMemoryDatabase.DbUser)
    .option("password", InMemoryDatabase.DbPassword)
    .load()

  import sparkSession.implicits._
  val europeanCountries: Array[String] = countriesReader.select("iso", "countryName")
    .where($"iso".isin("FR", "PL", "GB"))
    .map(row => row.getString(1))
    .collect()

  europeanCountries should have length 3
  europeanCountries should contain allOf("France", "Poland", "The United Kingdom")
}

"SQL IN clause" should "be applied for in-memory DataFrame" in {
  import sparkSession.implicits._
  val countriesDataFrame = Seq(
    ("FR", "France"), ("DE", "Germany"), ("CA", "Canada"), ("BR", "Brazil"), ("AR", "Argentina")
  ).toDF("iso", "countryName")

  val europeanCountries: Array[String] = countriesDataFrame.select("iso", "countryName")
    .where($"iso".isin("FR", "DE"))
    .map(row => row.getString(1))
    .collect()

  europeanCountries should have length 2
  europeanCountries should contain allOf("France", "Germany")

}