Make built-in lru_cache skip caching when function returns None
Asked Answered
S

2

8

Here's a simplified function for which I'm trying to add a lru_cache for -

from functools import lru_cache, wraps

@lru_cache(maxsize=1000)
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(validate_token.cache_info())

outputs -

CacheInfo(hits=0, misses=1000, maxsize=1000, currsize=1000)

As we can see, it would also cache args and returned values for the None returns as well. In above example, I want the cache_size to be 334, where we are returning non-None values. In my case, my function having large no. of args might return a different value if previous value was None. So I want to avoid caching the None values.

I want to avoid reinventing the wheel and implementing a lru_cache again from scratch. Is there any good way to do this?

Here are some of my attempts -

1. Trying to implement own cache (which is non-lru here) -

from functools import wraps 

# global cache object
MY_CACHE = {}

def get_func_hash(func):
    # generates unique key for a function. TODO: fix what if function gets redefined?
    return func.__module__ + '|' + func.__name__

def my_lru_cache(func):
    name = get_func_hash(func)
    if not name in MY_CACHE:
        MY_CACHE[name] = {}
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        if tuple(args) in MY_CACHE[name]:
            return MY_CACHE[name][tuple(args)]
        value = func(*args, **kwargs)
        if value is not None:
            MY_CACHE[name][tuple(args)] = value
        return value
    return function_wrapper

@my_lru_cache
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(get_func_hash(validate_token))
print(len(MY_CACHE[get_func_hash(validate_token)]))

outputs -

__main__|validate_token
334

2. I realised that the lru_cache doesn't do caching when an exception is raised within the wrapped function -

from functools import wraps, lru_cache

def my_lru_cache(func):
    @wraps(func)
    @lru_cache(maxsize=1000)
    def function_wrapper(*args, **kwargs):
        value = func(*args, **kwargs)
        if value is None:
            # TODO: change this to a custom exception
            raise KeyError
        return value
    return function_wrapper

def handle_exception(func):
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        try:
            value = func(*args, **kwargs)
            return value
        except KeyError:
            return None
    return function_wrapper    

@handle_exception
@my_lru_cache
def validate_token(token):
    if token % 3:
        return None
    return True

for x in range(1000):
    validate_token(x)

print(validate_token.__wrapped__.cache_info())

outputs -

CacheInfo(hits=0, misses=334, maxsize=1000, currsize=334)

Above correctly caches only the 334 values, but needs wrapping the function twice and accessing the cache_info in a weird manner func.__wrapped__.cache_info().

How do I better achieve the behaviour of not caching when None(or specific) values are returned using built-in lru_cache decorator in a pythonic way?

Spohr answered 18/2, 2022 at 14:41 Comment(1)
This is surprising to me that cachetools doesn't have a basic way to not-cache None valuesZygote
E
5

You are missing the two lines marked here:

def handle_exception(func):
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        try:
            value = func(*args, **kwargs)
            return value
        except KeyError:
            return None

    function_wrapper.cache_info = func.cache_info    # Add this
    function_wrapper.cache_clear = func.cache_clear  # Add this
    return function_wrapper

You can do both wrappers in one function:

def my_lru_cache(maxsize=128, typed=False):
    class CustomException(Exception):
        pass

    def decorator(func):
        @lru_cache(maxsize=maxsize, typed=typed)
        def raise_exception_wrapper(*args, **kwargs):
            value = func(*args, **kwargs)
            if value is None:
                raise CustomException
            return value

        @wraps(func)
        def handle_exception_wrapper(*args, **kwargs):
            try:
                return raise_exception_wrapper(*args, **kwargs)
            except CustomException:
                return None

        handle_exception_wrapper.cache_info = raise_exception_wrapper.cache_info
        handle_exception_wrapper.cache_clear = raise_exception_wrapper.cache_clear
        return handle_exception_wrapper

    if callable(maxsize):
        user_function, maxsize = maxsize, 128
        return decorator(user_function)

    return decorator
Ethereal answered 29/8, 2022 at 16:28 Comment(0)
K
7

Use exceptions to prevent caching:

from functools import lru_cache

@lru_cache(maxsize=None)
def fname(x):
    print('worked')

    raise Exception('')

    return 1


for _ in range(10):
    try:
        fname(1)
    except Exception as e:
        pass

In example above "worked" will be printed 10 times.

Kulp answered 29/8, 2022 at 14:44 Comment(0)
E
5

You are missing the two lines marked here:

def handle_exception(func):
    @wraps(func)
    def function_wrapper(*args, **kwargs):
        try:
            value = func(*args, **kwargs)
            return value
        except KeyError:
            return None

    function_wrapper.cache_info = func.cache_info    # Add this
    function_wrapper.cache_clear = func.cache_clear  # Add this
    return function_wrapper

You can do both wrappers in one function:

def my_lru_cache(maxsize=128, typed=False):
    class CustomException(Exception):
        pass

    def decorator(func):
        @lru_cache(maxsize=maxsize, typed=typed)
        def raise_exception_wrapper(*args, **kwargs):
            value = func(*args, **kwargs)
            if value is None:
                raise CustomException
            return value

        @wraps(func)
        def handle_exception_wrapper(*args, **kwargs):
            try:
                return raise_exception_wrapper(*args, **kwargs)
            except CustomException:
                return None

        handle_exception_wrapper.cache_info = raise_exception_wrapper.cache_info
        handle_exception_wrapper.cache_clear = raise_exception_wrapper.cache_clear
        return handle_exception_wrapper

    if callable(maxsize):
        user_function, maxsize = maxsize, 128
        return decorator(user_function)

    return decorator
Ethereal answered 29/8, 2022 at 16:28 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.