If I create a Python dataclass containing a Numpy ndarray, I can no longer use the automatically generated __eq__
anymore.
import numpy as np
@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray
arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
This is because ndarray.__eq__
sometimes returns a ndarray
of truth values, by comparing a[0]
to b[0]
, and so on and so forth to the longer of the 2. This is quite complex and unintuitive, and in fact only raises an error when the arrays are different shapes, or have different values or something.
How do I safely compare @dataclass
es holding Numpy arrays?
@dataclass
's implementation of __eq__
is generated using eval()
. Its source is missing from the stacktrace and cannot be viewed using inspect
, but it's actually using a tuple comparison, which calls bool(foo).
import dis
dis.dis(Instr.__eq__)
excerpt:
3 12 LOAD_FAST 0 (self) 14 LOAD_ATTR 1 (foo) 16 LOAD_FAST 0 (self) 18 LOAD_ATTR 2 (bar) 20 BUILD_TUPLE 2 22 LOAD_FAST 1 (other) 24 LOAD_ATTR 1 (foo) 26 LOAD_FAST 1 (other) 28 LOAD_ATTR 2 (bar) 30 BUILD_TUPLE 2 32 COMPARE_OP 2 (==) 34 RETURN_VALUE
__eq__
method onInstr
, you can override any of the autogenerated methods. Just catch theValueError
and implement your own additional logic. – Sightless.__eq__
source is here github.com/python/cpython/blob/3.7/Lib/dataclasses.py#L884 – Griffiths