Equality Comparison for User-Defined Class Objects in Python

This article explains how to compare user-defined class objects for equality in Python.

Introduction

In Python, user-defined classes cannot be compared for equality using the equality operator by default.

For example, if you define the following class and pytest code, an assertion error will occur when comparing objects:

class MyClass:
    def __init__(self, amount):
        self.amount = amount


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # This doesn't work


pytest execution result:

=============================== FAILURES ================================
_____________________________ test_equality _____________________________

    def test_equality():
        obj_1 = MyClass(5)
        obj_2 = MyClass(5)
>       assert obj_1 == obj_2
E       assert <equality.MyClass object at 0x0000014F05071848> == <equality.MyClass object at 0x0000014F05071C88>

equality.py:14: AssertionError
======================== short test summary info ========================
FAILED equality.py::test_equality - assert <equality.MyClass object at ...
=========================== 1 failed in 0.03s ===========================


This article shows how to enable equality comparison for custom-defined objects like those above.

# Environment
Python 3.7.6

Note: This article was translated from my original post.

Comparing User-Defined Class Objects for Equality in Python

To enable equality comparison for user-defined class objects, use the __eq__ special method.

The __eq__ method is called when the equality operator == is used.
For example, when x==y is executed, x.__eq__(y) is called.
※Reference: 3. Data model — Python 3.14.2 documentation

Let's look at how to implement this.

Comparing Instance Variables in the __eq__ Method

First, by comparing instance variables in the __eq__ method, you can enable equality comparison for objects.

Implementation example:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.amount == other.amount # Compare instance variables

def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # This test passes


However, with this approach, you need to update the __eq__ method every time an instance variable is added or changed. Next, let me introduce a method using __dict__.

Comparing __dict__ in the __eq__ Method

__dict__ is a special attribute that stores instance variables in a dictionary. By comparing this in the __eq__ method, you can perform equality comparison without comparing instance variables one by one.

Implementation example:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.__dict__ == other.__dict__ # Compare instance variables

def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # This test passes

Comparing Class Identity

The approaches shown above only compare instance variables, so they don't compare the class type itself. Therefore, if you compare two different classes that have identical instance variables, they will be judged as equal:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

class DummyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.__dict__ == other.__dict__


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = DummyClass(5)
    assert obj_1 == obj_2 # This test passes


To ensure objects of different classes are judged as unequal, compare the class using __class__ within the __eq__ method. __class__ is a special attribute that stores the class type.

Implementation example:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return (
            isinstance(other, self.__class__) and
            self.__dict__ == other.__dict__
        )


class DummyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return (
            isinstance(other, self.__class__) and
            self.__dict__ == other.__dict__
        )


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # This test passes
    
    obj_3 = DummyClass(5)
    assert obj_1 == obj_3 # This test fails

Returning NotImplemented for Different Classes

When the classes are different, you can alternatively implement the method to return NotImplemented, indicating that the equality operation is not supported.

Implementation example:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return NotImplemented
        return self.__dict__ == other.__dict__

class DummyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return NotImplemented
        return self.__dict__ == other.__dict__


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # This test passes
    
    obj_3 = DummyClass(5)
    assert obj_1 == obj_3 # This test fails

Conclusion

This article has covered methods for comparing user-defined class objects for equality in Python.

I encountered this issue while working through Kent Beck's TDD book using Python.

While Java/JUnit allows easy object comparison with assertEquals(), I realized that Python requires a little extra work, so I wrote this as a reference.

I hope this will be helpful to someone.

[Related Posts]

en.bioerrorlog.work

References