From f5b21ea1581d8984f3c80b848302c9cd46307892 Mon Sep 17 00:00:00 2001 From: Gabriel de Marmiesse Date: Wed, 23 Aug 2023 22:43:17 +0200 Subject: [PATCH] Resolver is now immutable (#4) --- plum/function.py | 7 +----- plum/resolver.py | 29 +++++------------------- tests/test_resolver.py | 51 +++++++++++++++++------------------------- 3 files changed, 27 insertions(+), 60 deletions(-) diff --git a/plum/function.py b/plum/function.py index a37a1030..ca82a8e4 100644 --- a/plum/function.py +++ b/plum/function.py @@ -134,8 +134,7 @@ def invalidate_resolver_and_cache(self): @property def resolver(self) -> Resolver: if self._resolver is None: - self._resolver = Resolver() - self._resolve_pending_registrations() + self._resolver = Resolver(self.get_all_subsignatures()) return self._resolver @property @@ -144,10 +143,6 @@ def cache(self) -> dict: self._cache = {} return self._cache - def _resolve_pending_registrations(self) -> None: - for subsignature in self.get_all_subsignatures(): - self._resolver.register(subsignature) - def get_all_subsignatures(self, strict: bool = True) -> Iterator[Signature]: # Perform any pending registrations. for f, signature, precedence in self._all_methods: diff --git a/plum/resolver.py b/plum/resolver.py index 66d1a97b..2d452544 100644 --- a/plum/resolver.py +++ b/plum/resolver.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import Iterable, List, Tuple, Union from plum.signature import Signature @@ -21,29 +21,10 @@ class Resolver: is_faithful (bool): Whether all signatures are faithful or not. """ - def __init__(self): - self.signatures: List[Signature] = [] - self.is_faithful: bool = True - - def register(self, signature: Signature) -> None: - """Register a new signature. - - Args: - signature (:class:`.signature.Signature`): Signature to add. - """ - existing = [s == signature for s in self.signatures] - if any(existing): - if sum(existing) != 1: - raise AssertionError( - f"The added signature `{signature}` is equal to {sum(existing)} " - f"existing signatures. This should never happen." - ) - self.signatures[existing.index(True)] = signature - else: - self.signatures.append(signature) - - # Use a double negation for slightly better performance. - self.is_faithful = not any(not s.is_faithful for s in self.signatures) + def __init__(self, signatures: Iterable[Signature]): + signatures_dict = {hash(s): s for s in signatures} + self.signatures: List[Signature] = list(signatures_dict.values()) + self.is_faithful: bool = all(s.is_faithful for s in self.signatures) def __len__(self) -> int: return len(self.signatures) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index b622d972..53adbd91 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -7,46 +7,35 @@ def test_initialisation(): - r = Resolver() + r = Resolver([]) # Without any registered signatures, the resolver should be faithful. assert r.is_faithful def test_register(): - r = Resolver() - # Test that faithfulness is tracked correctly. - r.register(Signature(int)) - r.register(Signature(float)) + r = Resolver([Signature(int), Signature(float)]) assert r.is_faithful - r.register(Signature(Tuple[int])) + r = Resolver([Signature(int), Signature(float), Signature(Tuple[int])]) assert not r.is_faithful # Test that signatures can be replaced. - new_s = Signature(float) assert len(r) == 3 + new_s = Signature(float) assert r.signatures[1] is not new_s - r.register(new_s) + r = Resolver([Signature(int), Signature(float), Signature(Tuple[int]), new_s]) assert len(r) == 3 assert r.signatures[1] is new_s - # Test the edge case that should never happen. - r.signatures[2] = Signature(float) - with pytest.raises( - AssertionError, - match=r"(?i)the added signature `(.*)` is equal to 2 existing signatures", - ): - r.register(Signature(float)) - def test_len(): - r = Resolver() + r = Resolver([]) assert len(r) == 0 - r.register(Signature(int)) + r = Resolver([Signature(int)]) assert len(r) == 1 - r.register(Signature(float)) + r = Resolver([Signature(int), Signature(float)]) assert len(r) == 2 - r.register(Signature(float)) + r = Resolver([Signature(int), Signature(float), Signature(float)]) assert len(r) == 2 @@ -80,16 +69,18 @@ class Missing: s_u = Signature(Unrelated) s_m = Signature(Missing) - r = Resolver() - r.register(s_b1) - # Import this after `s_b1` to test all branches. - r.register(s_a) - r.register(s_b2) - # Do not register `s_c1`. - r.register(s_c2) - r.register(s_u) - # Also do not register `s_m`. - + r = Resolver( + [ + s_b1, + # Import this after `s_b1` to test all branches. + s_a, + s_b2, + # Do not register `s_c1`. + s_c2, + s_u, + # Also do not register `s_m`. + ] + ) # Resolve by signature. assert r.resolve(s_a) == s_a assert r.resolve(s_b1) == s_b1