Type Narrowing of Class Attributes in Python (TypeGuard) without Subclassing
Asked Answered
H

2

6

Consider I have a python class that has a attributes (i.e. a dataclass, pydantic, attrs, django model, ...) that consist of a union, i.e. None and and a state. Now I have a complex checking function that checks some values.

If I use this checking function, I want to tell the type checker, that some of my class attributes are narrowed.

For instance see this simplified example:

import dataclasses
from typing import TypeGuard


@dataclasses.dataclass
class SomeDataClass:
    state: tuple[int, int] | None
    name: str
    # Assume many more data attributes


class SomeDataClassWithSetState(SomeDataClass):
    state: tuple[int, int]


def complex_check(data: SomeDataClass) -> TypeGuard[SomeDataClassWithSetState]:
    # Assume some complex checks here, for simplicity it is only:
    return data.state is not None and data.name.startswith("SPECIAL")


def get_sum(data: SomeDataClass) -> int:
    if complex_check(data):
        return data.state[0] + data.state[1]
    return 0

Explore on mypy Playground

As seen it is possible to do this with subclasses, which for various reason is not an option for me:

  • it introduces a lot of duplication
  • some possible libraries used for dataclasses are not happy with being subclasses without side condition
  • there could be some Metaclass or __subclasses__ magic that handles all subclass specially, i.e. creating database for the dataclasses

So is there an option to type narrow a(n) attribute(s) of a class without introducing a solely new class, as proposed here?

Herbertherbicide answered 30/11, 2022 at 8:11 Comment(2)
You seem intent on using typing.TypeGuard, which relies on two distinct types to be useful. It is not entirely clear, what your actual goal is, but my guess is that TypeGuard is not the way to get there. Based on your example, is your goal to have a type-safe way to assume the SomeDataClass.state attribute is a tuple (and not None) in the get_sum function?Zebada
The idea here is that the complex_check function is intended to be used everyhwhere in the code and I don't want assert data.state is not None calls in all the if branches, because this has already been verified by complex_check. I only need the type checker to understand this. Also it is just a simple example, the complex_check function could check a number of arguments, that could be potentially type narrowed.Herbertherbicide
Z
2

TL;DR: You cannot narrow the type of an attribute. You can only narrow the type of an object.

As I already mentioned in my comment, for typing.TypeGuard to be useful it relies on two distinct types T and S. Then, depending on the returned bool, the type guard function tells the type checker to assume the object to be either T or S.

You say, you don't want to have another class/subclass alongside SomeDataClass for various (vaguely valid) reasons. But if you don't have another type, then TypeGuard is useless. So that is not the route to take here.

I understand that you want to reduce the type-safety checks like if obj.state is None because you may need to access the state attribute in multiple different places in your code. You must have some place in your code, where you create/mutate a SomeDataClass instance in a way that ensures its state attribute is not None. One solution then is to have a getter for that attribute that performs the type-safety check and only ever returns the narrower type or raises an error. I typically do this via @property for improved readability. Example:

from dataclasses import dataclass


@dataclass
class SomeDataClass:
    name: str
    optional_state: tuple[int, int] | None = None

    @property
    def state(self) -> tuple[int, int]:
        if self.optional_state is None:
            raise RuntimeError("or some other appropriate exception")
        return self.optional_state


def set_state(obj: SomeDataClass, value: tuple[int, int]) -> None:
    obj.optional_state = value


if __name__ == "__main__":
    foo = SomeDataClass(optional_state=(1, 2), name="foo")
    bar = SomeDataClass(name="bar")
    baz = SomeDataClass(name="baz")
    set_state(bar, (2, 3))
    print(foo.state)
    print(bar.state)
    try:
        print(baz.state)
    except RuntimeError:
        print("baz has no state")

I realize you mean there are many more checks happening in complex_check, but either that function doesn't change the type of data or it does. If the type remains the same, you need to introduce type-safety for attributes like state in some other place, which is why I suggest a getter method.

Another option is obviously to have a separate class, which is what is typically done with FastAPI/Pydantic/SQLModel for example and use clever inheritance to reduce code duplication. You mentioned this may cause problems because of subclassing magic. Well, if it does, use the other approach, but I can't think of an example that would cause the problems you mentioned. Maybe you can be more specific and show a case where subclassing would lead to problems.

Zebada answered 30/11, 2022 at 10:10 Comment(2)
The property Idea is great.Herbertherbicide
Just found this issues again, facing the same problem in codebase with a lot of dataclasses with optional properties.Herbertherbicide
A
0

You can use a Protocol class and TypeVar's to support a wider range of potential state options.

from __future__ import annotations

import dataclasses
from typing import TypeGuard, TYPE_CHECKING, Literal

if TYPE_CHECKING:
    from typing import Protocol, TypeVar

    _T = TypeVar("_T")

    class _StateProtocol(Protocol[_T]):
        state: _T


@dataclasses.dataclass
class SomeDataClass:
    state: tuple[int, int] | Literal["FOO"] | None
    name: str


def is_tuple_state(data: SomeDataClass) -> TypeGuard[_StateProtocol[tuple[int, int]]]:
    # Assume some complex checks here, for simplicity it is only:
    return isinstance(data.state, tuple) and data.name.startswith("SPECIAL")


def is_foo_state(data: SomeDataClass) -> TypeGuard[_StateProtocol[Literal["FOO"]]]:
    # Assume some complex checks here, for simplicity it is only:
    return data.state == "FOO"


def get_sum(data: SomeDataClass) -> int:
    if is_tuple_state(data):
        return data.state[0] + data.state[1]
    if is_foo_state(data):
        return -1

    return 0


def main():
    assert get_sum(SomeDataClass(None, "SPECIAL")) == 0
    assert get_sum(SomeDataClass((1, 2), "SPECIAL")) == 3
    assert get_sum(SomeDataClass("FOO", "SPECIAL")) == -1


if __name__ == "__main__":
    main()

Aikido answered 11/10, 2023 at 18:29 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.