DAG marked as "success" if one task fails, because of trigger rule ALL_DONE
Asked Answered
C

4

21

I have the following DAG with 3 tasks:

start --> special_task --> end

The task in the middle can succeed or fail, but end must always be executed (imagine this is a task for cleanly closing resources). For that, I used the trigger rule ALL_DONE:

end.trigger_rule = trigger_rule.TriggerRule.ALL_DONE

Using that, end is properly executed if special_task fails. However, since end is the last task and succeeds, the DAG is always marked as SUCCESS.

How can I configure my DAG so that if one of the tasks failed, the whole DAG is marked as FAILED?

Example to reproduce

import datetime

from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.utils import trigger_rule

dag = DAG(
    dag_id='my_dag',
    start_date=datetime.datetime.today(),
    schedule_interval=None
)

start = BashOperator(
    task_id='start',
    bash_command='echo start',
    dag=dag
)

special_task = BashOperator(
    task_id='special_task',
    bash_command='exit 1', # force failure
    dag=dag
)

end = BashOperator(
    task_id='end',
    bash_command='echo end',
    dag=dag
)
end.trigger_rule = trigger_rule.TriggerRule.ALL_DONE

start.set_downstream(special_task)
special_task.set_downstream(end)

This post seems to be related, but the answer does not suit my needs, since the downstream task end must be executed (hence the mandatory trigger_rule).

Curren answered 7/8, 2018 at 13:49 Comment(4)
I'm not aware of a way to configure this at a DAG level. You could play with task flow to make something else propagate failure status, or use on_failure_callback to get notified about failed task.Upside
@JustinasMarozas Actually, I already have an on_failure_callback to get notified, but I would like my DAG marked as failed in the Web UI.Curren
If you create a dummy task and set it as downstream to special_task I'd expect failure to propagate. It is more of a bandage than a solution though.Upside
@JustinasMarozas indeed, your solution works, thanks! But I thought an out-of-the-box solution exists since it's a pretty common use case. However, for people facing the same issue, I will answer the question with your solution and will mark it as an answer if no other solution is found. Thanks for your help.Curren
C
7

As @JustinasMarozas explained in a comment, a solution is to create a dummy task like :

dummy = DummyOperator(
    task_id='test',
    dag=dag
)

and bind it downstream to special_task :

failing_task.set_downstream(dummy)

Thus, the DAG is marked as failed, and the dummy task is marked as upstream_failed.

Hope there is an out-of-the-box solution, but waiting for that, this solution does the job.

Curren answered 9/8, 2018 at 8:5 Comment(0)
X
11

I thought it was an interesting question and spent some time figuring out how to achieve it without an extra dummy task. It became a bit of a superfluous task, but here's the end result:

This is the full DAG:

import airflow
from airflow import AirflowException
from airflow.models import DAG, TaskInstance, BaseOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.utils.db import provide_session
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule

default_args = {"owner": "airflow", "start_date": airflow.utils.dates.days_ago(3)}

dag = DAG(
    dag_id="finally_task_set_end_state",
    default_args=default_args,
    schedule_interval="0 0 * * *",
    description="Answer for question https://stackoverflow.com/questions/51728441",
)

start = BashOperator(task_id="start", bash_command="echo start", dag=dag)
failing_task = BashOperator(task_id="failing_task", bash_command="exit 1", dag=dag)


@provide_session
def _finally(task, execution_date, dag, session=None, **_):
    upstream_task_instances = (
        session.query(TaskInstance)
        .filter(
            TaskInstance.dag_id == dag.dag_id,
            TaskInstance.execution_date == execution_date,
            TaskInstance.task_id.in_(task.upstream_task_ids),
        )
        .all()
    )
    upstream_states = [ti.state for ti in upstream_task_instances]
    fail_this_task = State.FAILED in upstream_states

    print("Do logic here...")

    if fail_this_task:
        raise AirflowException("Failing task because one or more upstream tasks failed.")


finally_ = PythonOperator(
    task_id="finally",
    python_callable=_finally,
    trigger_rule=TriggerRule.ALL_DONE,
    provide_context=True,
    dag=dag,
)

succesful_task = DummyOperator(task_id="succesful_task", dag=dag)

start >> [failing_task, succesful_task] >> finally_

Look at the _finally function, which is called by the PythonOperator. There are a few key points here:

  1. Annotate with @provide_session and add argument session=None, so you can query the Airflow DB with session.
  2. Query all upstream task instances for the current task:
upstream_task_instances = (
    session.query(TaskInstance)
    .filter(
        TaskInstance.dag_id == dag.dag_id,
        TaskInstance.execution_date == execution_date,
        TaskInstance.task_id.in_(task.upstream_task_ids),
    )
    .all()
)
  1. From the returned task instances, get the states and check if State.FAILED is in there:
upstream_states = [ti.state for ti in upstream_task_instances]
fail_this_task = State.FAILED in upstream_states
  1. Perform your own logic:
print("Do logic here...")
  1. And finally, fail the task if fail_this_task=True:
if fail_this_task:
    raise AirflowException("Failing task because one or more upstream tasks failed.")

The end result:

enter image description here

Xylophagous answered 2/3, 2019 at 21:21 Comment(2)
This works, but it incorrectly sets "finally" to failed, when it didn't. It would be better if you could mark it as upstream failed.Langille
For current version (2.4.x) it is State.UPSTREAM_FAILED rather than State.FAILED to be looked for in upstream_statesInterjacent
C
7

As @JustinasMarozas explained in a comment, a solution is to create a dummy task like :

dummy = DummyOperator(
    task_id='test',
    dag=dag
)

and bind it downstream to special_task :

failing_task.set_downstream(dummy)

Thus, the DAG is marked as failed, and the dummy task is marked as upstream_failed.

Hope there is an out-of-the-box solution, but waiting for that, this solution does the job.

Curren answered 9/8, 2018 at 8:5 Comment(0)
R
4

To expand on Bas Harenslak answer, a simpler _finally function which will check the state of all tasks (not only the upstream ones) can be:

def _finally(**kwargs):
    for task_instance in kwargs['dag_run'].get_task_instances():
        if task_instance.current_state() != State.SUCCESS and \
                task_instance.task_id != kwargs['task_instance'].task_id:
            raise Exception("Task {} failed. Failing this DAG run".format(task_instance.task_id))
Reggie answered 26/11, 2019 at 17:45 Comment(0)
C
0

Solution in case you have a lot of tasks that can fail but trigger rule ALL_DONE of one task makes dag be in a success state at the end of your pipeline:

  1. Collect your failing tasks into list:

tasks = [failing_task, another_one]

  1. Make dummy finish operator with trigger_rule="all_success"

finish = DummyOperator(task_id="finish", dag=dag, trigger_rule="all_success")

  1. Map each failing task on the operator

mapped = list(map(lambda x: x >> finish, tasks))

If one of the tasks fails, your dag marked as failed

Catanzaro answered 31/5, 2023 at 7:4 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.