Python Redis Queue (rq) - how to avoid preloading ML model for each job?
Asked Answered
G

2

17

I want to queue my ml predictions using rq. Example code (pesudo-ish):

predict.py:

import tensorflow as tf

def predict_stuff(foo):
    model = tf.load_model()
    result = model.predict(foo)
    return result

app.py:

from rq import Queue
from redis import Redis
from predict import predict_stuff

q = Queue(connection=Redis())
for foo in baz:
    job = q.enqueue(predict_stuff, foo)

worker.py:

import sys
from rq import Connection, Worker

# Preload libraries
import tensorflow as tf

with Connection():
    qs = sys.argv[1:] or ['default']

    w = Worker(qs)
    w.work()

I've read rq docs explaining that you can preload libraries to avoid importing them every time a job is run (so in example code I import tensorflow in the worker code). However, I also want to move model loading from predict_stuff to avoid loading the model every time the worker runs a job. How can I go about that?

Guile answered 30/8, 2018 at 14:1 Comment(1)
I tried moving the model load outside predict_stuff and importing predict inside the worker, but then the model is loaded in app as well, which is not desirable.Guile
A
5

I'm not sure if this is something that can help but, following the example here:

https://github.com/rq/rq/issues/720

Instead of sharing a connection pool, you can share the model.

pseudo code:

import tensorflow as tf

from rq import Worker as _Worker
from rq.local import LocalStack

_model_stack = LocalStack()

def get_model():
    """Get Model."""
    m = _model_stack.top
    try:
        assert m
    except AssertionError:
        raise('Run outside of worker context')
    return m

class Worker(_Worker):
    """Worker Class."""

    def work(self, burst=False, logging_level='WARN'):
        """Work."""
        _model_stack.push(tf.load_model())
        return super().work(burst, logging_level)

def predict_stuff_job(foo):
    model = get_model()
    result = model.predict(foo)
    return result

I use something similar to this for a "global" file reader I wrote. Load up the instance into the LocalStack and have the workers read off the stack.

Ankh answered 11/12, 2018 at 22:40 Comment(0)
G
3

In the end I haven't figured out how to do it with python-rq. I moved to celery where I did it like this:

app.py

from tasks import predict_stuff

for foo in baz:
    task = predict_stuff.delay(foo)

tasks.py

import tensorflow as tf
from celery import Celery
from celery.signals import worker_process_init

cel_app = Celery('tasks')
model = None

@worker_process_init.connect()
def on_worker_init(**_):
    global model
    model = tf.load_model()

@cel_app.task(name='predict_stuff')
def predict_stuff(foo):
    result = model.predict(foo)
    return result
Guile answered 14/9, 2018 at 9:13 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.