Skip to content

Commit

Permalink
Resolver is now immutable (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse authored Aug 23, 2023
1 parent d9bc095 commit f5b21ea
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 60 deletions.
7 changes: 1 addition & 6 deletions plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def invalidate_resolver_and_cache(self):
@property
def resolver(self) -> Resolver:
if self._resolver is None:
self._resolver = Resolver()
self._resolve_pending_registrations()
self._resolver = Resolver(self.get_all_subsignatures())
return self._resolver

@property
Expand All @@ -144,10 +143,6 @@ def cache(self) -> dict:
self._cache = {}
return self._cache

def _resolve_pending_registrations(self) -> None:
for subsignature in self.get_all_subsignatures():
self._resolver.register(subsignature)

def get_all_subsignatures(self, strict: bool = True) -> Iterator[Signature]:
# Perform any pending registrations.
for f, signature, precedence in self._all_methods:
Expand Down
29 changes: 5 additions & 24 deletions plum/resolver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union
from typing import Iterable, List, Tuple, Union

from plum.signature import Signature

Expand All @@ -21,29 +21,10 @@ class Resolver:
is_faithful (bool): Whether all signatures are faithful or not.
"""

def __init__(self):
self.signatures: List[Signature] = []
self.is_faithful: bool = True

def register(self, signature: Signature) -> None:
"""Register a new signature.
Args:
signature (:class:`.signature.Signature`): Signature to add.
"""
existing = [s == signature for s in self.signatures]
if any(existing):
if sum(existing) != 1:
raise AssertionError(
f"The added signature `{signature}` is equal to {sum(existing)} "
f"existing signatures. This should never happen."
)
self.signatures[existing.index(True)] = signature
else:
self.signatures.append(signature)

# Use a double negation for slightly better performance.
self.is_faithful = not any(not s.is_faithful for s in self.signatures)
def __init__(self, signatures: Iterable[Signature]):
signatures_dict = {hash(s): s for s in signatures}
self.signatures: List[Signature] = list(signatures_dict.values())
self.is_faithful: bool = all(s.is_faithful for s in self.signatures)

def __len__(self) -> int:
return len(self.signatures)
Expand Down
51 changes: 21 additions & 30 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,35 @@


def test_initialisation():
r = Resolver()
r = Resolver([])
# Without any registered signatures, the resolver should be faithful.
assert r.is_faithful


def test_register():
r = Resolver()

# Test that faithfulness is tracked correctly.
r.register(Signature(int))
r.register(Signature(float))
r = Resolver([Signature(int), Signature(float)])
assert r.is_faithful
r.register(Signature(Tuple[int]))
r = Resolver([Signature(int), Signature(float), Signature(Tuple[int])])
assert not r.is_faithful

# Test that signatures can be replaced.
new_s = Signature(float)
assert len(r) == 3
new_s = Signature(float)
assert r.signatures[1] is not new_s
r.register(new_s)
r = Resolver([Signature(int), Signature(float), Signature(Tuple[int]), new_s])
assert len(r) == 3
assert r.signatures[1] is new_s

# Test the edge case that should never happen.
r.signatures[2] = Signature(float)
with pytest.raises(
AssertionError,
match=r"(?i)the added signature `(.*)` is equal to 2 existing signatures",
):
r.register(Signature(float))


def test_len():
r = Resolver()
r = Resolver([])
assert len(r) == 0
r.register(Signature(int))
r = Resolver([Signature(int)])
assert len(r) == 1
r.register(Signature(float))
r = Resolver([Signature(int), Signature(float)])
assert len(r) == 2
r.register(Signature(float))
r = Resolver([Signature(int), Signature(float), Signature(float)])
assert len(r) == 2


Expand Down Expand Up @@ -80,16 +69,18 @@ class Missing:
s_u = Signature(Unrelated)
s_m = Signature(Missing)

r = Resolver()
r.register(s_b1)
# Import this after `s_b1` to test all branches.
r.register(s_a)
r.register(s_b2)
# Do not register `s_c1`.
r.register(s_c2)
r.register(s_u)
# Also do not register `s_m`.

r = Resolver(
[
s_b1,
# Import this after `s_b1` to test all branches.
s_a,
s_b2,
# Do not register `s_c1`.
s_c2,
s_u,
# Also do not register `s_m`.
]
)
# Resolve by signature.
assert r.resolve(s_a) == s_a
assert r.resolve(s_b1) == s_b1
Expand Down

0 comments on commit f5b21ea

Please sign in to comment.