Subclassing builtin types in Python 2 and Python 3
Asked Answered
K

3

22

When subclassing builtin types, I noticed a rather important difference between Python 2 and Python 3 in the return type of the methods of the built-in types. The following code illustrates this for sets:

class MySet(set):

    pass

s1 = MySet([1, 2, 3, 4, 5])

s2 = MySet([1, 2, 3, 6, 7])

print(type(s1.union(s2)))

print(type(s1.intersection(s2)))

print(type(s1.difference(s2)))

With Python 2, all the return values are of type MySet. With Python 3, the return types are set. I could not find any documentation on what the result is supposed to be, nor any documentation about the change in Python 3.

Anyway, what I really care about is this: is there a simple way in Python 3 to get the behavior seen in Python 2, without redefining every single method of the built-in types?

Kismet answered 2/11, 2011 at 15:48 Comment(2)
On Python 2 only the type of s1 is relevant not the type of s2.Shig
It's kind of similar to the way that False + False is 0, not False (bool is a subclass of int, by the way).Dyna
P
12

This isn't a general change for built-in types when moving from Python 2.x to 3.x -- list and int, for example, have the same behaviour in 2.x and 3.x. Only the set type was changed to bring it in line with the other types, as discussed in this bug tracker issue.

I'm afraid there is no really nice way to make it behave the old way. Here is some code I was able to come up with:

class MySet(set):
    def copy(self):
        return MySet(self)
    def _make_binary_op(in_place_method):
        def bin_op(self, other):
            new = self.copy()
            in_place_method(new, other)
            return new
        return bin_op
    __rand__ = __and__ = _make_binary_op(set.__iand__)
    intersection = _make_binary_op(set.intersection_update)
    __ror__ = __or__ = _make_binary_op(set.__ior__)
    union = _make_binary_op(set.update)
    __sub__ = _make_binary_op(set.__isub__)
    difference = _make_binary_op(set.difference_update)
    __rxor__ = xor__ = _make_binary_op(set.__ixor__)
    symmetric_difference = _make_binary_op(set.symmetric_difference_update)
    del _make_binary_op
    def __rsub__(self, other):
        new = MySet(other)
        new -= self
        return new

This will simply overwrite all methods with versions that return your own type. (There is a whole lot of methods!)

Maybe for your application, you can get away with overwriting copy() and stick to the in-place methods.

Paviour answered 2/11, 2011 at 16:42 Comment(2)
Right, Python 2 was not consistent here. If you create a class MySet(set): pass in Python 2, then print type(MySet().copy()) gives <class '__main__.MySet'>, but if you create a class MyDict(dict): pass, then print type(MyDict().copy()) gives <type 'dict'>.Hartzog
There's a way to handle at least the non-special methods in a single operation. I'll answer my own question to illustrate how (I can't put code into a comment). But it's still way more overhead that I'd like, with all the special methods to handle one by one.Kismet
M
0

Perhaps a metaclass to do all that humdrum wrapping for you would make it easier:

class Perpetuate(type):
    def __new__(metacls, cls_name, cls_bases, cls_dict):
        if len(cls_bases) > 1:
            raise TypeError("multiple bases not allowed")
        result_class = type.__new__(metacls, cls_name, cls_bases, cls_dict)
        base_class = cls_bases[0]
        known_attr = set()
        for attr in cls_dict.keys():
            known_attr.add(attr)
        for attr in base_class.__dict__.keys():
            if attr in ('__new__'):
                continue
            code = getattr(base_class, attr)
            if callable(code) and attr not in known_attr:
                setattr(result_class, attr, metacls._wrap(base_class, code))
            elif attr not in known_attr:
                setattr(result_class, attr, code)
        return result_class
    @staticmethod
    def _wrap(base, code):
        def wrapper(*args, **kwargs):
            if args:
                cls = args[0]
            result = code(*args, **kwargs)
            if type(result) == base:
                return cls.__class__(result)
            elif isinstance(result, (tuple, list, set)):
                new_result = []
                for partial in result:
                    if type(partial) == base:
                        new_result.append(cls.__class__(partial))
                    else:
                        new_result.append(partial)
                result = result.__class__(new_result)
            elif isinstance(result, dict):
                for key in result:
                    value = result[key]
                    if type(value) == base:
                        result[key] = cls.__class__(value)
            return result
        wrapper.__name__ = code.__name__
        wrapper.__doc__ = code.__doc__
        return wrapper

class MySet(set, metaclass=Perpetuate):
    pass

s1 = MySet([1, 2, 3, 4, 5])

s2 = MySet([1, 2, 3, 6, 7])

print(s1.union(s2))
print(type(s1.union(s2)))

print(s1.intersection(s2))
print(type(s1.intersection(s2)))

print(s1.difference(s2))
print(type(s1.difference(s2)))
Metachromatism answered 2/11, 2011 at 20:47 Comment(4)
A few comments: 1. This would fail to wrap a method called e(), but it does wrap __getattribute__(), preventing to store objects of the base typ in attributes. 2. This will have a severe performance hit, especially for retrieving attributes. If you store a list in an attribute, it will be iterated over on every access. There are more performance problems, maybe too many to point out.Paviour
@SvenMarnach: Why will it fail to wrap e()?Metachromatism
Because for the name e, the condition attr in ('__new__') will hold. Admittedly, that's a cheap one, but there are more obscure bugs in this code.Paviour
@SvenMarnach: Ah -- not a tuple without the comma. Thanks.Metachromatism
K
0

As a follow-up to Sven's answer, here is a universal wrapping solution that takes care of all non-special methods. The idea is to catch the first lookup coming from a method call, and install a wrapper method that does the type conversion. At subsequent lookups, the wrapper is returned directly.

Caveats:

1) This is more magic trickery than I like to have in my code.

2) I'd still need to wrap special methods (__and__ etc.) manually because their lookup bypasses __getattribute__

import types

class MySet(set):

    def __getattribute__(self, name):
        attr = super(MySet, self).__getattribute__(name)
        if isinstance(attr, types.BuiltinMethodType):
            def wrapper(self, *args, **kwargs):
                result = attr(self, *args, **kwargs)
                if isinstance(result, set):
                    return MySet(result)
                else:
                    return result
            setattr(MySet, name, wrapper)
            return wrapper
        return attr
Kismet answered 3/11, 2011 at 10:32 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.