Simple way to do multiple dispatch in python? (No external libraries or class building?)
Asked Answered
T

3

9

I'm writing a throwaway script to compute some analytical solutions to a few simulations I'm running.

I would like to implement a function in a way that, based on its inputs, will compute the right answer. So for instance, say I have the following math equation:

tmax = (s1 - s2) / 2 = q * (a^2 / (a^2 - b^2))

It seems simple to me that I should be able to do something like:

def tmax(s1, s2):
    return (s1 - s2) / 2

def tmax(a, b, q):
    return q * (a**2 / (a**2 - b**2))

I may have gotten to used to writing in julia, but I really don't want to complicate this script more than I need to.

Trajectory answered 16/1, 2020 at 0:40 Comment(3)
I really have no idea what you are asking. You can't define 2 functions with the same name.Rhino
@Rhino The question is asking about a specific case of multiple dispatch, where only the number of arguments affects the implementation to call, and not the runtime type.Nagari
@Rhino Like Fengyang said, it’s called multiple dispatch and it is possible in python, but everything I’ve read so far requires external packages or building a custom class. I have a minor but specific use case. So I want to know if there are features inherent to the language that would let me do this.Trajectory
M
4

In statically typed languages like C++, you can overload functions based on the input parameter types (and quantity) but that's not really possible in Python. There can only be one function of any given name.

What you can do is to use the default argument feature to select one of two pathways within that function, something like:

def tmax(p1, p2, p3 = None):
    # Two-argument variant has p3 as None.

    if p3 is None:
        return (p1 - p2) / 2

    # Otherwise, we have three arguments.

    return (p1 * p1 / (p1 * p1 - p2 * p2)) * p3

If you're wondering why I've change the squaring operations from n ** 2 to n * n, it's because the latter is faster (or it was, at some point in the past, at least for small integral powers like 2 - this is probably still the case but you may want to confirm).

A possible case where it may be faster to do g1 ** 2 rather than g1 * g1 is where g1 is a global rather than a local (it takes longer for the Python VM to LOAD_GLOBAL rather than LOAD_FAST). This is not the case with the code posted since the argument is inherently non-global.

Mceachern answered 16/1, 2020 at 0:48 Comment(2)
One reason why n ** 2 may be faster is that it involves loading n only once, instead of twice. This would be essentially irrelevant for local variables (which are quickly loaded) but can be important for global ones (which are expensive to load).Nagari
Good point, @Fengyang, doesn't matter in this particular case since the arguments are grabbed via LOAD_FAST rather than LOAD_GLOBAL. Will update the answer to clarify.Mceachern
C
7

Just thought I'd offer two more options.

Multiple Dispatch Type Checking

Python has native support for @overload annotations.

It won't impact runtime, but it will notify your IDE & static analysis tools if you elect to use them.

Fundamentally your implementation will be the same nasty series of hacks Python wants you to use, but you'll have better debugging support.

This SO post explains it better, I've modified the code example to show multiple parameters:

# << Beginning of additional stuff >>
from typing import overload


@overload
def hello(s: int) -> str:
    ...


@overload
def hello(s: str) -> str:
    ...


@overload
def hello(s: int, b: int | float | str) -> str:
    ...
# << End of additional stuff >>

# Janky python overload
def hello(s, b=None):
    if b is None:
        if isinstance(s, int):
            return "s is an integer!"
        if isinstance(s, str):
            return "s is a string!"
    if b is not None:
        if isinstance(s, int) and isinstance(b, int | float | str):
            return "s is an integer & b is an int / float / string!"

    raise ValueError('You must pass either int or str')
print(hello(1))           # s is an integer!
print(hello("Blah"))      # s is a string!
print(hello(11, 1))       # s is an integer & b is an int / float / string!
print(hello(11, "Blah"))  # s is an integer & b is an int / float / string!

My IDE puts a normal error line under the offending argument.

print(hello("Blah", "Blah"))  
# >> ValueError: You must pass either int or str
# PyCharm warns "Blah" w/ "Expected type 'int', got 'str' instead"

print(hello(1, [0, 1]))  
# >> ValueError: You must pass either int or str
# PyCharm warns [0, 1] w/ "Expected type 'int | float | str', got 'list[int]' instead"

