diff --git a/tests/test_method.py b/tests/test_method.py index 19fcfea6..8919ba4b 100644 --- a/tests/test_method.py +++ b/tests/test_method.py @@ -1,5 +1,9 @@ +import textwrap from copy import copy +import rich + +from plum import Dispatcher from plum.method import Method from plum.signature import Signature @@ -90,20 +94,83 @@ def test_equality(): assert m != 1 -def test_repr(): - def f(x) -> float: +def _rich_render(x): + console = rich.get_console() + with console.capture() as capture: + console.print(x) + return capture.get() + + +def test_repr_simple(): + def f(x, *args) -> float: return x m = Method( f, - Signature(int), + Signature(int, varargs=object), return_type=complex, function_name="different_name", ) result = ( - f"different_name(x: int) -> complex\n" - f" .f at {hex(id(f))}> @" + f"different_name(x: int, *args: object) -> complex\n" + f" .f at {hex(id(f))}> @" ) assert repr(m).startswith(result) + # Also render the fully mismatched version. When rendered to text, that should + # give the same. + assert _rich_render(m.repr_mismatch({0}, False)).startswith(result) + + +def test_repr_complex(): + def f(x, *, option, **kw_args) -> float: + return x + + m = Method( + f, + Signature(int, precedence=1), + return_type=complex, + function_name="different_name", + ) + + result = ( + f"different_name(x: int, *, option, **kw_args) -> complex\n" + f" precedence=1\n" + f" .f at {hex(id(f))}> @" + ) + + assert repr(m).startswith(result) + # Also render the fully mismatched version. When rendered to text, that should + # give the same. + assert _rich_render(m.repr_mismatch({0}, False)).startswith(result) + + +def test_methodlist_repr(monkeypatch): + dispatch = Dispatcher() + + @dispatch + def f(x: int): + pass + + @dispatch + def f(x: float): + pass + + imp1 = f.methods[0].implementation + imp2 = f.methods[1].implementation + + result = textwrap.dedent( + f""" + List of 2 method(s): + [0] f(x: int) + .f at {hex(id(imp1))}> @ + [1] f(x: float) + .f at {hex(id(imp2))}> @ + """ + ) + lines = repr(f.methods).strip().splitlines() + # Remove the lines corresponding to the source of the functions. These are very + # conveniently broken onto new lines. + lines = [lines[0], lines[1], lines[2], lines[4], lines[5]] + assert "\n".join(line.rstrip() for line in lines) == result.strip()