Photo by Jez Timms on Unsplash
October 20, 2020

Unit Testing Pyspark Code

Navin Vembar

CTO, Camber


Testing code in a distributed data pipeline is not always easy. PySpark is the Python-language API for the Spark analytics engine - extremely useful for manipulating large data sets for ETL, machine learning and other big data tasks. But, the distributed architecture of PySpark and some of its operational quirks can create pitfalls. Sometimes that means that you have to expend some effort to make tests run at all or that your tests aren’t testing what you expect. And errors in PySpark code often translate directly to money, time, and storage spent on producing data that’s not useful. At the least, running unit tests will prevent spinning up an EMR cluster that’s destined to fail.

Normally, unit-testing Python code is fairly straightforward. The unittest builtin libraries and additional libraries such as pytest are fantastic and allow for robust testing of Python code. Unit testing data transformation code is just one part of making sure that your pipeline is producing data fit for the decisions it’s supporting.

Let’s start with PySpark 3.x - the most recent major version of PySpark - to start. There’s some differences on setup with PySpark 2.7.x which we’ll cover at the end. We’re going to assume that you know a little bit about package management and pytest below, but the code should be pretty readable.

Using poetry as our package manager, we install our packages:

$ poetry init # Follow the instructions that Poetry gives you
# Install the version that's most recent as of this writing
$ poetry add pyspark=3.0.1 
$ poetry add --dev pytest

Let’s create a file called adding.py that looks like this:

from pyspark.sql import Row, SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T


@F.udf(returnType=T.IntegerType())
def add_columns(value1, value2):
    """
    A UDF which sums two integer columns
    """
    return value1 + value2


def build_dataframe(spark, count):
    """
    Builds a dataframe with integer columns `col1` and `col2`
    and `count` rows.
    """
    rows = [Row(col1=i, col2=i + 2) for i in range(count)]
    df = spark.createDataFrame(rows)
    return df


def add_dataframe_columns(df):
    """
    Adds a column `added_column` which sums `col1` and `col2`
    """
    return df.withColumn("added_column",
                         add_columns(F.col("col1"), F.col("col2")))

The add_columns function is a user-defined function that can be used natively by PySpark to enhance the already rich set of functions that PySpark supports for manipulating data. UDFs can accomplish sophisticated tasks and should be indepdently tested. Our testing strategy here is not to test the native functionality of PySpark, but to test whether our functions act as they should.

We want to test whether our user-defined function (UDF) actually adds columns, whether we construct a dataframe with the right number of rows, and whether our function which modifies a dataframe acts as expected. (Note: In PySpark 3.x, UDFs can use Python type hints to understand what kind of UDF is being declared. We don’t do that here because we’ll be using this same code for Pyspark 2.7 later.)

So, let’s create a tests in test_add.py with the following imports.

import pytest
import adding
import pyspark.sql.functions as F
from pyspark.sql import Row, SparkSession
from pyspark.sql.utils import PythonException

The first thing we need to make sure that PySpark is actually accessible to the our test functions. We need a fixture.

@pytest.fixture(scope="session")
def spark_session():
    return SparkSession.builder.getOrCreate()

This is going to get called once for the entire run (scope="session"). Calling this repeatedly will just make the tests take longer. Much of what we’re going to do is use this fixture to test the validity of the code outputs. But, let’s call out a few important decisions.

def test_build(spark_session):
    df = adding.build_dataframe(spark_session, 2000)
    assert df.count() == 2000

Note first that test_build takes spark_session as an argument, using the fixture defined above it. But it’s important to note that the build_dataframe function takes a SparkSession as an argument. In fact, in the cases where a function needs a session to run, making sure that that session is a function argument rather than constructed in the function itself makes for a much more easily testable codebase.

Now let’s go ahead and test our UDF. Add the following to test_add.py and run poetry run pytest.

def test_add_udf(spark_session):
    df = spark_session.createDataFrame([Row(a=1, b=7), Row(a=2, b=8)])
    df = df.withColumn("added", adding.add_columns(F.col("a"), F.col("b")))
    rows = df.collect()
    for row in rows:
        assert row["added"] == row["a"] + row["b"]

Here, we’re making sure that the UDF is actually adding. The df.collect() executes the Spark queries. But how do we make sure that the addition fails where it makes sense to fail?

