SaveMode.Overwrite trap with RDBMS in Apache Spark SQL

Versions: Apache Spark 2.3.0

Some months ago I presented save modes in Spark SQL. However, this post was limited to their use in files. I was quite surprised to observe some specific behavior of them for RDBMS sinks. Especially for SaveMode.Overwrite.

The post begins by a short test case showing that intriguing problem. The second part presents how it can be solved with a small effort. In the final section we discover other save modes available for RDBMS sinks.

The problem

The problem happens when you try to use SaveMode.Overwrite with Spark SQL and RDBMS sinks. The following test case shows it:

// Test on MySQL because H2 is NoopDialect and it returns None for 
// org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils.isCascadingTruncateTable
// and because of that every time the table is recreated
private val Connection = "jdbc:mysql://127.0.0.1:3306/wfc_tests?serverTimezone=UTC"
private val User = "root"
private val Password = "root"
private val mysqlConnector = new MysqlConnector(Connection, User, Password)
before {
  mysqlConnector.executedSideEffectQuery("CREATE TABLE users (user_login VARCHAR(20) NOT NULL)")
  val users = (1 to 5).map(id => UserInformation(s"User${id}"))
  mysqlConnector.populateTable("INSERT INTO users (user_login) VALUES (?)", users)
}

after {
  mysqlConnector.executedSideEffectQuery("DROP TABLE users")
}

private def mapResultSetToColumnType: (ResultSet) => String = (resultSet) => {
  // index#1 = column name
  // index#2 = column type
  resultSet.getString(2)
}

"Spark" should "override the column type with overwrite save mode" in {
  val newUsers = Seq(
    ("User100"), ("User101"), ("User102")
  ).toDF("user_login")

  newUsers.write.mode(SaveMode.Overwrite).jdbc(Connection, "users", ConnectionProperties)

  val allUsers = mysqlConnector.getRows("SELECT * FROM users", (resultSet) => resultSet.getString("user_login"))
  allUsers should have size 3
  allUsers should contain allOf("User100", "User101", "User102")
  val columnTypes = mysqlConnector.getRows("DESC users", mapResultSetToColumnType)
  columnTypes should have size 1
  columnTypes(0) shouldEqual "text"
}

As you can see, the table was initially created with one column of VARCHAR type. However after the execution of the query, the table was dropped and recreated from the schema deduced by Apache Spark. After some digging we can find the source of this behavior in JdbcRelationProvider.createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], df: DataFrame) method, where one of possible scenarios is executed for overwrite save mode:

  val tableExists = JdbcUtils.tableExists(conn, options)
  if (tableExists) {
    mode match {
      case SaveMode.Overwrite =>
        if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
          // In this case, we should truncate table and then load.
          truncateTable(conn, options)
          val tableSchema = JdbcUtils.getSchemaOption(conn, options)
          saveTable(df, tableSchema, isCaseSensitive, options)
        } else {
          // Otherwise, do not truncate the table, instead drop and recreate it
          dropTable(conn, options.table)
          createTable(conn, df, options)
          saveTable(df, Some(df.schema), isCaseSensitive, options)
        }
// ...

Possible solution

Above source code snippet gives immediately an idea how to solve the issue. The framework provides a JDBC option called truncate that by default is set to false. As you can see in the snippet, when it's false, it automatically drops and recreate the table. The other condition from the snippet verifies if the database uses cascading truncate on all tables referencing truncated table with foreign keys. Thus, a solution for our problem could look like:

"truncate option enabled" should "prevent table against recreating with unexpected type" in {
  val newUsers = Seq(
    ("User100"), ("User101"), ("User102")
  ).toDF("user_login")

  newUsers.write.mode(SaveMode.Overwrite)
    .option("truncate", true)
    .jdbc(Connection, "users", ConnectionProperties)

  val allUsers = mysqlConnector.getRows("SELECT * FROM users", (resultSet) => resultSet.getString("user_login"))
  allUsers should have size 3
  allUsers should contain allOf("User100", "User101", "User102")
  val columnTypes = mysqlConnector.getRows("DESC users", mapResultSetToColumnType)
  columnTypes should have size 1
  columnTypes(0) shouldEqual "varchar(20)"
}

SaveModes in RDBMS

All SaveModes are supported in RDBMS but, except overwrite, remaining ones behave pretty naturally. SaveMode.Append adds new columns to already existing table. SaveMode.ErrorIfExists will trigger an error while SaveMode.Ignore will do nothing in the case of such conflict:

"append mode" should "add new rows without removing existing data or deleting the table" in {
  val newUsers = Seq(
    ("User100"), ("User101"), ("User102")
  ).toDF("user_login")

  newUsers.write.mode(SaveMode.Append)
    .jdbc(Connection, "users", ConnectionProperties)

  val allUsers = mysqlConnector.getRows("SELECT * FROM users", (resultSet) => resultSet.getString("user_login"))
  allUsers should have size 8
  allUsers should contain allOf("User1", "User2", "User3", "User4", "User5",  "User100", "User101", "User102")
  val columnTypes = mysqlConnector.getRows("DESC users", mapResultSetToColumnType)
  columnTypes should have size 1
  columnTypes(0) shouldEqual "varchar(20)"
}

"ignore mode" should "do nothing when the table already exists" in {
  val newUsers = Seq(
    ("User100"), ("User101"), ("User102")
  ).toDF("user_login")

  newUsers.write.mode(SaveMode.Ignore)
    .jdbc(Connection, "users", ConnectionProperties)

  val allUsers = mysqlConnector.getRows("SELECT * FROM users", (resultSet) => resultSet.getString("user_login"))
  allUsers should have size 5
  allUsers should contain allOf("User1", "User2", "User3", "User4", "User5")
  val columnTypes = mysqlConnector.getRows("show columns from users", mapResultSetToColumnType)
  columnTypes should have size 1
  columnTypes(0) shouldEqual "varchar(20)"
}

"error mode" should "throw an exception for insert to already existing table" in {
  import sparkSession.implicits._
  val newUsers = Seq(
    ("User100"), ("User101"), ("User102")
  ).toDF("user_login")

  val analysisException = intercept[AnalysisException] {
    newUsers.write.mode(SaveMode.ErrorIfExists)
      .jdbc(Connection, "users", ConnectionProperties)
  }
  analysisException.message should include("Table or view 'users' already exists.")
}

This post presented some subtle differences between overwrite save mode for RDBMS and files. The most important one is that Spark will recreate database table when truncate flag is left to false. In such case the engine, as we could see in the first post's section, may create old-new table with incorrectly deduced schema. The second part shown how to overcome that issue with the help of truncate JDBC option.