From 3476348c995af2ce5dfcbcc688e9ddf98fa36360 Mon Sep 17 00:00:00 2001 From: Anthonios Partheniou Date: Mon, 29 Jul 2024 13:10:12 -0400 Subject: [PATCH] fix: fix issue with equality comparison of repeated field with None (#477) * fix: fix issue with equality comparison of repeated field with None * style --- proto/marshal/collections/repeated.py | 9 +++++++-- tests/test_fields_repeated_composite.py | 1 + tests/test_fields_repeated_scalar.py | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/proto/marshal/collections/repeated.py b/proto/marshal/collections/repeated.py index 29bacff6..a6560411 100644 --- a/proto/marshal/collections/repeated.py +++ b/proto/marshal/collections/repeated.py @@ -14,6 +14,7 @@ import collections import copy +from typing import Iterable from proto.utils import cached_property @@ -48,7 +49,7 @@ def __delitem__(self, key): def __eq__(self, other): if hasattr(other, "pb"): return tuple(self.pb) == tuple(other.pb) - return tuple(self.pb) == tuple(other) + return tuple(self.pb) == tuple(other) if isinstance(other, Iterable) else False def __getitem__(self, key): """Return the given item.""" @@ -119,7 +120,11 @@ def _pb_type(self): def __eq__(self, other): if super().__eq__(other): return True - return tuple([i for i in self]) == tuple(other) + return ( + tuple([i for i in self]) == tuple(other) + if isinstance(other, Iterable) + else False + ) def __getitem__(self, key): return self._marshal.to_python(self._pb_type, self.pb[key]) diff --git a/tests/test_fields_repeated_composite.py b/tests/test_fields_repeated_composite.py index db6be27b..7c9f73e4 100644 --- a/tests/test_fields_repeated_composite.py +++ b/tests/test_fields_repeated_composite.py @@ -47,6 +47,7 @@ class Baz(proto.Message): baz = Baz(foos=[Foo(bar=42)]) assert baz.foos == baz.foos + assert baz.foos != None def test_repeated_composite_init_struct(): diff --git a/tests/test_fields_repeated_scalar.py b/tests/test_fields_repeated_scalar.py index e07cf130..d6383132 100644 --- a/tests/test_fields_repeated_scalar.py +++ b/tests/test_fields_repeated_scalar.py @@ -72,6 +72,7 @@ class Foo(proto.Message): foo = Foo(bar=[1, 1, 2, 3, 5, 8, 13]) assert foo.bar == copy.copy(foo.bar) assert foo.bar != [1, 2, 4, 8, 16] + assert foo.bar != None def test_repeated_scalar_del():