diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 401aee5..59f03db 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -379,7 +379,12 @@ def wrapped_fn(*args, **kwargs): raise except Exception as e: argmsg = _get_problem_arg( - param_signature, args, kwargs, module, typechecker + param_signature, + args, + kwargs, + bound.arguments, + module, + typechecker, ) try: name = fn.__name__ @@ -390,7 +395,8 @@ def wrapped_fn(*args, **kwargs): msg = ( "Type-check error whilst checking the parameters of " f"{name}.{argmsg}\n" - f"Called with arguments: {param_values}\n" + "----------------------\n" + f"Called with parameters: {param_values}\n" f"Parameter annotations: {param_hints}.\n" + _exc_shape_info(memos) ) @@ -440,10 +446,11 @@ def wrapped_fn(*args, **kwargs): msg = ( "Type-check error whilst checking the return value " f"of {name}.\n" - f"Called with arguments: {param_values}\n" - f"Return value: {return_value}\n" + f"Actual value: {return_value}\n" + f"Expected type: {return_hint}.\n" + "----------------------\n" + f"Called with parameters: {param_values}\n" f"Parameter annotations: {param_hints}.\n" - f"Return annotation: {return_hint}.\n" + _exc_shape_info(memos) ) if config.jaxtyping_remove_typechecker_stack: @@ -660,7 +667,7 @@ def _make_argpiece(p, name_to_annotation, name_to_default): def _get_problem_arg( - param_signature: inspect.Signature, args, kwargs, module, typechecker + param_signature: inspect.Signature, args, kwargs, arguments, module, typechecker ) -> str: """Determines which argument was likely to be the problematic one responsible for raising a type-check error. @@ -669,13 +676,17 @@ def _get_problem_arg( # anyway. for keep_name in param_signature.parameters.keys(): new_parameters = [] + keep_annotation = sentinel = object() for p_name, p in param_signature.parameters.items(): if p_name == keep_name: new_parameters.append( inspect.Parameter(p.name, p.kind, annotation=p.annotation) ) + assert keep_annotation is sentinel + keep_annotation = _remove_typing(p.annotation) else: new_parameters.append(inspect.Parameter(p.name, p.kind)) + assert keep_annotation is not sentinel new_signature = inspect.Signature(new_parameters) fn = _make_fn_with_signature( "check_single_arg", new_signature, module, output=False @@ -684,7 +695,12 @@ def _get_problem_arg( try: fn(*args, **kwargs) except Exception: - return f"\nThe problem arose whilst typechecking argument '{keep_name}'." + keep_value = _pformat(arguments[keep_name], short_self=False) + return ( + f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n" + f"Actual value: {keep_value}\n" + f"Expected type: {keep_annotation}." + ) else: # Could not localise the problem to a single argument -- probably due to # e.g. a mismatched typevar, which each individual argument is okay with. diff --git a/test/test_messages.py b/test/test_messages.py index 2568baf..2cab03c 100644 --- a/test/test_messages.py +++ b/test/test_messages.py @@ -14,8 +14,8 @@ def f(x: str, y: str, z: int): matches = [ "Type-check error whilst checking the parameters of f", - "The problem arose whilst typechecking argument 'z'.", - "Called with arguments: {'x': 'hi', 'y': 'bye', 'z': 'not-an-int'}", + "The problem arose whilst typechecking parameter 'z'.", + "Called with parameters: {'x': 'hi', 'y': 'bye', 'z': 'not-an-int'}", r"Parameter annotations: \(x: str, y: str, z: int\).", ] for match in matches: @@ -30,8 +30,8 @@ def g(x: Float[Array, "a b"], y: Float[Array, "b c"]): y = jnp.zeros((4, 3)) matches = [ "Type-check error whilst checking the parameters of g", - "The problem arose whilst typechecking argument 'y'.", - r"Called with arguments: {'x': f32\[2,3\], 'y': f32\[4,3\]}", + "The problem arose whilst typechecking parameter 'y'.", + r"Called with parameters: {'x': f32\[2,3\], 'y': f32\[4,3\]}", ( r"Parameter annotations: \(x: Float\[Array, 'a b'\], y: " r"Float\[Array, 'b c'\]\)." @@ -54,9 +54,9 @@ def f(x: PyTree[Any, " T"], y: PyTree[Any, " S"]) -> PyTree[Any, "T S"]: y = {"a": 1} matches = [ "Type-check error whilst checking the return value of f", - r"Called with arguments: {'x': \(1, 2\), 'y': {'a': 1}}", - "Return value: 'foo'", - r"Return annotation: PyTree\[Any, \"T S\"\].", + r"Called with parameters: {'x': \(1, 2\), 'y': {'a': 1}}", + "Actual value: 'foo'", + r"Expected type: PyTree\[Any, \"T S\"\].", ( "The current values for each jaxtyping PyTree structure annotation are as " "follows." @@ -82,9 +82,9 @@ class M(eqx.Module): matches = [ "Type-check error whilst checking the parameters of M", - "The problem arose whilst typechecking argument 'z'.", + "The problem arose whilst typechecking parameter 'z'.", ( - r"Called with arguments: {'self': M\(\.\.\.\), 'x': f32\[2,3\], " + r"Called with parameters: {'self': M\(\.\.\.\), 'x': f32\[2,3\], " r"'y': \(1, \(3, 4\)\), 'z': 'not-an-int'}" ), (