Unit tests are the backbone of any software, data-oriented included. However testing some parts that way may be difficult, especially when they interact with the external world. Apache Airflow sensor is an example coming from that category. Fortunately, thanks to Python's dynamic language properties, testing sensors can be simplified a lot.
Looking for a better data engineering position and skills?
You have been working as a data engineer but feel stuck? You don't have any new challenges and are still writing the same jobs all over again? You have now different options. You can try to look for a new job, now or later, or learn from the others! "Become a Better Data Engineer" initiative is one of these places where you can find online learning resources where the theory meets the practice. They will help you prepare maybe for the next job, or at least, improve your current skillset without looking for something else.
👉 I'm interested in improving my data engineering skillset
See you there, Bartosz
In this post I will show you how to use Python's properties to test sensors in Apache Airflow. I'll start by presenting the sensor I would like to test. In the next part, I will show an example of a unit test for it.
Tested sensor
Let's say that we want to integrate AWS Athena into our batch pipeline. Fortunately, a sensor for the query execution is already provided and it looks like (comments omitted):
class AthenaSensor(BaseSensorOperator): INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',) FAILURE_STATES = ('FAILED', 'CANCELLED',) SUCCESS_STATES = ('SUCCEEDED',) template_fields = ['query_execution_id'] template_ext = () ui_color = '#66c3ff' @apply_defaults def __init__(self, query_execution_id, max_retires=None, aws_conn_id='aws_default', sleep_time=10, *args, **kwargs): super(AthenaSensor, self).__init__(*args, **kwargs) self.aws_conn_id = aws_conn_id self.query_execution_id = query_execution_id self.hook = None self.sleep_time = sleep_time self.max_retires = max_retires def poke(self, context): self.hook = self.get_hook() self.hook.get_conn() state = self.hook.poll_query_status(self.query_execution_id, self.max_retires) if state in self.FAILURE_STATES: raise AirflowException('Athena sensor failed') if state in self.INTERMEDIATE_STATES: return False return True def get_hook(self): return AWSAthenaHook(self.aws_conn_id, self.sleep_time)
I'm using here an already existent sensor just to keep things simple. It doesn't mean that you should test built-in sensors - no, it's the responsibility of Apache Airflow committers. That's why for the scope of this article we suppose that AthenaSensor is the sensor that we've developed especially for our project.
Test
Since the task is a sensor, we want here to assert on the poke result. In our case AthenaSensor exposes a method called get_hook which returns the class responsible for Athena connection. It's the first class we have to mock:
class MockedAthenaHook: def __init__(self, query_status): self.query_status = query_status def get_conn(self): pass def poll_query_status(self, query_execution_id, max_retries): return self.query_status
The first test will check whether the sensor returns a readiness state for a successful query execution. To test this case, I will simply override get_hook method of the tested sensor:
class TerminatedAthenaSensor(AthenaSensor): def __init__(self): super(TerminatedAthenaSensor, self).__init__(task_id="aa", key="test", query_execution_id=10) def get_hook(self): return MockedAthenaHook('SUCCEEDED') tested_sensor = TerminatedAthenaSensor() is_ready = tested_sensor.poke(None) assert is_ready
Now, doing the opposite is quite similar:
class NotTerminatedAthenaSensor(AthenaSensor): def __init__(self): super(NotTerminatedAthenaSensor, self).__init__(task_id="aa", key="test", query_execution_id=10) def get_hook(self): return MockedAthenaHook('QUEUED') tested_sensor = NotTerminatedAthenaSensor() is_ready = tested_sensor.poke(None) assert not is_ready
As you can see, it works but it's also a little bit boilerplate. Every time we need to create a class to mock the tested behavior, similarly to Java-style Mockito-based tests. Actually, we can do it easier by dynamically overriding the method of the whole instance instead of a class:
def mocked_hook(): return MockedAthenaHook('SUCCEEDED') tested_sensor = AthenaSensor(task_id="aa", key="test", query_execution_id=10) tested_sensor.get_hook = mocked_hook is_ready = tested_sensor.poke(None) assert is_ready real_athena_sensor = AthenaSensor(task_id="aa", key="test", query_execution_id=10) assert real_athena_sensor.get_hook().__class__ is AWSAthenaHook
As you can see in this example, we override only the method returning the hook and that modification is limited to the scope of the tested instance. And we did it with much less code than previously
In this post you can see how to use Python to write tests of apparently hard to test Apache Airflow parts like sensors. The first part presented a sensor waiting for AWS Athena query results, so requiring an AWS connection. The second part gave 2 different approaches to test a sensor unitary, one object-oriented and one dynamically typed.