How to use a User Defined Function in the column-related operations in Apache Spark SQL?

Using UDF in SQL statement or in programmatic way is quite easy because either you define the function's name or simply call the object returned after the registration. Using it in column-based operations like withColumn is a little bit harder, at least for the very first time.

In order to call a UDF in column-based operations you need to use functions.callUDF() method, exactly like in the snippet below:

  "a UDF" should "be called in withColumn method" in {
    val testedSparkSession: SparkSession = SparkSession.builder()
      .appName("UDF from withColumn").master("local[*]").getOrCreate()
    import testedSparkSession.implicits._
    val orders = Seq((1L), (2L), (3L), (4L)).toDF("order_id")

    testedSparkSession.udf.register("generate_user_id", (orderId: Long) => s"user${orderId}")


    val ordersWithUserId =
      orders.withColumn("user_id", functions.callUDF("generate_user_id", $"order_id"))
      .map(row => (row.getAs[Long]("order_id"), row.getAs[String]("user_id")))
      .collect()

    ordersWithUserId should have size 4
    ordersWithUserId should contain allOf((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user4"))
  }

It was a complicated option. A simpler one (thanks Mikhail for the tip!) uses the UDF directly in the withColumn expression:

"a UDF" should "be called in withColumn method with direct invocation" in {
  val testedSparkSession: SparkSession = SparkSession.builder()
    .appName("UDF from withColumn").master("local[*]").getOrCreate()
  import testedSparkSession.implicits._
  val orders = Seq((1L), (2L), (3L), (4L)).toDF("order_id")

  val generate_user_id =
    testedSparkSession.udf.register("generate_user_id", (orderId: Long) => s"user${orderId}")


  val ordersWithUserId =
    orders.withColumn("user_id", generate_user_id($"order_id"))
      .map(row => (row.getAs[Long]("order_id"), row.getAs[String]("user_id")))
      .collect()

  ordersWithUserId should have size 4
  ordersWithUserId should contain allOf((1L, "user1"), (2L, "user2"), (3L, "user3"), (4L, "user4"))
}