Conditional function chaining in Python
Asked Answered
H

2

6

Imagine that there is a function g I want to implement by chaining sub-functions. This can be easily done by:

def f1(a):
    return a+1

def f2(a):
    return a*2

def f3(a):
    return a**3

g = lambda x: f1(f2(f3(x)))

However, now consider that, which sub-functions will be chained together, depends on conditions: specifically, user-specified options which are known in advance. One could of course do:

def g(a, cond1, cond2, cond3):

    res = a
    if cond1:
        res = f3(res)
    if cond2:
        res = f2(res)
    if cond3:
        res = f1(res)
    return res

However, instead of dynamically checking these static conditions each time the function is called, I assume that it's better to define the function g based on its constituent functions in advance. Unfortunately, the following gives a RuntimeError: maximum recursion depth exceeded:

g = lambda x: x
if cond1:
    g = lambda x: f3(g(x))
if cond2:
    g = lambda x: f2(g(x))
if cond3:
    g = lambda x: f1(g(x))

Is there a good way of doing this conditional chaining in Python? Please note that the functions to be chained can be N, so it's not an option to separately define all 2^N function combinations (8 in this example).

Halflength answered 12/3, 2019 at 19:8 Comment(10)
g = lambda x: f(g(x)) will blow up your stack because the tail call never ends.Kovno
how do you intent to define those N functions? Are they all gonna be named functions, so f1, f2, ...fN or will you put the into a dict or something? I'm asking because this will pretty much define how to chain them efficiently.Vestment
You can assume they are all named functions.Halflength
@RockyLi Yes, basically I want to avoid recursion, and instead use g as a function object defined previously in the codeHalflength
if you say that for N functions there are N^2 combinations then a single function can be called up to N times. If each function appears only once in the chain you'd have N! (and not N*N) cobinationsVestment
does the order matter?Vestment
@jojo I wrote 2^N, not N^2. Basically, there are N available defined functions, each of them potentially applied in a specific, fixed order, and for each we have the option of either using or not using it (once). To me, that is 2^N. N! would mean that you certainly want to use all of them, and the options specify the order.Halflength
sorry my bad. you are right, if the order is fixed then we're down to 2*N. I was assuming that the order might vary, then 2*N! would be the potential configs (not N! as i wrote initially). I wrote a solution that can also deal with variable ordering, if needed.Vestment
This could be made to work (not saying that it's necessarily a good idea) by writing your conditionals as g = lambda x, g=g: f3(g(x)) (the default parameter captures the previous value of g, rather than recursively referring to the new value).Malinin
@Malinin That's a great one-line solution! Besides being a bit cryptic, why do you think it's not a good idea?Halflength
S
4

I found one solution with usage of decorators. Have a look:

def f1(x):
    return x + 1

def f2(x):
    return x + 2

def f3(x):
    return x ** 2


conditions = [True, False, True]
functions = [f1, f2, f3]


def apply_one(func, function):
    def wrapped(x):
        return func(function(x))
    return wrapped


def apply_conditions_and_functions(conditions, functions):
    def initial(x):
        return x

    function = initial

    for cond, func in zip(conditions, reversed(functions)):
        if cond:
            function = apply_one(func, function)
    return function


g = apply_conditions_and_functions(conditions, functions)

print(g(10)) # 101, because f1(f3(10)) = (10 ** 2) + 1 = 101

The conditions are checked only once when defining g function, they are not checked when calling it.

Sperrylite answered 12/3, 2019 at 20:2 Comment(0)
K
3

The most structurally similar code I can think of have to be structured in the following way, your f1.. f3 will need to become pseudo decorators, like this:

def f1(a):
    def wrapper(*args):
        return a(*args)+1
    return wrapper

def f2(a):
    def wrapper(*args):
        return a(*args)*2
    return wrapper

def f3(a):
    def wrapper(*args):
        return a(*args)**3
    return wrapper

And then you can apply these to each functions.

g = lambda x: x
if cond1:
    g = f3(g)
if cond2:
    g = f2(g)
if cond3:
    g = f1(g)
g(2)

Returns:

# Assume cond1..3 are all True
17 # (2**3*2+1)
Kovno answered 12/3, 2019 at 20:7 Comment(2)
Thank you for this solution. It's basically the same as Sanyash's solution, making use of decorators. Yours is structurally most similar to the question, as you mentioned, while his uses an extra abstraction layer which allows the definitions of f1...fN to remain clean. In the end I accepted his answer, because he was 5' faster. I am not very experienced in stackoverflow, so let me know if you think that I should have chosen yours instead.Halflength
Don't worry about choosing - answering is about helping, not scoring. I'm glad his answer worked for you.Kovno

© 2022 - 2024 — McMap. All rights reserved.