print(hello(1, 1) + 1)  
# >> TypeError: can only concatenate str (not "int") to str
# PyCharm warns "+ 1" w/ "Expected type 'str', got 'int' instead"

This is the most direct answer to the post.

SINGLE DISPATCH:

If you only need single dispatch for functions or class methods, have a look at the slightly recent @singledispatch and @singledispatchmethod annotations.

Functions:

from functools import singledispatch


@singledispatch
def coolAdd(a, b):
    raise NotImplementedError('Unsupported type')

@coolAdd.register(int)
@coolAdd.register(float)
def _(a, b):
    print(a + b)

@coolAdd.register(str)
def _(a, b):
    print((a + " " + b).upper())
coolAdd(1, 2)                     # 3
coolAdd(0.1, 0.2)                 # 0.30000000000000004
coolAdd('Python', 'Programming')  # PYTHON PROGRAMMING
coolAdd(b"hi", b"hello")          # NotImplementedError: Unsupported type

Python 3.11 should include union operators for even easier reading (as of now you just put each type as an individual decorator).

Methods:

class CoolClassAdd:

    @singledispatchmethod
    def addMethod(self, arg1, arg2):
        raise NotImplementedError('Unsupported type')

    @addMethod.register(int)
    @addMethod.register(float)
    def _(self, arg1, arg2):
        print(f"Numbers = %s" % (arg1 + arg2))

    @addMethod.register(str)
    def _(self, arg1, arg2):
        print(f"Strings = %s %s" % (arg1, (arg2).upper()))
c = CoolClassAdd()
c.addMethod(1, 2)           # Numbers = 3
c.addMethod(0.1, 0.2)       # Numbers = 0.30000000000000004
c.addMethod(0.1, 2)         # Numbers = 2.1
c.addMethod("hi", "hello")  # hi HELLO

Static & class methods are also supported (and most bugs are resolved as of 3.9.7).

However, beware! Dispatch appears to check only the first (non-self) argument type when evaluating which function/method to use.

c.addMethod(1, "hello")     
# >> TypeError: unsupported operand type(s) for +: 'int' and 'str'

Of course, this would normally call for advanced error handling OR implementing multiple dispatch, and now we're back to where we started!

Configurationism answered 29/8, 2022 at 7:3 Comment(0)
M
4

In statically typed languages like C++, you can overload functions based on the input parameter types (and quantity) but that's not really possible in Python. There can only be one function of any given name.

What you can do is to use the default argument feature to select one of two pathways within that function, something like:

def tmax(p1, p2, p3 = None):
    # Two-argument variant has p3 as None.

    if p3 is None:
        return (p1 - p2) / 2

    # Otherwise, we have three arguments.

    return (p1 * p1 / (p1 * p1 - p2 * p2)) * p3

If you're wondering why I've change the squaring operations from n ** 2 to n * n, it's because the latter is faster (or it was, at some point in the past, at least for small integral powers like 2 - this is probably still the case but you may want to confirm).

A possible case where it may be faster to do g1 ** 2 rather than g1 * g1 is where g1 is a global rather than a local (it takes longer for the Python VM to LOAD_GLOBAL rather than LOAD_FAST). This is not the case with the code posted since the argument is inherently non-global.

Mceachern answered 16/1, 2020 at 0:48 Comment(2)
One reason why n ** 2 may be faster is that it involves loading n only once, instead of twice. This would be essentially irrelevant for local variables (which are quickly loaded) but can be important for global ones (which are expensive to load).Nagari
Good point, @Fengyang, doesn't matter in this particular case since the arguments are grabbed via LOAD_FAST rather than LOAD_GLOBAL. Will update the answer to clarify.Mceachern
N
1

You can do this using an optional argument:

def tmax_2(s1, s2):
    return (s1 - s2) / 2

def tmax_3(a, b, q):
    return q * (a**2 / (a**2 - b**2))

def tmax(a, b, c=None):
    if c is None:
        return tmax_2(a, b)
    else:
        return tmax_3(a, b, c)
Nagari answered 16/1, 2020 at 0:44 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.