Fetchsize in Spark SQL

Versions: Spark 2.1.0

Spark SQL reading from RDBMS is based on classic JDBC drivers. Thus it supports some of their options, as fetchsize described in sections below.

This post explains what this fetch size parameter is. Sometimes it's misunderstood and is considered as an alternative to LIMIT statement. The first part of the post will prove that these both features haven't nothing in common. The next section will explain the role of fetch size in Spark SQL queries. The last part, through learning tests using the decorator written with the help of Javaassist (more about it in the post Manipulate bytecode with Javassist), will show that the theoretical explanation was correct.

Fetch size in JDBC

The fetch size is a JDBC property defining the number of rows fetched by a round trip. After reading this definition we'd think that it behaves as LIMIT statement but it's not true. To see that, let's define also the round trips. A round trip can be considered as a kind of opened network tunnel between the database and the application reading the data. The database, through this tunnel, sends rows queries per batches. The size of each batch corresponds to the size of fetch size property.

To be more clear, let's take an example. Our SELECT * FROM query returns 100 rows. But since the fetch size is specified to 5, the data will be send to the database per 10. In consequence, the database will send the rows 20 times (= in 20 trips).

What is the interest of fetch size ? Among others it helps to avoid memory problems. Without it, the client application risks to read and to process all rows returned by the select query at once. It can become a very memory intensive operation that can lead to increased GC activity (thus, the GC pauses) and, in the worst cases, to the OOM problems.

If we go back to the example described above, we could imagine the following code:

var accumulatedAmount = 0.0d
val resultsSet = statement.executeQuery("SELECT * FROM orders")
while (resutsSet.next() && accumulatedAmount < 1000d) {
  val currentRowAmount = resultsSet.getDouble("amount")
  accumulatedAmount += currentRowAmount
}

