Python dictionary with enum as key
Asked Answered
C

3

9

Let's say I have an enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

I wanted to create a ColorDict class that works as a native python dictionary but only takes the Color enum or its corresponding string value as key.

d = ColorDict() # I want to implement a ColorDict class such that ...

d[Color.RED] = 123
d["RED"] = 456  # I want this to override the previous value
d[Color.RED]    # ==> 456
d["foo"] = 789  # I want this to produce an KeyError exception

What's the "pythonic way" of implementing this ColorDict class? Shall I use inheritance (overriding python's native dict) or composition (keep a dict as a member)?

Comines answered 28/11, 2021 at 18:43 Comment(6)
Inheritance or composition is really up to you. If use inheritance, you will have to override all the methods that accept inputs, so __setiitem__, .update, it may be easy enough.Hitt
It's personal choice, but I generally prefer composition in these cases. It makes the interface a lot easier to understand by being explicit about what operations you want to expose, limiting the amount of work you need to do, especially if you don't care about implementing the entire dict interface.Abidjan
@Mark The code snippets contains the behavior I hope to achieve, not what currently I observe. I updated the comment to be clearer. Sorry for the confusion.Comines
An alternative is to inherit from collections.abc.MutableMapping which would involve composition, but you would only have to implement a minimal amount of methodsHitt
Thanks @YingXiong, I realized I misread that right before you posted.Mudguard
KeyError is a lookup error (something not found), but the assignment is not a lookup. I would consider the ValueError for a wrong key value, but to be honest, I'm not sure which one is more appropriate.Halfwitted
J
7

A simple solution would be to slightly modify your Color object and then subclass dict to add a test for the key. I would do something like this:

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

    @classmethod
    def is_color(cls, color):
        if isinstance(color, cls):
            color=color.value
        if not color in cls.__members__:
            return False
        else:
            return True


class ColorDict(dict):
    
    def __setitem__(self, k, v):
        if Color.is_color(k):
            super().__setitem__(Color(k), v)
        else:
            raise KeyError(f"Color {k} is not valid")

    def __getitem__(self, k):
        if isinstance(k, str):
            k = Color(k.upper())
        return super().__getitem__(k)

d = ColorDict()

d[Color.RED] = 123
d["RED"] = 456
d[Color.RED]
d["foo"] = 789

In the Color class, I have added a test function to return True or False if a color is/isn't in the allowed list. The upper() function puts the string in upper case so it can be compared to the pre-defined values.

Then I have subclassed the dict object to override the __setitem__ special method to include a test of the value passed, and an override of __getitem__ to convert any key passed as str into the correct Enum. Depending on the specifics of how you want to use the ColorDict class, you may need to override more functions. There's a good explanation of that here: How to properly subclass dict and override __getitem__ & __setitem__

Jejunum answered 28/11, 2021 at 19:22 Comment(4)
Minor comment - it's probably a bad idea to name your filtering function (test_color) using the test_ prefix, as a number of test frameworks might inadvertently pick that up as a test case. is_color would be an idiomatic filter name.Essam
That is a very good point! I'll modify it.Jejunum
Same as @VPfB's comment. We probably need to do something like super().__setitem__(Color(k), v).Comines
Thank you both, I've added the suggested Color(k), which appears to solve the problem. But now __getitem__ needs overriding as well to match.Jejunum
S
4
from enum import Enum

class Color(Enum):
    RED = 1
    GREEN = 2
    BLUE = 3

# Using enum values as key (with example of type hinting)
color_dict: dict[Color, str] = {
    Color.RED: "red",
    Color.GREEN: "green",
    Color.BLUE: "blue"
}

# Accessing values using enum keys
print(color_dict[Color.RED])  # Outputs: red
Simms answered 12/1 at 14:59 Comment(3)
Why is this not the accepted answer?Whitehall
@Whitehall It doesn't fulfill the basic requirement that color_dict["foo"] = 789 raises a KeyError.Ciao
@Ciao In Python dictionaries, ANY hashable and comparable object can be the key. Strings are hashable+comparable objects. Enums are hashable+comparable objects. And much more. There is ZERO runtime type checking whatsoever in Python. Your example is a FULLY VALID Python dict key. That is why people use static MyPy type checking and type hinting to lint their program, to ensure that all calls and usages in the code adheres to the specified type-hints. However, I see that the original question was from a misguided person who wants runtime key-type checking. They SHOULD use MyPy!Koss
H
2

One way is to use the abstract base class collections.abc.MutableMapping, this way, you only need to override the abstract methods and then you can be sure that access always goes through your logic -- you can do this with dict too, but for example, overriding dict.__setitem__ will not affect dict.update, dict.setdefault etc... So you have to override those by hand too. Usually, it is easier to just use the abstract base class:

from collections.abc import MutableMapping
from enum import Enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

class ColorDict(MutableMapping):

    def __init__(self): # could handle more ways of initializing  but for simplicity...
        self._data = {}

    def __getitem__(self, item):
        return self._data[color]

    def __setitem__(self, item, value):
        color = self._handle_item(item)
        self._data[color] = value

    def __delitem__(self, item):
        del self._data[color]

    def __iter__(self):
        return iter(self._data)

    def __len__(self):
        return len(self._data)

    def _handle_item(self, item):
        try:
            color = Color(item)
        except ValueError:
            raise KeyError(item) from None
        return color

Note, you can also add:

    def __repr__(self):
        return repr(self._data)

For easier debugging.

An example in the repl:

In [3]: d = ColorDict() # I want to implement a ColorDict class such that ...
   ...:
   ...: d[Color.RED] = 123
   ...: d["RED"] = 456  # I want this to override the previous value
   ...: d[Color.RED]    # ==> 456
Out[3]: 456

In [4]: d["foo"] = 789  # I want this to produce an KeyError exception
   ...:
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-9cf80d6dd8b4> in <module>
----> 1 d["foo"] = 789  # I want this to produce an KeyError exception

<ipython-input-2-a0780e16594b> in __setitem__(self, item, value)
     17
     18     def __setitem__(self, item, value):
---> 19         color = self._handle_item(item)
     20         self._data[color] = value
     21

<ipython-input-2-a0780e16594b> in _handle_item(self, item)
     34             color = Color(item)
     35         except ValueError:
---> 36             raise KeyError(item) from None
     37         return color
     38     def __repr__(self): return repr(self._data)

KeyError: 'foo'

In [5]: d
Out[5]: {<Color.RED: 'RED'>: 456}
Hitt answered 28/11, 2021 at 19:8 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.