Chain a celery task's results into a distributed group
Asked Answered
B

3

7

Like in this other question, I want to create a celery group from a list that's returned by a celery task. The idea is that the first task will return a list, and the second task will explode that list into concurrent tasks for every item in the list.

The plan is to use this while downloading content. The first task gets links from a website, and the second task is a chain that downloads the page, processes it, and then uploads it to s3. Finally, once all the subpages are done, the website is marked as done in our DB. Something like:

chain(
    get_links_from_website.si('https://www.google.com'),
    dmap.s(  # <-- Distributed map
        download_sub_page.s() | 
        process_sub_page.s() | 
        upload_sub_page_to_s3.s()
    ),
    mark_website_done.s()
)

The solution I've seen so far seems to do an adequate job of this, but fails when the second task is a chain, due to issues with clone not doing a deepcopy (see the comments on this answer for details):

@task
def dmap(it, callback):
    # Map a callback over an iterator and return as a group
    callback = subtask(callback)
    return group(callback.clone([arg,]) for arg in it)()

It also has the problem that if the iterable is 10,000 items long, it will create a group with 10,000 items. That is blowing up our memory usage, as you can imagine.

So, what I'm looking for is a way to do dmap that:

  • Doesn't blow up RAM by creating monstrous groups (maybe there's a way to chunk through the iterable?)
  • Works on celery chains without issues with deepcopy.
Blossomblot answered 29/3, 2017 at 23:45 Comment(3)
I would create a chain for every link you find in the website and not use dmap. Create one task which does get_links_from_website and this submits one chain for each link.Azores
@bernhard: It's it a no-no to create tasks in tasks?Blossomblot
I don't know, I assume it's not. We do have beat tasks that do some regular status checks and launch other long running tasks to react to status changes. Works nicely.Azores
M
4

celery canvas provides chunks to split a task into chunks. Unfortunately, this won't work with primitives like chain, group.

You can use celery signals to prevent issues with dmap/clone.

ch = chain(
    download_sub_page.s(),
    process_sub_page.s(),
    upload_sub_page.s(),
)

@task_success.connect(sender='get_links_from_website')
def task_success_handler(sender=None, headers=None, body=None, **kwargs):
    result = kwargs['result']    
    header = [ch(i) for i in result]
    callback = mark_website_done.si()
    chord(header)(callback)

Create a chain for processing pages and hook the last task to it using a chord. This function gets executed whenever get_links_from_website runs succcessfully.

Depending on the time taken by chain, you can also save results of get_links_from_website somewhere. Then iterate over a batch of them to queue up chains and with the last batch, you can hook a callback to last task.

Metric answered 4/4, 2017 at 11:57 Comment(4)
This still doesn't handle the issue of flooding the queue. Crawling is an exponential process. How can I prevent my queue from growing exponentially as the crawler runs? There's no use for it to have more tasks than roughly 5× the number of workers at any time. More than that just wastes space. Put another way, even if I don't use dmap/groups, I'll still be able to create tasks faster than they can be completed.Blossomblot
@Blossomblot In this case, you should store URLs and use celery scheduler to create a job which will periodically check for the number of existing/scheduled tasks and queues up an appropriate number of tasks.Metric
It's an ongoing crawler and we want to get the links as quick as possible. A scheduled task would create delays.Blossomblot
I created a throttling tool here: https://mcmap.net/q/1624962/-how-to-throttle-script-that-creates-celery-tasks-faster-than-they-39-re-consumedBlossomblot
I
1

This is a bit hacky but we're using deepcopy to clone the callback, this fixes the bug with Signature's shallow copy

def dmap(it, callback, final=None):
    # Map a callback over an iterator and return as a group
    callback = subtask(callback)

    run_in_parallel = group(subtask(copy.deepcopy(dict(callback))).clone([arg, ]) for arg in it)

    if len(run_in_parallel.tasks) == 0:
        return []

    if final:
        return chord(run_in_parallel)(final)

    return run_in_parallel.delay()

Note that this will only work for one nesting level (i.e. callback is a chain/group/chord) but will not work for deeply nested callbacks

For deeply nested callback graphs we use this hack which is a bit slower but works flawlessly

# Hack to completely clone a signature with possibly complex subtasks (chains, chords, etc...)
run_in_parallel = group(pickle.loads(pickle.dumps(callback)).clone([arg, ]) for arg in it)

And for the size of the groups you can always split the iterator to chunks

Interjoin answered 26/6, 2017 at 10:41 Comment(0)
L
0

If anyone runs into this, Jether's answer helped a lot, but it wasn't perfect. For us, there were three issues:

  1. If the callback is itself a chain, the answer doesn't pass arguments onto the chain. https://mcmap.net/q/451677/-how-to-recursively-chain-a-celery-task-that-returns-a-list-into-a-group helps provide a solution to this, via clone_signature. This seems to work for reasonably nested chains using RabbitMQ as a broker, but we didn't try anything extreme (and thus didn't need to adapt it to use pickle).
  2. If the callback is a group or chord, we need to apply the arguments to each of the clone's task, so we modified the clone_signature from (1) to accommodate this case.
  3. After adding (1), passing final broke - we adopted the solution from https://github.com/celery/celery/issues/5265 to convert final from a dict to a Signature.
  4. Finally, we found that final wouldn't actually execute in many cases because chord was receiving a Group rather than a list of tasks.

For anyone curious, here's our final solution:

import copy

from celery import Signature, chord, group, shared_task, subtask


def clone_signature(sig, args=(), kwargs=(), **opts):
    """
    Turns out that a chain clone() does not copy the arguments properly - this
    clone does.
    From: https://mcmap.net/q/455341/-cloning-a-celery-chain
    """
    if sig.subtask_type and sig.subtask_type not in ["chain", "group", "chord"]:
        raise NotImplementedError(
            "Cloning only supported for tasks, chains, groups, and chords, not {}".format(
                sig.subtask_type
            )
        )
    clone = sig.clone()
    # if the task we're cloning is a group or chord, apply the arguments to each of the children
    if sig.subtask_type and sig.subtask_type in ["group", "chord"]:
        clone.tasks = [
            clone_signature(task, args=args, kwargs=kwargs, opts=opts)
            for task in clone.tasks
        ]
    # otherwise, apply the arguments to either the task itself (if it's a single task)
    # or the first child task (if it's a chain)
    else:
        if hasattr(clone, "tasks"):
            task_to_apply_args_to = clone.tasks[0]
        else:
            task_to_apply_args_to = clone
        args, kwargs, opts = task_to_apply_args_to._merge(
            args=args, kwargs=kwargs, options=opts
        )
        task_to_apply_args_to.update(
            args=args, kwargs=kwargs, options=copy.deepcopy(opts)
        )
    return clone


@shared_task
def dmap(it, callback, final=None):
    if not len(it):
        return []

    callback = subtask(callback)
    run_in_parallel = [
        clone_signature(callback, args if type(args) is list else [args]) for args in it
    ]

    if not final:
        return group(*run_in_parallel).delay()

    # see https://github.com/celery/celery/issues/5265
    if not isinstance(final, Signature):
        final["immutable"] = True
        final = Signature.from_dict(final)
    return chord(run_in_parallel)(final)

This allowed us to successfully execute nested dmaps like the following:

chain(
    taskA.s(),
    dmap.s(
        chain(
            taskB.s(),
            taskC.s(),
            dmap.s(
                taskD.s(),
                final=chain(
                    taskE.s(),
                    taskF.s(),
                ),
            ),
        ),
    ),
).delay()
Levon answered 30/8, 2022 at 22:2 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.