User Defined Aggregate Functions

Versions: Spark 2.2.0

User Defined Functions are not the single way to extend Spark SQL. The second solution is offered by User Defined Aggregate Functions.

This post divided in 2 parts showing User Defined Aggregate Functions. The first section describes their integration in data processing workflow. The second part shows a sample User Defined Aggregate Function.


The User Defined Aggregate Function (UDAF) is not a Spark SQL-reserved feature. As User Defined Functions, it comes from the RDBMS world. It exists for instance in PostgreSQL where it's created with CREATE AGGREGATE... statement. In Spark SQL the UDAF must implement the org.apache.spark.sql.expressions.UserDefinedAggregateFunction abstract class. It's composed of methods describing data types and aggregation process, such as:

Partial aggregation

The idea of partial aggregation consists on executing aggregation functions on a subset of data in parallel on different machines. The result of this execution must be the same as the function was executed sequentially on a single server.

The functions supporting partial aggregation must provide a combiner function. Its role is to merge the results computed by partial aggregations.

As you could see, the classes involved in the aggregation are the implementations of MutableAggregationBuffer. The implementation used by default is MutableAggregationBufferImpl. It's a mutable row that enables persisting aggregation results between 2 subsequent calls of UserDefinedAggregateFunction#update method on a particular group of input rows.

The UDAF can be invoked exactly in the same manners as UDF. It can be registered through org.apache.spark.sql.UDFRegistration#register(name: String, udaf: UserDefinedAggregateFunction) and used later as an expression. It can also be created as new class instance and called directly.

UDAF example

In order to show UDAF use, we'll take the example of session reconstruction. Let's imagine that our dataset groups the user's activity logs as JSON files in partitioned directories. Using the aggregations is one of possible solutions to reconstruct user activity session:

class UserDefinedAggregationFunctionTest extends FlatSpec with Matchers with BeforeAndAfterAll {

  private val sparkSession: SparkSession = SparkSession.builder().appName("UDAF test").master("local")

  import sparkSession.implicits._
  private val Sessions = Seq(
    (1, 100, "categories.html"), (2, 100, "index.html"), (3, 100, "contact.html"),
    (1, 150, "categories/home.html"), (1, 200, "cart.html"), (3, 300, "reclaim_form.html")
  ).toDF("user", "time", "page")

  override def afterAll() {

  "sessions" should "be aggregated with UDAF" in {
    val sessionsAggregator = new SessionsAggregator

    val aggregatedSessions = Sessions.groupBy("user").agg(sessionsAggregator(Sessions.col("time"), Sessions.col("page")))

    val expectedResults = Map(
      1 -> Map(100 -> "categories.html", 150 -> "categories/home.html", 200 -> "cart.html"),
      2 -> Map(100 -> "index.html"),
      3 -> Map(100 -> "contact.html", 300 -> "reclaim_form.html")
    aggregatedSessions.foreach(aggregationResult => {
      val aggregationKey = aggregationResult.getInt(0)
      val sessionMap = aggregationResult.getMap[Long, String](1)
      sessionMap should equal(expectedResults(aggregationKey))

  "sessions length" should "be computed with UDAF registered with UDF" in {
    sparkSession.udf.register("SessionLength_registerTest", new SessionDurationAggregator)

    val sessionsDurations = Sessions.groupBy("user").agg("time" -> "SessionLength_registerTest")

    val expectedResults = Map(1 -> 100, 2 -> 0, 3 -> 200)
    sessionsDurations.foreach(aggregationResult => {
      val aggregationKey = aggregationResult.getInt(0)
      val sessionDuration = aggregationResult.getLong(1)
      sessionDuration should equal(expectedResults(aggregationKey))


class SessionsAggregator extends UserDefinedAggregateFunction {

  private val AggregationMapType = MapType(LongType, StringType, valueContainsNull = false)

  override def inputSchema: StructType = StructType(Seq(
    StructField("time", LongType, false), StructField("page", StringType, false)

  override def bufferSchema: StructType = StructType(Seq(
    StructField("visited_pages", AggregationMapType, false)

  override def dataType: DataType = AggregationMapType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, new TreeMap[Long, String]())

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val currentTimes = buffer.getMap[Long, String](0)
    val newTimes = currentTimes + (input.getLong(0) -> input.getString(1))
    buffer.update(0, newTimes)

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val currentTimes = buffer1.getMap[Long, String](0)
    val newTimes = currentTimes ++ buffer2.getMap[Long, String](0)
    buffer1.update(0, newTimes)

  override def evaluate(buffer: Row): Any = {

class SessionDurationAggregator extends UserDefinedAggregateFunction {

  override def inputSchema: StructType = StructType(Seq(StructField("time", LongType, false)))

  override def bufferSchema: StructType = StructType(Seq(
    StructField("first_log_time", LongType, false), StructField("last_log_time", LongType, false)

  override def dataType: DataType = LongType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Long.MaxValue)
    buffer.update(1, Long.MinValue)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val sessionStartTime = buffer.getLong(0)
    val sessionLastLogTime = buffer.getLong(1)
    val logTime = input.getLong(0)
    if (logTime < sessionStartTime) {
      buffer.update(0, logTime)
    if (logTime > sessionLastLogTime) {
      buffer.update(1, logTime)

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    val buffer1StartTime = buffer1.getLong(0)
    val buffer2StartTime = buffer2.getLong(0)
    if (buffer2StartTime < buffer1StartTime) {
      buffer1.update(0, buffer2StartTime)
    val buffer1EndTime = buffer1.getLong(1)
    val buffer2EndTime = buffer2.getLong(1)
    if (buffer2EndTime > buffer1EndTime) {
      buffer1.update(1, buffer2EndTime)

  override def evaluate(buffer: Row): Any = {
    val sessionStartTime = buffer.getLong(0)
    val sessionLastLogTime = buffer.getLong(1)
    sessionLastLogTime - sessionStartTime

Spark SQL brings a lot of mechanisms enhancing its basic features. One of them are UDAF helping to apply custom aggregations. As explained in the first section, their must extends the UserDefinedAggregateFunction abstract class and implement methods defining data types and aggregation operation. Two use cases were shown in the second section where UDAF were used to aggregate sessions and compute the duration of each of them.