Testing sensors in Apache Airflow

Versions: Apache Airflow 1.10.3

Unit tests are the backbone of any software, data-oriented included. However testing some parts that way may be difficult, especially when they interact with the external world. Apache Airflow sensor is an example coming from that category. Fortunately, thanks to Python's dynamic language properties, testing sensors can be simplified a lot.

Data Engineering Design Patterns

Looking for a book that defines and solves most common data engineering problems? I'm currently writing one on that topic and the first chapters are already available in πŸ‘‰ Early Release on the O'Reilly platform

I also help solve your data engineering problems πŸ‘‰ contact@waitingforcode.com πŸ“©

In this post I will show you how to use Python's properties to test sensors in Apache Airflow. I'll start by presenting the sensor I would like to test. In the next part, I will show an example of a unit test for it.

Tested sensor

Let's say that we want to integrate AWS Athena into our batch pipeline. Fortunately, a sensor for the query execution is already provided and it looks like (comments omitted):

class AthenaSensor(BaseSensorOperator):
    INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',)
    FAILURE_STATES = ('FAILED', 'CANCELLED',)
    SUCCESS_STATES = ('SUCCEEDED',)

    template_fields = ['query_execution_id']
    template_ext = ()
    ui_color = '#66c3ff'

    @apply_defaults
    def __init__(self,
                 query_execution_id,
                 max_retires=None,
                 aws_conn_id='aws_default',
                 sleep_time=10,
                 *args, **kwargs):
        super(AthenaSensor, self).__init__(*args, **kwargs)
        self.aws_conn_id = aws_conn_id
        self.query_execution_id = query_execution_id
        self.hook = None
        self.sleep_time = sleep_time
        self.max_retires = max_retires

    def poke(self, context):
        self.hook = self.get_hook()
        self.hook.get_conn()
        state = self.hook.poll_query_status(self.query_execution_id, self.max_retires)

        if state in self.FAILURE_STATES:
            raise AirflowException('Athena sensor failed')

        if state in self.INTERMEDIATE_STATES:
            return False
        return True

    def get_hook(self):
        return AWSAthenaHook(self.aws_conn_id, self.sleep_time)

I'm using here an already existent sensor just to keep things simple. It doesn't mean that you should test built-in sensors - no, it's the responsibility of Apache Airflow committers. That's why for the scope of this article we suppose that AthenaSensor is the sensor that we've developed especially for our project.

Test

Since the task is a sensor, we want here to assert on the poke result. In our case AthenaSensor exposes a method called get_hook which returns the class responsible for Athena connection. It's the first class we have to mock:

class MockedAthenaHook:
    def __init__(self, query_status):
        self.query_status = query_status

    def get_conn(self):
        pass

    def poll_query_status(self, query_execution_id, max_retries):
        return self.query_status

The first test will check whether the sensor returns a readiness state for a successful query execution. To test this case, I will simply override get_hook method of the tested sensor:

class TerminatedAthenaSensor(AthenaSensor):

    def __init__(self):
        super(TerminatedAthenaSensor, self).__init__(task_id="aa", key="test", query_execution_id=10)

    def get_hook(self):
        return MockedAthenaHook('SUCCEEDED')

tested_sensor = TerminatedAthenaSensor()

is_ready = tested_sensor.poke(None)

assert is_ready

Now, doing the opposite is quite similar:

class NotTerminatedAthenaSensor(AthenaSensor):
    def __init__(self):
        super(NotTerminatedAthenaSensor, self).__init__(task_id="aa", key="test",  query_execution_id=10)

    def get_hook(self):
        return MockedAthenaHook('QUEUED')

tested_sensor = NotTerminatedAthenaSensor()

is_ready = tested_sensor.poke(None)

assert not is_ready

As you can see, it works but it's also a little bit boilerplate. Every time we need to create a class to mock the tested behavior, similarly to Java-style Mockito-based tests. Actually, we can do it easier by dynamically overriding the method of the whole instance instead of a class:

def mocked_hook():
    return MockedAthenaHook('SUCCEEDED')
tested_sensor = AthenaSensor(task_id="aa", key="test",  query_execution_id=10)
tested_sensor.get_hook = mocked_hook

is_ready = tested_sensor.poke(None)

assert is_ready

real_athena_sensor = AthenaSensor(task_id="aa", key="test",  query_execution_id=10)
assert real_athena_sensor.get_hook().__class__ is AWSAthenaHook

As you can see in this example, we override only the method returning the hook and that modification is limited to the scope of the tested instance. And we did it with much less code than previously

In this post you can see how to use Python to write tests of apparently hard to test Apache Airflow parts like sensors. The first part presented a sensor waiting for AWS Athena query results, so requiring an AWS connection. The second part gave 2 different approaches to test a sensor unitary, one object-oriented and one dynamically typed.


If you liked it, you should read:

πŸ“š Newsletter Get new posts, recommended reading and other exclusive information every week. SPAM free - no 3rd party ads, only the information about waitingforcode!