Skip to content

Commit

Permalink
Improved error messages a little bit, in particular to highlight indi…
Browse files Browse the repository at this point in the history
…vidual problematic arguments.
  • Loading branch information
patrick-kidger committed Dec 6, 2023
1 parent 33cf4fc commit 3378dac
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
30 changes: 23 additions & 7 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,12 @@ def wrapped_fn(*args, **kwargs):
raise
except Exception as e:
argmsg = _get_problem_arg(
param_signature, args, kwargs, module, typechecker
param_signature,
args,
kwargs,
bound.arguments,
module,
typechecker,
)
try:
name = fn.__name__
Expand All @@ -390,7 +395,8 @@ def wrapped_fn(*args, **kwargs):
msg = (
"Type-check error whilst checking the parameters of "
f"{name}.{argmsg}\n"
f"Called with arguments: {param_values}\n"
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
+ _exc_shape_info(memos)
)
Expand Down Expand Up @@ -440,10 +446,11 @@ def wrapped_fn(*args, **kwargs):
msg = (
"Type-check error whilst checking the return value "
f"of {name}.\n"
f"Called with arguments: {param_values}\n"
f"Return value: {return_value}\n"
f"Actual value: {return_value}\n"
f"Expected type: {return_hint}.\n"
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
f"Return annotation: {return_hint}.\n"
+ _exc_shape_info(memos)
)
if config.jaxtyping_remove_typechecker_stack:
Expand Down Expand Up @@ -660,7 +667,7 @@ def _make_argpiece(p, name_to_annotation, name_to_default):


def _get_problem_arg(
param_signature: inspect.Signature, args, kwargs, module, typechecker
param_signature: inspect.Signature, args, kwargs, arguments, module, typechecker
) -> str:
"""Determines which argument was likely to be the problematic one responsible for
raising a type-check error.
Expand All @@ -669,13 +676,17 @@ def _get_problem_arg(
# anyway.
for keep_name in param_signature.parameters.keys():
new_parameters = []
keep_annotation = sentinel = object()
for p_name, p in param_signature.parameters.items():
if p_name == keep_name:
new_parameters.append(
inspect.Parameter(p.name, p.kind, annotation=p.annotation)
)
assert keep_annotation is sentinel
keep_annotation = _remove_typing(p.annotation)
else:
new_parameters.append(inspect.Parameter(p.name, p.kind))
assert keep_annotation is not sentinel
new_signature = inspect.Signature(new_parameters)
fn = _make_fn_with_signature(
"check_single_arg", new_signature, module, output=False
Expand All @@ -684,7 +695,12 @@ def _get_problem_arg(
try:
fn(*args, **kwargs)
except Exception:
return f"\nThe problem arose whilst typechecking argument '{keep_name}'."
keep_value = _pformat(arguments[keep_name], short_self=False)
return (
f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n"
f"Actual value: {keep_value}\n"
f"Expected type: {keep_annotation}."
)
else:
# Could not localise the problem to a single argument -- probably due to
# e.g. a mismatched typevar, which each individual argument is okay with.
Expand Down
18 changes: 9 additions & 9 deletions test/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def f(x: str, y: str, z: int):

matches = [
"Type-check error whilst checking the parameters of f",
"The problem arose whilst typechecking argument 'z'.",
"Called with arguments: {'x': 'hi', 'y': 'bye', 'z': 'not-an-int'}",
"The problem arose whilst typechecking parameter 'z'.",
"Called with parameters: {'x': 'hi', 'y': 'bye', 'z': 'not-an-int'}",
r"Parameter annotations: \(x: str, y: str, z: int\).",
]
for match in matches:
Expand All @@ -30,8 +30,8 @@ def g(x: Float[Array, "a b"], y: Float[Array, "b c"]):
y = jnp.zeros((4, 3))
matches = [
"Type-check error whilst checking the parameters of g",
"The problem arose whilst typechecking argument 'y'.",
r"Called with arguments: {'x': f32\[2,3\], 'y': f32\[4,3\]}",
"The problem arose whilst typechecking parameter 'y'.",
r"Called with parameters: {'x': f32\[2,3\], 'y': f32\[4,3\]}",
(
r"Parameter annotations: \(x: Float\[Array, 'a b'\], y: "
r"Float\[Array, 'b c'\]\)."
Expand All @@ -54,9 +54,9 @@ def f(x: PyTree[Any, " T"], y: PyTree[Any, " S"]) -> PyTree[Any, "T S"]:
y = {"a": 1}
matches = [
"Type-check error whilst checking the return value of f",
r"Called with arguments: {'x': \(1, 2\), 'y': {'a': 1}}",
"Return value: 'foo'",
r"Return annotation: PyTree\[Any, \"T S\"\].",
r"Called with parameters: {'x': \(1, 2\), 'y': {'a': 1}}",
"Actual value: 'foo'",
r"Expected type: PyTree\[Any, \"T S\"\].",
(
"The current values for each jaxtyping PyTree structure annotation are as "
"follows."
Expand All @@ -82,9 +82,9 @@ class M(eqx.Module):

matches = [
"Type-check error whilst checking the parameters of M",
"The problem arose whilst typechecking argument 'z'.",
"The problem arose whilst typechecking parameter 'z'.",
(
r"Called with arguments: {'self': M\(\.\.\.\), 'x': f32\[2,3\], "
r"Called with parameters: {'self': M\(\.\.\.\), 'x': f32\[2,3\], "
r"'y': \(1, \(3, 4\)\), 'z': 'not-an-int'}"
),
(
Expand Down

0 comments on commit 3378dac

Please sign in to comment.