User Defined Functions

Versions: Spark 2.1.0

User Defined Types (UDT) described in one of previous posts aren't the single customization possibility in Apache Spark SQL. The other possibility are User Defined Functions (UDF).

This post explains what UDF are and shows how to use them. The first part focuses on theoretical explaination. The next section shows some of UDF internals. Finally, the last section shows sample implementations of UDF.

User Defined Function definition

Spark SQL UDF is not a new invention because it exists already, under the same name, in RDBMS world (the famous CREATE FUNCTION... sentence). Similarly to them Spark's UDF extends system built-in functionalities, i.e. it allows the definition of custom processing methods applied on column level.

However, they come not without a cost. Spark perceives them as a blackboxes and it's unable to apply the optimizations. Moreover, poorly written UDF can decrease program performance. Thus even if its execution time on 1 row can appear small, we'd keep in mind that this execution time must be multiplied by the number of processed rows. In consequence, if inside UDF you make actions with unpredictible execution time, as retrieving data from REST services, reading data from database, you'll probably need some extra optimization effort (e.g. adding cache layer or something).

User Defined Function internals

The UDF can be defined by calling org.apache.spark.sql.functions$#udf(scala.Function2<A1,A2,RT>, scala.reflect.api.TypeTags.TypeTag<RT>, scala.reflect.api.TypeTags.TypeTag<A1>, scala.reflect.api.TypeTags.TypeTag<A2>) or org.apache.spark.sql.UDFRegistration#register(java.lang.String, scala.Function1<A1,RT>, scala.reflect.api.TypeTags.TypeTag<RT>, scala.reflect.api.TypeTags.TypeTag<A1>). But the register method has an influence on further use. UDF registered with udf(...) can't be used in select or filtering method containing stringified expression. In the other side, the UDF added through register(...) can be freely used as string expression (= exactly as SQL functions, for instance: "SELECT LCASE(upper_cased_column)....") and org.apache.spark.sql.Column-based expression.

But in both cases the register action produces the instance of org.apache.spark.sql.expressions.UserDefinedFunction class. So registered UDFs are later called in the while loop from org.apache.spark.sql.catalyst.expressions.InterpretedProjection#apply method:

/**
 * Note that exprArray is the list of expressions applied to defined columns.
 * Below you can find 2 expressions: 0 -> gets the column as it, 
 * 1 -> applies UDF defined in test of the last section on the 2nd selected column
 * 
 * exprArray = {Expression[2]@7263} 
 * 0 = {Alias@7298} "input[0, string, true] AS letter#5"
 * 1 = {Alias@7286} "UDF(input[1, int, false]) AS isEven#30"
 *     child = {ScalaUDF@7301} "UDF(input[1, int, false])"
 *     function = {UserDefinedFunctionTest$$anonfun$6$$anonfun$apply$1@7331} ""
 */
protected val exprArray = if (expressions != null) expressions.toArray else null
// ...
while (i < exprArray.length) {
  outputArray(i) = exprArray(i).eval(input)
  i += 1
}

The evaluation invokes apply method of UserDefinedFunction class that creates org.apache.spark.sql.catalyst.expressions.ScalaUDF instance. This instance plays 2 different roles. If the UDF is used in the filtering, it adds the UDF representation in the Java code generated by Spark. This representation is later added to GeneratedIterator used to iterate over selected rows. The second role of ScalaUDF consists on applying UDF directly on each row when the UDF is used in select statement.

Below you can find an example of generated filter with UDF applied on GeneratedIterator instance:

