How to recursively chain a Celery task that returns a list into a group?
Asked Answered
N

2

5

I started from this question: How to chain a Celery task that returns a list into a group?

But I want to expand twice. So in my use case I have:

  • task A: determines total number of items for a given date
  • task B: downloads 1000 metadata entries for that date
  • task C: download the content for one item

So each step I'm expanding the number of items of the next step. I can do it by looping through the results in my task and calling .delay() on the next task function. But I thought I'd try to not make my main tasks do that. Instead they'd return a list of tuples - each tuple would then be expanded into the arguments for a call to the next function.

The above question has an answer that appears to meet my need, but I can't work out the correct way of chaining it for a two level expansion.

Here is a very cut down example of my code:

from celery import group
from celery.task import subtask
from celery.utils.log import get_task_logger

from .celery import app

logger = get_task_logger(__name__)

@app.task
def task_range(upper=10):
    # wrap in list to make JSON serializer work
    return list(zip(range(upper), range(upper)))

@app.task
def add(x, y):
    logger.info(f'x is {x} and y is {y}')
    char = chr(ord('a') + x)
    char2 = chr(ord('a') + x*2)
    result = x + y
    logger.info(f'result is {result}')
    return list(zip(char * result, char2 * result))

@app.task
def combine_log(c1, c2):
    logger.info(f'combine log is {c1}{c2}')

@app.task
def dmap(args_iter, celery_task):
    """
    Takes an iterator of argument tuples and queues them up for celery to run with the function.
    """
    logger.info(f'in dmap, len iter: {len(args_iter)}')
    callback = subtask(celery_task)
    run_in_parallel = group(callback.clone(args) for args in args_iter)
    return run_in_parallel.delay()

I've then tried various ways to make my nested mapping work. First, a one level mapping works fine, so:

pp = (task_range.s() | dmap.s(add.s()))
pp(2)

Produces the kind of results I'd expect, so I'm not totally off.

But when I try to add another level:

ppp = (task_range.s() | dmap.s(add.s() | dmap.s(combine_log.s())))

Then in the worker I see the error:

[2019-11-23 22:34:12,024: ERROR/ForkPoolWorker-2] Task proj.tasks.dmap[e92877a9-85ce-4f16-88e3-d6889bc27867] raised unexpected: TypeError("add() missing 2 required positional arguments: 'x' and 'y'",)
Traceback (most recent call last):
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/app/trace.py", line 385, in trace_task
    R = retval = fun(*args, **kwargs)
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/app/trace.py", line 648, in __protected_call__
    return self.run(*args, **kwargs)
  File "/home/hdowner/dev/playground/celery/proj/tasks.py", line 44, in dmap
    return run_in_parallel.delay()
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/canvas.py", line 186, in delay
    return self.apply_async(partial_args, partial_kwargs)
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/canvas.py", line 1008, in apply_async
    args=args, kwargs=kwargs, **options))
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/canvas.py", line 1092, in _apply_tasks
    **options)
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/canvas.py", line 578, in apply_async
    dict(self.options, **options) if options else self.options))
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/canvas.py", line 607, in run
    first_task.apply_async(**options)
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/canvas.py", line 229, in apply_async
    return _apply(args, kwargs, **options)
  File "/home/hdowner/.venv/play_celery/lib/python3.6/site-packages/celery/app/task.py", line 532, in apply_async
    check_arguments(*(args or ()), **(kwargs or {}))
TypeError: add() missing 2 required positional arguments: 'x' and 'y'

And I'm not sure why changing the argument to dmap() from a plain task signature to a chain changes how the arguments get passed into add(). My impression was that it shouldn't, it just means the return value of add() would get passed on. But apparently that is not the case ...

Nihility answered 23/11, 2019 at 22:49 Comment(0)
N
4

Turns out the problem is that the clone() method on a chain instance does not pass the arguments through at some point - see https://mcmap.net/q/455341/-cloning-a-celery-chain for the full details. If I use the method in that answer, my dmap() code becomes:

@app.task
def dmap(args_iter, celery_task):
    """
    Takes an iterator of argument tuples and queues them up for celery to run with the function.
    """
    callback = subtask(celery_task)
    run_in_parallel = group(clone_signature(callback, args) for args in args_iter)
    return run_in_parallel.delay()


def clone_signature(sig, args=(), kwargs=(), **opts):
    """
    Turns out that a chain clone() does not copy the arguments properly - this
    clone does.
    From: https://mcmap.net/q/455341/-cloning-a-celery-chain
    """
    if sig.subtask_type and sig.subtask_type != "chain":
        raise NotImplementedError(
            "Cloning only supported for Tasks and chains, not {}".format(sig.subtask_type)
        )
    clone = sig.clone()
    if hasattr(clone, "tasks"):
        task_to_apply_args_to = clone.tasks[0]
    else:
        task_to_apply_args_to = clone
    args, kwargs, opts = task_to_apply_args_to._merge(args=args, kwargs=kwargs, options=opts)
    task_to_apply_args_to.update(args=args, kwargs=kwargs, options=deepcopy(opts))
    return clone

And then when I do:

ppp = (task_range.s() | dmap.s(add.s() | dmap.s(combine_log.s())))

everything works as expected.

Nihility answered 24/11, 2019 at 22:53 Comment(1)
Can we do like (task_range.s() | dmap.s(add.s() | dmap.s(combine_log.s())) | some_tasks.s()) ?Swanson
M
1

Thanks for the great answer. I had to tweak the code to make sure it could handle tasks with single arguments. I am sure this is awful, but it works! Any improvements appreciated.

@celery_app.task(name='app.worker.dmap')
def dmap(args_iter, celery_task):
    """
    Takes an iterator of argument tuples and queues them up for celery to run with the function.
    """
    callback = subtask(celery_task)
    print(f"ARGS: {args_iter}")
    args_list = []
    run_in_parallel = group(clone_signature(callback, args if type(args) is list else [args]) for args in args_iter)
    print(f"Finished Loops: {run_in_parallel}")
    return run_in_parallel.delay()

Specifically - I added:

if type(args) is list else [args]

to this line:

run_in_parallel = group(clone_signature(callback, args ***if type(args) is list else [args]***) for args in args_iter)
Misanthropy answered 23/1, 2021 at 23:7 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.