As you can see, the code above iterates over all orders until accumulate the sum of 1000 (I know I could be written with SQL query but it's juste an example). Now imagine that the fetch size is 10. If in the first trip, the 10 first commands value will be equal or greater to 1000, the JDBC client won't fetch all results in memory. But, in the other side, if it won't be, the client will get next 10 rows, until reaching the amount condition.

Fetch size in Spark SQL

As you can imagine now, Spark SQL uses fetch size to define the number of round trips that selected rows must do before arriving to the Spark application. Spark SQL reading RDBMS data defines the fetch size in JDBCRDD's compute(thePart: Partition, context: TaskContext) method:

val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause"
stmt = conn.prepareStatement(sqlText,
    ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
rs = stmt.executeQuery()

The rest of work is done by used JDBC driver. In the case of MySQL described in the tests in the next section, the code making round trips looks like in below snippet (some fragments are cut off for brevity reasons):

@Override
public Row next() {
  // ...
  if ((this.fetchedRows == null) 
      || (this.currentPositionInFetchedRows > (this.fetchedRows.size() - 1))) {
    fetchMoreRows();
    this.currentPositionInFetchedRows = 0;
  }
  // ...
}

private void fetchMoreRows() {
  if (this.lastRowFetched) {
    this.fetchedRows = new ArrayList<Row>(0);
    return;
  }

  synchronized (this.owner.getConnection().getConnectionMutex()) {
    try {
      boolean oldFirstFetchCompleted = this.firstFetchCompleted;

      if (!this.firstFetchCompleted) {
        this.firstFetchCompleted = true;
      }

      int numRowsToFetch = this.owner.getOwnerFetchSize();

      if (numRowsToFetch == 0) {
        numRowsToFetch = this.owner.getOwningStatementFetchSize();
      }

      if (numRowsToFetch == Integer.MIN_VALUE) {
        // Handle the case where the user used 'old' streaming result sets

        numRowsToFetch = 1;
      }

      if (this.fetchedRows == null) {
        this.fetchedRows = new ArrayList<Row>(numRowsToFetch);
      } else {
        this.fetchedRows.clear();
      }

      // TODO this is not the right place for this code, should be in protocol
      PacketPayload sharedSendPacket = this.protocol.getSharedSendPacket();
      sharedSendPacket.setPosition(0);

      sharedSendPacket.writeInteger(IntegerDataType.INT1, MysqlaConstants.COM_STMT_FETCH);
      sharedSendPacket.writeInteger(IntegerDataType.INT4, this.owner.getOwningStatementServerId());
      sharedSendPacket.writeInteger(IntegerDataType.INT4, numRowsToFetch);

      this.protocol.sendCommand(MysqlaConstants.COM_STMT_FETCH, null, sharedSendPacket, true, null, 0);

      Row row = null;

      while ((row = this.protocol.read(ResultsetRow.class, this.rowFactory)) != null) {
        this.fetchedRows.add(row);
      }

      this.currentPositionInFetchedRows = BEFORE_START_OF_ROWS;

      if (this.protocol.getServerSession().isLastRowSent()) {
        this.lastRowFetched = true;

        if (!oldFirstFetchCompleted && this.fetchedRows.size() == 0) {
          this.wasEmpty = true;
        }
      }
    } catch (Exception ex) {
        throw ExceptionFactory.createException(ex.getMessage(), ex);
    }
  }
}

Fetch size example

To see the use of fetch size, we'll use, unlike in other learning tests, not standalone database (H2) but MySQL installed out-of-scope of this post. The reason of this choice was the will to make tests against more production-like system:

private val FetchMoreRowsKey = "fetchMoreRows"
MethodInvocationDecorator.decorateClass("com.mysql.cj.mysqla.result.ResultsetRowsCursor", FetchMoreRowsKey).toClass

private val Connection = "jdbc:mysql://127.0.0.1:3306/wfc_tests?serverTimezone=UTC&useCursorFetch=true&autocommit=false"
private val User = "root"
private val Password = "root"
private val mysqlConnector = new MysqlConnector(Connection, User, Password)

private var sparkSession: SparkSession = null

before {
  sparkSession = SparkSession.builder().appName("Spark SQL fetch size test").master("local[*]").getOrCreate()
  MethodInvocationCounter.methodInvocations.remove(FetchMoreRowsKey)
  case class Order(customer: String, amount: Double) extends DataOperation {
    override def populatePreparedStatement(preparedStatement: PreparedStatement): Unit = {
      preparedStatement.setString(1, customer)
      preparedStatement.setDouble(2, amount)
    }
  }
  val ordersToInsert = mutable.ListBuffer[Order]()
  for (i <- 1 to 1000) {
    val amount = ThreadLocalRandom.current().nextDouble(1000)
    ordersToInsert.append(Order(UUID.randomUUID().toString, amount))
  }
  mysqlConnector.populateTable("INSERT INTO orders (customer, amount) VALUES (?, ?)", ordersToInsert)
}

after {
  mysqlConnector.cleanTable("orders")
  sparkSession.stop()
}

"fetch size smaller than the number of rows" should "make more than 1 round trip" in {
  val jdbcDataFrame = sparkSession.read
    .format("jdbc")
    .options(getOptionsMap(10))
    .load()

  jdbcDataFrame.foreach(row => {})

  MethodInvocationCounter.methodInvocations("fetchMoreRows") shouldEqual(101)
}

"fetch size smaller than the number of rows with internal filter" should "make more less than 101 round trip" in {
  // As you can see, this test has the same parameters as the previous one
  // But in the action, instead of iterating over all rows, it ignores
  // all rows with id smaller than 20. Thus logically, it should make
  // only 2 round trips to get rows 1-10 and 11-20
  // Please note that this sample works also because we use a single
  // partition. Otherwise, it should call rows.next() at least
  // once for each of partitions
  val jdbcDataFrame = sparkSession.read
    .format("jdbc")
    .options(getOptionsMap(10))
    .load()

  jdbcDataFrame.foreachPartition(rows => {
    while (rows.hasNext && rows.next.getAs[Int]("id") < 20) {
      // Do nothing, only to show the case
    }
  })

  MethodInvocationCounter.methodInvocations("fetchMoreRows") shouldEqual(2)
}


"fetch size greater than the number of rows" should "make only 1 round trip" in {
  val jdbcDataFrame = sparkSession.read
    .format("jdbc")
    .options(getOptionsMap(50000))
    .load()

  jdbcDataFrame.foreach(row => {})

  MethodInvocationCounter.methodInvocations("fetchMoreRows") shouldEqual(2)
}

private def getOptionsMap(fetchSize: Int): Map[String, String] = {
  Map("url" -> s"${Connection}",
    "dbtable" -> "orders", "user" -> s"${User}", "password" -> s"${Password}",
    "driver" -> "com.mysql.cj.jdbc.Driver", "fetchsize" -> s"${fetchSize}")
} 

This post explains the option of fetch size, available in the configuration of Spark SQL. The first part shown theoretically the role of fetch size parameter. We could learn that it has nothing in common with LIMIT statement. Instead of limiting returned rows, it acts more like a lazy iterator returning chunks of X next rows. One of tests in the 3rd section proved altough that the fetchMoreRows method, responsible for getting subsequent chunks, can be called only for a subset of rows.