/* 013 */   private scala.Function1 filter_udf;
// ...
/* 039 */   public void init(int index, scala.collection.Iterator[] inputs) { 
/* 042 */     wholestagecodegen_init_0();  
/* 046 */   }
/* 047 */
/* 048 */   private void wholestagecodegen_init_0() {
/* 051 */     this.filter_scalaUDF = (org.apache.spark.sql.catalyst.expressions.ScalaUDF) references[1];
/* 052 */     this.filter_catalystConverter = (scala.Function1)org.apache.spark.sql.catalyst.CatalystTypeConverters$.MODULE$.createToCatalystConverter(filter_scalaUDF.dataType());
/* 053 */     this.filter_converter = (scala.Function1)org.apache.spark.sql.catalyst.CatalystTypeConverters$.MODULE$.createToScalaConverter(((org.apache.spark.sql.catalyst.expressions.Expression)(((org.apache.spark.sql.catalyst.expressions.ScalaUDF)references[1]).getChildren().apply(0))).dataType());
/* 054 */     this.filter_udf = (scala.Function1)filter_scalaUDF.userDefinedFunc();
/* 060 */   }
// ...
/* 085 */   protected void processNext() throws java.io.IOException {
// ...
/* 094 */
/* 095 */       Boolean filter_result = null;
/* 096 */       try {
                // apply filtering here
/* 097 */         filter_result = (Boolean)filter_catalystConverter.apply(filter_udf.apply(filter_arg));
/* 098 */       } catch (Exception e) {
/* 099 */         throw new org.apache.spark.SparkException(filter_scalaUDF.udfErrorMessage(), e);
/* 100 */       }
/* 101 */
/* 102 */       boolean filter_isNull1 = filter_result == null;
/* 103 */       boolean filter_value1 = false;
/* 104 */       if (!filter_isNull1) {
/* 105 */         filter_value1 = filter_result;
/* 106 */       }
/* 107 */       if (!filter_isNull1) {
/* 108 */         filter_isNull = false; // resultCode could change nullability.
/* 109 */         filter_value = filter_value1 == true;
/* 110 */
/* 111 */       }
                // if the row doesn't match filter criteria, skip it
/* 112 */       if (filter_isNull || !filter_value) continue; 

As announced in the previous section, UDF functions aren't optimized. However, sometimes they can be analyzed. It's the case of org.apache.spark.sql.catalyst.analysis.Analyzer.HandleNullInputsForUDF method applied to handle null in primitive inputs. Its activity can be detected in logs with entries beginning with Batch UDF has....

User Defined Function example

After introducing UDF we can see how to manipulate them either as a part of select or filter statements:

val sparkSession = SparkSession.builder().appName("UDF test")
  .master("local[*]").getOrCreate()

def evenFlagResolver(number: Int): Boolean = {
  number%2 == 0
}

def multiplicatorCurried(factor: Int)(columnValue: Int): Int = {
  factor * columnValue
}

import sparkSession.implicits._
val letterNumbers = Seq(
  ("A", 1), ("B", 2), ("C", 3), ("D", 4), ("E", 5), ("F", 6)
).toDF("letter", "number")

val watchedMoviesPerUser = Seq(
  (1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (3, 2), (3, 4), (3, 5)
).toDF("user", "movie")

val letters = Seq(
  ("a", "A"), ("b", "B"), ("c", "C"), ("d", "D"), ("e", "E"), ("f", "F")
).toDF("lowercased_letter", "uppercased_letter")

override def afterAll {
  sparkSession.stop()
}

"UDF" should "be registered through register method" in {
  sparkSession.udf.register("EvenFlagResolver_registerTest", evenFlagResolver _)

  // function registered through udf.register can be used in select expressions
  // It could also be used with DSL, as:
  // select($"letter", udfEvenResolver($"number") as "isEven") where
  // udfEvenResolver = sparkSession.udf.register("EvenFlagResolver_registerTest", evenFlagResolver)
  val rows = letterNumbers.selectExpr("letter", "EvenFlagResolver_registerTest(number) as isEven")
    .map(row => (row.getString(0), row.getBoolean(1)))
    .collectAsList()

  rows should contain allOf(("A", false), ("B", true), ("C", false), ("D", true), ("E", false), ("F", true))
}

"UDF registered with udf(...)" should "not be usable in select expression" in {
  val udfEvenResolver = udf(evenFlagResolver _)

  val analysisException = intercept[AnalysisException] {
    letterNumbers.selectExpr("letter", "udfEvenResolver(number) as isEven")
      .map(row => (row.getString(0), row.getBoolean(1)))
      .collectAsList()
  }

  analysisException.message.contains("Undefined function: 'udfEvenResolver'. " +
    "This function is neither a registered temporary function nor a permanent function registered " +
    "in the database 'default'.") shouldBe true
}

"UDF" should "be used with udf(...) method" in {
  val udfEvenResolver = udf(evenFlagResolver _)

  val rows = letterNumbers.select($"letter", udfEvenResolver($"number") as "isEven")
    .map(row => (row.getString(0), row.getBoolean(1)))
    .collectAsList()

  rows should contain allOf(("A", false), ("B", true), ("C", false), ("D", true), ("E", false), ("F", true))
}

"UDF taking arguments" should "be correctly called in select expression" in {
  val udfMultiplicatorCurried = udf(multiplicatorCurried(5)_)

  val rows = letterNumbers.select($"number", udfMultiplicatorCurried($"number") as "isEven")
    .map(row => (row.getInt(0), row.getInt(1)))
    .collectAsList()

  rows should contain allOf((1,5), (2,10), (3,15), (4,20), (5,25), (6,30))
}

"UDF not optimized" should "slow processing down" in {
  val udfNotOptimized = udf(WebService.callWebServiceWithoutCache _)

  val watchedMoviesPerUser = Seq(
    (1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (3, 2), (3, 4), (3, 5)
  ).toDF("user", "movie")

  // Unoptimized UDF
  val startUnoptimizedProcessing = System.currentTimeMillis()
  val rowsUnoptimized = watchedMoviesPerUser.select($"user", udfNotOptimized($"movie") as "movie_title")
    .map(row => (row.getInt(0), row.getString(1)))
    .collectAsList()
  val endUnoptimizedProcessing = System.currentTimeMillis()
  val totalUnoptimizedProcessingTime = endUnoptimizedProcessing - startUnoptimizedProcessing

  // Optimized UDF
  val startOptimizedProcessing = System.currentTimeMillis()
  val udfOptimized = udf(WebService.callWebServiceWithCache _)
  val rowsOptimized = watchedMoviesPerUser.select($"user", udfOptimized($"movie") as "movie_title")
    .map(row => (row.getInt(0), row.getString(1)))
    .collectAsList()
  val endOptimizedProcessing = System.currentTimeMillis()

  val totalOptimizedProcessingTime = endOptimizedProcessing - startOptimizedProcessing

  rowsUnoptimized should contain allOf(
    (1, "Title_1"), (1, "Title_2"), (1, "Title_3"), (2, "Title_1"), (2, "Title_2"), (3, "Title_2"),
    (3, "Title_4"), (3, "Title_5")
  )
  rowsOptimized should contain allOf(
    (1, "Title_1"), (1, "Title_2"), (1, "Title_3"), (2, "Title_1"), (2, "Title_2"), (3, "Title_2"),
    (3, "Title_4"), (3, "Title_5")
  )
  totalUnoptimizedProcessingTime should be > totalOptimizedProcessingTime
}

"UDF" should "also be able to process 2 columns" in {
  def concatenateStrings(separator: String)(column1Value: String, column2Value: String): String = {
    s"${column1Value}${separator}${column2Value}"
  }

  val udfConcatenator = udf(concatenateStrings("-") _)
  val rows = letters.select(udfConcatenator($"lowercased_letter", $"uppercased_letter") as "word")
    .map(row => (row.getString(0)))
    .collectAsList()

  rows should contain allOf("a-A", "b-B", "c-C", "d-D", "e-E", "f-F")
}

"UDF" should "also be callable in nested way" in {
  def concatenateStrings(separator: String)(column1Value: String, column2Value: String): String = {
    s"${column1Value}${separator}${column2Value}"
  }

  def reverseText(text: String): String = {
    text.reverse
  }

  val udfConcatenator = udf(concatenateStrings("-") _)
  val udfReverser = udf(reverseText _)

  // Nested calls are easy to implement but from 3 functions
  // the code becomes less and less readable
  val rows = letters.select(udfReverser(udfConcatenator($"lowercased_letter", $"uppercased_letter")) as "reversed_word")
    .map(row => (row.getString(0)))
    .collectAsList()

  rows should contain allOf("A-a", "B-b", "C-c", "D-d", "E-e", "F-f")
}

"UDF" should "be used in where clause" in {
  val evenNumbersFilter: (Int) => Boolean = (nr) => { nr%2 == 0 }
  sparkSession.udf.register("EvenFlagResolver_whereTest", evenNumbersFilter)

  val evenNumbers = letterNumbers.selectExpr("letter", "number")
    .where("EvenFlagResolver_whereTest(number) == true")
    .map(row => (row.getString(0), row.getInt(1)))

  val rows = evenNumbers.collectAsList()

  evenNumbers.explain(true)
  rows should contain allOf(("B", 2), ("D", 4), ("F", 6))
}

object WebService {

val titresCache = mutable.Map[Int, String]()

def callWebServiceWithoutCache(id: Int): String = {
  // Let's suppose that our function call a web service
  // to enrich row data and that the call takes 200 ms every time
  Thread.sleep(200)
  s"Title_${id}"
}

def callWebServiceWithCache(id: Int): String = {
  titresCache.getOrElseUpdate(id,  {
    val title = callWebServiceWithoutCache(id)
    title
  })
}

}

User Defined Functions enhance basic feautres of every SQL-related software. Spark SQL is not an exception. Moreover, defining them is quite easy. There are simple Scala methods registered through udf(...) or register(...) methods. But, when we decide to define UDF(s), we should keep in mind that it won't be optimized and any bad coding can drastically decrease general app performance.