def test_udf_raises_type_execption(spark_session):
    df = spark_session.createDataFrame(
        [Row(a=1, b="not an integer"), Row(a=2, b="not an integer")]
    )
    with pytest.raises(PythonException):
        df = df.withColumn(
            "added",
            adding.add_columns(F.col("a"), F.col("b"))
        )

Hmmm… running poetry run pytest fails - no exception is thrown. The reason is that PySpark’s queries are lazy, meaning that they don’t get executed unless something actually requires computation. In the first test function, df.collect() made that happen, but writes to file, running count() or other calls can realize the resulting data frame. So let’s modify our test function:

def test_udf_raises_type_exception(spark_session):
    df = spark_session.createDataFrame(
        [Row(a=1, b="not an integer"), Row(a=2, b="not an integer")]
    )
    df = df.withColumn(
        "added",
        adding.add_columns(F.col("a"), F.col("b"))
    )
    with pytest.raises(PythonException):
        _ = df.collect()

Now it successfully passes!

Lastly, let’s test add_dataframe_columns function to make sure it modifies our DataFrame as expected.

def test_add_dataframe_columns(spark_session):
    data = [Row(col1=1, col2=7), Row(col1=2, col2=8)]
    df = spark_session.createDataFrame(data)
    with_added = adding.add_dataframe_columns(df)
    rows = with_added.collect()
    for row in rows:
        assert row["added_column"] == row["col1"] + row["col2"]

It passes! Notice we didn’t use build_dataframe() here, to focus our test only on add_dataframe_columns().

Rewinding a bit, what if you - like many - are still using PySpark 2.4? Let’s try it out.

$ poetry remove pyspark
$ poetry add pyspark=2.4.7
$ poetry run pytest

Well, that blows up - the first thing we’ll have to resolve is this:

_________________________________________ ERROR collecting test_add.py _________________________________________
ImportError while importing test module '/..../test_add.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../.pyenv/versions/3.7.6/lib/python3.7/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
test_add.py:5: in <module>
    from pyspark.sql.utils import PythonException
E   ImportError: cannot import name 'PythonException' from 'pyspark.sql.utils'

In PySpark 3.x, some exception handling changes were made. So, we need to make some adjustments there. Remove the PythonException import and change the code for our exception test to:

def test_udf_raises_type_exception(spark_session):
    df = spark_session.createDataFrame(
        [Row(a=1, b="not an integer"), Row(a=2, b="not an integer")]
    )
    df = df.withColumn("added", adding.add_columns(F.col("a"), F.col("b")))
    with pytest.raises(Exception):
        _ = df.collect()

OK, let’s try again.

=========================================== short test summary info ============================================
ERROR test_add.py::test_build - Exception: Java gateway process exited before sending its port number
ERROR test_add.py::test_add_udf - Exception: Java gateway process exited before sending its port number
ERROR test_add.py::test_udf_raises_type_exception - Exception: Java gateway process exited before sending its...
ERROR test_add.py::test_add_dataframe_columns - Exception: Java gateway process exited before sending its por...

Well, at least the error is different this time. But, what’s going on? Here’s where PySpark 2.4 isn’t so friendly to users just trying to run things locally.

But also, we have another problem: PySpark 2.4 only runs on Java 8 - and, if you’re like me and installed Java with Homebrew, you don’t have JDK 8. So, head on over to Adopt Open JDK and get Java 8 and install it. Set JAVA_HOME and your PATH variables so that JDK 8 is the running JDK. You’ll also need to get Spark - select the 2.4.7 version with Hadoop included and unpack it to ~/spark.

For me, that meant installing the JDK file at the link above and then running:

$ export JAVA_HOME=/Library/Java/JavaVirtualMachines/adoptopenjdk-8.jdk/Contents/Home
$ export PATH=${JAVA_HOME}/bin:${PATH}
$ export SPARK_HOME=${HOME}/spark

Once that’s done, poetry run pytest should succeed again!

What happened here? The environment needed to be explicitly set up to start up Spark to run the tests, which meant that Java had to be set up correctly and the Hadoop and PySpark implementations were available for successful execution.

We’ll see next time how to package up all of this with Docker to make things a bit easier.