Python subclass method to inherit decorator from superclass method
Asked Answered
C

2

5

I have a superclass that has a retrieve() method, and its subclasses each implement their own retrieve() method. I'd like every retrieve() method to be decorated to cache the return value when it receive the same args, without having to decorate the method in every subclass.

Decorators don't seem to be inherited. I could probably call the superclass's method which would in turn set the cache, but currently my superclass raises a NotImplemented exception, which I like.

import json
import operator
from cachetools import cachedmethod, TTLCache

def simple_decorator(func):
    def wrapper(*args, **kwargs):
        #check cache
        print("simple decorator")
        func(*args, **kwargs)
        #set cache
    return wrapper


class AbstractInput(object):
    def __init__(self, cacheparams = {'maxsize': 10, 'ttl': 300}):
        self.cache = TTLCache(**cacheparams)
        super().__init__()

    @simple_decorator
    def retrieve(self, params):
        print("AbstractInput retrieve")
        raise NotImplementedError("AbstractInput inheritors must implement retrieve() method")

class JsonInput(AbstractInput):
    def retrieve(self, params):
        print("JsonInput retrieve")
        return json.dumps(params)

class SillyJsonInput(JsonInput):
    def retrieve(self, params):
        print("SillyJsonInput retrieve")
        params["silly"] = True
        return json.dumps(params)

Actual results:

>>> ai.retrieve(params)
ai.retrieve(params)
simple decorator
AbstractInput retrieve
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 8, in wrapper
  File "<string>", line 22, in retrieve
NotImplementedError: AbstractInput inheritors must implement retrieve() method
>>> ji.retrieve(params)
ji.retrieve(params)
JsonInput retrieve
'{"happy": "go lucky", "angry": "as a wasp"}'

Desired results:

>>> ai.retrieve(params)
ai.retrieve(params)
simple decorator
AbstractInput retrieve
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<string>", line 8, in wrapper
  File "<string>", line 22, in retrieve
NotImplementedError: AbstractInput inheritors must implement retrieve() method
>>> ji.retrieve(params)
simple decorator
ji.retrieve(params)
JsonInput retrieve
'{"happy": "go lucky", "angry": "as a wasp"}'
Cogitate answered 19/7, 2019 at 0:56 Comment(0)
P
4

Yes, the use of a metaclass to force a decorator on an specific method, as you put in your own answer is correct. With a few changes, it can be made so that the method to be decorated is not fixed - for example, an attribute set in the decorated function can be used as a "mark" that such a decorator should be forced upon overriding methods.

Besides that, since Python 3.6, there is a new class level mechanism - the special method __init_subclass__, which has the specific objective of diminishing the need for metaclasses. Metaclasses can be complicated, and if your class hierarchy needs to combine more than one metaclass, you may be in for some headache.

The __init_subclass__ method is placed on the base class, and it is called once each time a child class is created. The wrapping logic can be put there.

Basically, you can just modify your decorator to put the mark I mentioned above, and add this class in your inheritance hierarchy - it can be put as mixin class in multiple inheritance, so it can be reused for various class-trees, if needed:

def simple_decorator(func):
    def wrapper(*args, **kwargs):
        print("check cache")
        rt = func(*args, **kwargs)
        print("set cache")
        return rt
    wrapper.inherit_decorator = simple_decorator
    return wrapper

class InheritDecoratorsMixin:
    def __init_subclass__(cls, *args, **kwargs):
         super().__init_subclass__(*args, **kwargs)
         decorator_registry = getattr(cls, "_decorator_registry", {}).copy()
         cls._decorator_registry = decorator_registry
         # Check for decorated objects in the mixin itself- optional:
         for name, obj in __class__.__dict__.items():
              if getattr(obj, "inherit_decorator", False) and not name in decorator_registry:
                  decorator_registry[name] = obj.inherit_decorator
         # annotate newly decorated methods in the current subclass:
         for name, obj in cls.__dict__.items():
              if getattr(obj, "inherit_decorator", False) and not name in decorator_registry:
                  decorator_registry[name] = obj.inherit_decorator
         # finally, decorate all methods anottated in the registry:
         for name, decorator in decorator_registry.items():
              if name in cls.__dict__ and getattr(getattr(cls, name), "inherit_decorator", None) != decorator:
                    setattr(cls, name, decorator(cls.__dict__[name]))

So, that is it - each new subclass will have its own _decorator_registry attribute, where the name of the decorated methods in all ancestors, along with which decorator to apply is annotated.

