Airflow - Use TaskGroup and PythonBranchOperator in the same DAG
Asked Answered
R

2

12

I am currently using Airflow Taskflow API 2.0. I am having an issue of combining the use of TaskGroup and BranchPythonOperator.

Below is my code:

import airflow
from airflow.models import DAG
from airflow.decorators import task, dag
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator, PythonOperator
from airflow.operators.python import task, get_current_context
from random import randint
from airflow.utils.task_group import TaskGroup


default_args = {
    'owner': 'Airflow',
    'start_date': airflow.utils.dates.days_ago(2),
}

@task
def dummy_task():
    return {}


@task
def task_b():
    return {}

@task
def task_c():
    return {}

def final_step():
    return {}

def get_tasks(**kwargs):
    task = 'task_a'

    return task


with DAG(dag_id='branch_dag', 
    default_args=default_args, 
    schedule_interval=None) as dag:

    with TaskGroup('task_a') as task_a:
        obj = dummy_task()

    tasks = BranchPythonOperator(
        task_id='check_api',
        python_callable=get_tasks,
        provide_context=True
    )

    final_step = PythonOperator(
        task_id='final_step',
        python_callable=final_step,
        trigger_rule='one_success'
    )

    b = task_b()
    c = task_c()

    tasks >> task_a >> final_step
    tasks >> b >> final_step
    tasks >> c >> final_step

When i trigger this DAG, i get the below error inside the check_api task:

airflow.exceptions.TaskNotFound: Task task_a not found

Is it possible to get this working and using TaskGroup in conjunction with BranchPythonOperator?

Thanks,

Revenge answered 27/5, 2021 at 10:38 Comment(1)
Aside: provide_context is deprecated as of v2.0 and is no longer required. Reference: github.com/apache/airflow/blob/main/airflow/operators/…Nitrite
S
18

BranchPythonOperator is expected to return task_ids

You need to change the get_tasksfunction to:

def get_tasks(**kwargs):
    task = 'task_a.dummy_task'
    return task

enter image description here

Slap answered 27/5, 2021 at 11:4 Comment(0)
H
0

What if your task group has more than one parallel task? Return a list with all task group task ids works but I'm wondering if there is a better way to do it.

@task.branch(task_id='branch_task')
def branch_func():
    regiones = Variable.get('regiones',  deserialize_json=True)    
    municipios_seleccionados = []
    for region_id, municipios in grupo_regiones.items():
        if region_id in regiones:
            municipios_seleccionados+=[f'region_{region_id}.municipio_{id}' for id in municipios]
    return municipios_seleccionados

enter image description here

Hay answered 7/4, 2023 at 1:57 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.