How can I wrap a python function in a way that works with with inspect.signature?
Asked Answered
W

1

7

Some uncontroversial background experimentation up front:

import inspect

def func(foo, bar):
  pass

print(inspect.signature(func))  # Prints "(foo, bar)" like you'd expect

def decorator(fn):
  def _wrapper(baz, *args, *kwargs):
    fn(*args, **kwargs)

  return _wrapper

wrapped = decorator(func)
print(inspect.signature(wrapped))  # Prints "(baz, *args, **kwargs)" which is totally understandable

The Question

How can implement my decorator so that print(inspect.signature(wrapped)) spits out "(baz, foo, bar)"? Can I build _wrapper dynamically somehow by adding the arguments of whatever fn is passed in, then gluing baz on to the list?

The answer is NOT

def decorator(fn):
  @functools.wraps(fn)
  def _wrapper(baz, *args, *kwargs):
    fn(*args, **kwargs)

  return _wrapper

That give "(foo, bar)" again - which is totally wrong. Calling wrapped(foo=1, bar=2) is a type error - "Missing 1 required positional argument: 'baz'"

I don't think it's necessary to be this pedantic, but

def decorator(fn):
  def _wrapper(baz, foo, bar):
    fn(foo=foo, bar=bar)

  return _wrapper

Is also not the answer I'm looking for - I'd like the decorator to work for all functions.

Wigfall answered 14/5, 2022 at 0:13 Comment(0)
U
3

You can use __signature__ (PEP) attribute to modify returned signature of wrapped object. For example:

import inspect


def func(foo, bar):
    pass


def decorator(fn):
    def _wrapper(baz, *args, **kwargs):
        fn(*args, **kwargs)

    f = inspect.getfullargspec(fn)

    fn_params = []
    if f.args:
        for a in f.args:
            fn_params.append(
                inspect.Parameter(a, inspect.Parameter.POSITIONAL_OR_KEYWORD)
            )

    if f.varargs:
        fn_params.append(
            inspect.Parameter(f.varargs, inspect.Parameter.VAR_POSITIONAL)
        )

    if f.varkw:
        fn_params.append(
            inspect.Parameter(f.varkw, inspect.Parameter.VAR_KEYWORD)
        )

    _wrapper.__signature__ = inspect.Signature(
        [
            inspect.Parameter("baz", inspect.Parameter.POSITIONAL_OR_KEYWORD),
            *fn_params,
        ]
    )
    return _wrapper


wrapped = decorator(func)
print(inspect.signature(wrapped))

Prints:

(baz, foo, bar)

If the func is:

def func(foo, bar, *xxx, **yyy):
    pass

Then print(inspect.signature(wrapped)) prints:

(baz, foo, bar, *xxx, **yyy)
Umbelliferous answered 14/5, 2022 at 17:40 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.