How to write unittest for @task decorated Airflow tasks?
Asked Answered
A

3

6

I am trying to write unittests for some of the tasks built with Airflow TaskFlow API. I tried multiple approaches for example, by creating a dagrun or only running the task function but nothing is helping.

Here is a task where I download a file from S3, there is more stuff going on but I removed that for this example.

@task()
def updates_process(files):
    context = get_current_context()
    try:
        updates_file_path = utils.download_file_from_s3_bucket(files.get("updates_file"))
    except FileNotFoundError as e:
        log.error(e)
        return

    # Do something else

Now I was trying to write a test case where I can check this except clause. Following is one the example I started with

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = account_link_updates_process({"updates_file": "path/to/file.csv"})
        get_current_context.assert_called_once()
        log.error.assert_called_once()

I also tried by creating a dagrun as shown in the example here in docs and fetching the task from the dagrun but that also didin't help.

Abate answered 13/7, 2022 at 22:52 Comment(0)
S
4

I was struggling to do this myself, but I found that the decorated tasks have a .function parameter.

You can then use Task.function() to call the actual function. Using your example:

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = dags.delta_load.updates.updates_process
        # Call the function for testing
        task.function({"updates_file": "path/to/file.csv"})
        get_current_context.assert_called_once()
        log.error.assert_called_once()

This prevents you from having to set up any of the DAG infrastructure and just run the python function as intended!

Superhuman answered 22/7, 2022 at 16:6 Comment(2)
Thanks but it didn't work for me. I got error AttributeError: 'function' object has no attribute 'function'. Upon some investigation I can use task.__wrapped__() the way you mentioned.Abate
.function worked for me too. So in your example it would be assert updates_process.function(["file1", "file2", ...])Jadejaded
A
0

This is what I could figure out. Not sure if this is the right thing but it works.

class TestAccountLinkUpdatesProcess(TestCase):
    TASK_ID = "updates_process"

    @classmethod
    def setUpClass(cls) -> None:
        cls.dag = dag_delta_load()

    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        task = self.dag.get_task(task_id=self.TASK_ID)
        task.op_args = [{"updates_file": "file.csv"}]
        task.execute(context={})
        log.error.assert_called_once()

UPDATE: Based on the answer of @AetherUnbound I did some investigation and found that we can use task.__wrapped__() to call the actual python function.

class TestAccountLinkUpdatesProcess(TestCase):
    @mock.patch("dags.delta_load.updates.log")
    @mock.patch("dags.delta_load.updates.get_current_context")
    @mock.patch("dags.delta_load.updates.utils.download_file_from_s3_bucket")
    def test_file_not_found_error(self, download_file_from_s3_bucket, get_current_context, log):
        download_file_from_s3_bucket.side_effect = FileNotFoundError
        update_process.__wrapped__({"updates_file": "file.csv"})
        log.error.assert_called_once()
Abate answered 14/7, 2022 at 12:36 Comment(0)
E
0

If you are trying to run the dag as part of your unit tests, and are finding it difficult to get access to the actual dag itself due to the Airflow Taskflow API decorators, you can do something like this in your tests:

class TestSomething(unittest.TestCase):

    def test_something(self):

        dags = []
        real_dag_enter = DAG.__enter__

        def fake_dag_enter(dag):
            # Whenever a dag gets created (behind the scenes),
            # DAG.__enter__(self) gets called - you can use this
            # to keep a list of dags that get created so that
            # you can run unit tests on them.

            dags.append(dag)

            return real_dag_enter(self=dag)

        with patch.object(DAG, "__enter__", new=fake_dag_enter):

            for dag in dags:

                # Run your dag here, and assert
                # that it does what it's supposed to.
                #
                # (I have a utility class that loops through the
                # tasks in the dag and executes each one.)

Elegant answered 2/10, 2023 at 21:49 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.