If the decorator should be used one single time for the method, and not be repeated when the overridden method performs the super() call for its ancestors (not the case when you are decorating for cache, since the super-methods won't be called) that gets trickier - but can be done.

However, it is tricky to do - as the decorator instances in the superclasses would be other instances than the decorator on the subclass - one way to pass information to then that the "decorator code for this method is already run in this chain call" is to use an instance-level marker - which should be a thread-local variable if the code is to support parallelism.

All this checking will result in quite some complicated boilerplate to put into what could be a simple decorator - so we can create a "decorator" for the "decorators" that we want to run a single time. In other wors, decoratos decorated with childmost bellow will run only on the "childmost" class, but not on the corresponding methods in the superclasses when they call super()



import threading

def childmost(decorator_func):

    def inheritable_decorator_that_runs_once(func):
        decorated_func = decorator_func(func)
        name = func.__name__
        def wrapper(self, *args, **kw):
            if not hasattr(self, f"_running_{name}"):
                setattr(self, f"_running_{name}", threading.local())
            running_registry = getattr(self, f"_running_{name}")
            try:
                if not getattr(running_registry, "running", False):
                    running_registry.running = True
                    rt = decorated_func(self, *args, **kw)
                else:
                    rt = func(self, *args, **kw)
            finally:
                running_registry.running = False
            return rt

        wrapper.inherit_decorator = inheritable_decorator_that_runs_once
        return wrapper
    return inheritable_decorator_that_runs_once

Example using the first listing:

class A: pass

class B(A, InheritDecoratorsMixin):
    @simple_decorator
    def method(self):
        print(__class__, "method called")

class C(B):
   def method(self):
       print(__class__, "method called")
       super().method()

And after pasting the listing-1 and these A=B-C class in the interpreter, the result is this:

In [9]: C().method()                                                                         
check cache
<class '__main__.C'> method called
check cache
<class '__main__.B'> method called
set cache
set cache

(the "A" class here is entirely optional and can be left out)


Example using the second listing:


# Decorating the same decorator above:

@childmost
def simple_decorator2(func):
    def wrapper(*args, **kwargs):
        print("check cache")
        rt = func(*args, **kwargs)
        print("set cache")
        return rt
    return wrapper

class D: pass

class E(D, InheritDecoratorsMixin):
    @simple_decorator2
    def method(self):
        print(__class__, "method called")

class F(E):
   def method(self):
       print(__class__, "method called")
       super().method()

And the result:


In [19]: F().method()                                                                        
check cache
<class '__main__.F'> method called
<class '__main__.E'> method called
set cache

Portable answered 19/7, 2019 at 5:22 Comment(4)
Hmm, I was hoping this approach might allow me to reference methods defined in the subclasses within the decorator call, but unfortunately I can't seem to get it to work. The line: setattr(cls, name) = decorator(cls.__dict__[name]) results in: SyntaxError: can't assign to function call and if I comment that out, I get: TypeError: Cannot create a consistent method resolution order (MRO) for bases object, InheritDecoratorsMixinCogitate
That line was a typo - what is after the = should have been the 3rd argument to setattr, of course. I will check what is issing the TypeError -Portable
SO, I apologize for posting untested code in the first instance. It is fixed now, with examples - and I also changed the way ensure the decorator runs just once, factoring the bits out into a new decorator for the original decorator.Portable
And as for "referring to methods defined in the subclasses" - the decorator is passed "self" explicitly, so it will just workPortable
C
3

OK, it seems that I can "decorate" a method in a superclass and have the subclasses also inherit that decoration, even if the method is overwritten in the subclass, using metaclasses. In this case, I'm decorating all "retrieve" methods in AbstractInput and its subclasses with simple_decorator using a metaclass named CacheRetrieval.

def simple_decorator(func):
    def wrapper(*args, **kwargs):
        print("check cache")
        rt = func(*args, **kwargs)
        print("set cache")
        return rt
    return wrapper

class CacheRetrieval(type):
    def __new__(cls, name, bases, attr):
        # Replace each function with
        # a print statement of the function name
        # followed by running the computation with the provided args and returning the computation result
        attr["retrieve"] = simple_decorator(attr["retrieve"])

        return super(CacheRetrieval, cls).__new__(cls, name, bases, attr)


class AbstractInput(object, metaclass= CacheRetrieval):
    def __init__(self, cacheparams = {'maxsize': 10, 'ttl': 300}):
        self.cache = TTLCache(**cacheparams)
        super().__init__()

    def retrieve(self, params):
        print("AbstractInput retrieve")
        raise NotImplementedError("DataInput must implement retrieve() method")


class JsonInput(AbstractInput):
    def retrieve(self, params):
        print("JsonInput retrieve")
        return json.dumps(params)


class SillyJsonInput(JsonInput):
    def retrieve(self, params):
        print("SillyJsonInput retrieve")
        params["silly"] = True
        return json.dumps(params)

I was helped by this page: https://stackabuse.com/python-metaclasses-and-metaprogramming/

Cogitate answered 19/7, 2019 at 3:52 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.