How to compare equality of dataclasses holding numpy.ndarray (bool(a==b) raises ValueError)?
Asked Answered
S

2

9

If I create a Python dataclass containing a Numpy ndarray, I can no longer use the automatically generated __eq__ anymore.

import numpy as np

@dataclass
class Instr:
    foo: np.ndarray
    bar: np.ndarray

arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

This is because ndarray.__eq__ sometimes returns a ndarray of truth values, by comparing a[0] to b[0], and so on and so forth to the longer of the 2. This is quite complex and unintuitive, and in fact only raises an error when the arrays are different shapes, or have different values or something.

How do I safely compare @dataclasses holding Numpy arrays?


@dataclass's implementation of __eq__ is generated using eval(). Its source is missing from the stacktrace and cannot be viewed using inspect, but it's actually using a tuple comparison, which calls bool(foo).

import dis
dis.dis(Instr.__eq__)

excerpt:

  3          12 LOAD_FAST                0 (self)
             14 LOAD_ATTR                1 (foo)
             16 LOAD_FAST                0 (self)
             18 LOAD_ATTR                2 (bar)
             20 BUILD_TUPLE              2
             22 LOAD_FAST                1 (other)
             24 LOAD_ATTR                1 (foo)
             26 LOAD_FAST                1 (other)
             28 LOAD_ATTR                2 (bar)
             30 BUILD_TUPLE              2
             32 COMPARE_OP               2 (==)
             34 RETURN_VALUE
Sherrie answered 8/8, 2018 at 9:58 Comment(3)
You could write your own __eq__ method on Instr, you can override any of the autogenerated methods. Just catch the ValueError and implement your own additional logic.Sightless
For the record, the dataclass .__eq__ source is here github.com/python/cpython/blob/3.7/Lib/dataclasses.py#L884Griffiths
This bites me often.Upsilon
G
10

The solution is to put in your own __eq__ method and set eq=False so the dataclass doesn't generate its own (although checking the docs that last step isn't necessary but I think it's nice to be explicit anyway).

import numpy as np

def array_eq(arr1, arr2):
    return (isinstance(arr1, np.ndarray) and
            isinstance(arr2, np.ndarray) and
            arr1.shape == arr2.shape and
            (arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

    foo: np.ndarray
    bar: np.ndarray

    def __eq__(self, other):
        if not isinstance(other, Instr):
            return NotImplemented
        return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)

Edit

A general and quick solution for generic dataclasses where some values are numpy arrays and some others are not

import numpy as np
from dataclasses import dataclass, astuple

def array_safe_eq(a, b) -> bool:
    """Check if a and b are equal, even if they are numpy arrays"""
    if a is b:
        return True
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return a.shape == b.shape and (a == b).all()
    try:
        return a == b
    except TypeError:
        return NotImplemented

def dc_eq(dc1, dc2) -> bool:
   """checks if two dataclasses which hold numpy arrays are equal"""
   if dc1 is dc2:
        return True
   if dc1.__class__ is not dc2.__class__:
       return NotImplemented  # better than False
   t1 = astuple(dc1)
   t2 = astuple(dc2)
   return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))

# usage
@dataclass(eq=False)
class T:

   a: int
   b: np.ndarray
   c: np.ndarray

   def __eq__(self, other):
        return dc_eq(self, other)
Griffiths answered 8/8, 2018 at 10:5 Comment(8)
that's a bit of work, but I guess I'll have to do it.Sherrie
Well remember that before attrs/dataclasses you'd have had to do this anywayGriffiths
I ended up comparing arr.tolist() so I don't have to worry about arrays being falsely equal due to broadcasting shenanigans, is that a good idea?Sherrie
That will work but it's not very efficient. What broadcasting shenanigans are we talking about here?Griffiths
This is primarily for unit tests, I don't want 1 == [1] == [[[1]]] or [1,1,1] == 1.Sherrie
Ok, I have edited my answer for the solution you require (without using tolist())Griffiths
I think you could use numpy's array_equal() instead of your own def array_eq() implementation.Narcose
Thanks for this! Please note a typo in line 20 (NotImplmeneted instead of NotImplemented)Eartha
W
0

This can be customized if you use attrs instead of dataclasses:

from attrs import define, field
import numpy

@define
class C:
   an_array = field(eq=attr.cmp_using(eq=numpy.array_equal))
Wauters answered 7/9, 2023 at 17:58 Comment(0)

© 2022 - 2025 — McMap. All rights reserved.