I know this is almost exactly what you specified you don't want to do - so apologies if it's no use (perhaps could be useful for someone else). It does differ very slightly from your requirement as it sets retries to a number initially and then checks whether to honour that retry number based on the contents of exceptions - rather than setting retries dynamically as an error message is encountered. i.e. it uses on_retry_callback
instead of on_failure_callback
.
You could set retries to the number you would like, but then use an on_retry_callback
to alter the task State
. Here's an example that always generates exceptions, but manipulates the task State
based on the exception
name:
from airflow.decorators import dag,task
from airflow.utils.state import State
from datetime import timedelta, datetime
import random
@dag(
dag_id="retry_testing"
,tags=['utils','experimental']
,schedule_interval=None
,start_date=datetime(2020,1,1)
,description="Testing the on_retry_callback parameter"
,params={"key":"value"}
,render_template_as_native_obj=True
)
def taskflow():
def exception_parser(context):
"""
A function that checks the class name of the Exception thrown.
Different exceptions trigger behaviour of allowing the task to fail, retry or succeed
"""
print('retrying...')
ti = context["task_instance"]
exception_raised = context.get('exception')
if exception_raised.__class__.__name__ == 'ZeroDivisionError':
print("div/0 error, setting task to failed")
ti.set_state(State.FAILED)
elif exception_raised.__class__.__name__ == 'TypeError':
print("Type Error - setting task to success")
ti.set_state(State.SUCCESS)
else:
print("Not div/0 error, trying again...")
@task(
retries=10,
retry_delay=timedelta(seconds=3),
on_retry_callback=exception_parser,
)
def random_error():
"""Does some common runtime errors based on the value of a random number."""
r = random.randrange(0,10)
print(f"random integer = {r}")
if r in [0,1,2]:
# Produce a ZeroDivisionError
x = 1/0
print(x)
elif r in [3,4,5]:
# Produce a TypeError
x = 'not a number'+1
else:
# Produce a KeyError
mydict = {"thiskey":"foo"}
get_missing_key = mydict["thatkey"]
print(get_missing_key)
random_error()
taskflow()
This has been tested on MWAA v2.5.1