PySpark schema inference and 'Can not infer schema for type str' error

Versions: Apache Spark 3.0.1

The title of this blog post is maybe one of the first problems you may encounter with PySpark (it was mine). Even though it's quite mysterious, it makes sense if you take a look at the root cause.

To introduce the problem, let's take this code executed with Apache Spark's Scala API:

  val singleColumn = Seq(
    ("a"), ("b"), ("c")
  ).toDF("letter")
  singleColumn.show()

It will run without problem and print:

+------+
|letter|
+------+
|     a|
|     b|
|     c|
+------+

However, if you translate this code to PySpark:

letters = [('a'), ('b'), ('c')]
spark.createDataFrame(letters, ['letter']).show(truncate=False)

You will get this exception instead of the DataFrame:

Traceback (most recent call last):
  File "/home/bartosz/workspace/spark-playground/pyspark-schema-inference/inference_from_one_column.py", line 14, in <module>
    spark.createDataFrame(letters, ['letter']).show(truncate=False)
  File "/home/bartosz/workspace/spark-playground/pyspark-schema-inference/.venv/lib/python3.6/site-packages/pyspark/sql/session.py", line 605, in createDataFrame
    return self._create_dataframe(data, schema, samplingRatio, verifySchema)
  File "/home/bartosz/workspace/spark-playground/pyspark-schema-inference/.venv/lib/python3.6/site-packages/pyspark/sql/session.py", line 630, in _create_dataframe
    rdd, schema = self._createFromLocal(map(prepare, data), schema)
  File "/home/bartosz/workspace/spark-playground/pyspark-schema-inference/.venv/lib/python3.6/site-packages/pyspark/sql/session.py", line 451, in _createFromLocal
    struct = self._inferSchemaFromList(data, names=schema)
  File "/home/bartosz/workspace/spark-playground/pyspark-schema-inference/.venv/lib/python3.6/site-packages/pyspark/sql/session.py", line 383, in _inferSchemaFromList
    schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
  File "/home/bartosz/workspace/spark-playground/pyspark-schema-inference/.venv/lib/python3.6/site-packages/pyspark/sql/session.py", line 383, in <genexpr>
    schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
  File "/home/bartosz/workspace/spark-playground/pyspark-schema-inference/.venv/lib/python3.6/site-packages/pyspark/sql/types.py", line 1067, in _infer_schema
    raise TypeError("Can not infer schema for type: %s" % type(row))
TypeError: Can not infer schema for type: <class 'str'>

As you can see in the stack trace, the part responsible for the error is the _infer_schema method. Inside you will find if-else statements that indicate pretty clearly why the DataFrame creation failed:

    if isinstance(row, dict):
# ...
    elif isinstance(row, (tuple, list)):
# ...
    elif hasattr(row, "__dict__"):  # object    
# ...
else:
        raise TypeError("Can not infer schema for type: %s" % type(row))

There is nothing you can do here except changing the instance creation method. Let's check the first one. Here, instead of using a simple string type, you have to use a dictionary, so automatically you will provide the structure of your DataFrame:

letters = [{'letter': 'a'}, {'letter': 'b'}, {'letter': 'c'}]
spark.createDataFrame(letters).show(truncate=False)

The execution of the above code is not ideal, though. Once executed, you will see a warning saying that "inferring schema from dict is deprecated, please use pyspark.sql.Row instead". However this deprecation warning is supposed to be un-deprecated in one of the next releases because it mirrors one of the Pandas' functionalities and is judged as being Pythonic enough to stay in the code. If you want to discover more about this challenge, follow the Project Zen epic on JIRA. Anyway, that's not the topic of this blog post (maybe the next one 🤔). Let's go back to the DataFrame initialization.

The second method from the list is a quite mysterious (tuple, list) type. What does it mean? Is it for (('letter'), ['a'])? Not really. If you check the internals, you will see that it's more for the classes exposing the __fields__ or _fields attributes. An example of the former one is Row whereas for the latter the namedtuple:

letters_tuple_list = [(('letter'), ['a']), (('letter'), ['b'])]
spark.createDataFrame(letters_tuple_list).show(truncate=False)

LetterRow = Row('letter')
letters_from_row = [LetterRow('a'), LetterRow('b')]
spark.createDataFrame(letters_from_row).show(truncate=False)

LetterNamedTuple = namedtuple('Letter', ['letter'])
letters_from_named_tuple = [LetterNamedTuple('a'), LetterNamedTuple('b')]
spark.createDataFrame(letters_from_row).show(truncate=False)

Currently this method will be also responsible for converting a pair into a DataFrame's row. And that's the reason one of the fixes to the initial problem of converting an ('a') into a DataFrame could be adding a new column to get ('a', 1):

letters_tuple_pair = [('a', 1), ('b', 1)]
spark.createDataFrame(letters_tuple_pair).show(truncate=False)

The final accepted method uses a class, so object type:

class LetterObject:
    def __init__(self, letter):
        self.letter = letter

letters_from_object = [LetterObject('a'), LetterObject('b')]
spark.createDataFrame(letters_from_object).show(truncate=False)

In the video below you can see all of these methods in action:

You can see then that there are multiple solutions to the problem of initializing the DataFrame with a single column from an in-memory dataset. If for whatever reason you have to do so, you don't have to add another column.