From 1d4d40294ca6b007c5e4e30699b13332f9346214 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 12 Feb 2024 01:07:02 +0000 Subject: [PATCH] Added support for beartype 0.17.0's __instancecheck_str__. Recall that jaxtyping will currently generate rich error messages in precisely one scenario: about the arguments and return types when doing: ```python @jaxtyped(typechecker=beartype) def foo(...): ... ``` With this commit we add support for beartype 0.17.0's pseudo-standard `__instancecheck_str__`, which means the following: 1. For those using beartype decorators, the following will *also* generate an informative error message, and moreover it will state exactly why (shape mismatch, dtype mismatch etc): ```python @jaxtyped(typechecker=None) @beartype def foo(...): ... ``` (In practice we probably won't recommend the above combination in the docs just to keep things simple.) 2. For those using the beartype import hook together with the jaxtyping import hook, we can probably also check `assert isinstance(x, Float[Array, "foo"])` statements with rich error messages. (#153) We'll need to test + document that though. (@jeezrick interested?) 3. For those using plain `assert isinstance(...)` statements without beartype (#167, tagging @reinerp), then they can *also* get rich error messages by doing ```python tt = Float[Array, "foo"] assert isinstance(x, tt), tt.__instancecheck_str__(x) + "\n" + print_bindings() ``` which is still a bit long-winded right now but is a step in the right direction. (CC @leycec for interest.) --- jaxtyping/_array_types.py | 69 +++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index 417dc69..9f18045 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -115,7 +115,7 @@ def _check_dims( obj_shape: tuple[int, ...], single_memo: dict[str, int], arg_memo: dict[str, Any], -) -> bool: +) -> str: assert len(cls_dims) == len(obj_shape) for cls_dim, obj_size in zip(cls_dims, obj_shape): if cls_dim is _anonymous_dim: @@ -124,7 +124,7 @@ def _check_dims( pass elif type(cls_dim) is _FixedDim: if cls_dim.size != obj_size: - return False + return f"the dimension size {obj_size} does not equal {cls_dim.size} as expected by the type hint" # noqa: E501 elif type(cls_dim) is _SymbolicDim: try: # Support f-string syntax. @@ -141,7 +141,7 @@ def _check_dims( "arguments." ) from e if eval_size != obj_size: - return False + return f"the dimension size {obj_size} does not equal the existing value of {cls_dim.elem}={eval_size}" # noqa: E501 else: assert type(cls_dim) is _NamedDim if cls_dim.treepath: @@ -154,16 +154,19 @@ def _check_dims( single_memo[name] = obj_size else: if cls_size != obj_size: - return False - return True + return f"the size of dimension {cls_dim.name} is {obj_size} which does not equal the existing value of {cls_size}" # noqa: E501 + return "" class _MetaAbstractArray(type): - def __instancecheck__(cls, obj): + def __instancecheck__(cls, obj: Any) -> bool: + return cls.__instancecheck_str__(obj) == "" + + def __instancecheck_str__(cls, obj: Any) -> str: if not isinstance(obj, cls.array_type): - return False + return f"this value is not an instance of the underlying array type {cls.array_type}" # noqa: E501 if get_treeflatten_memo(): - return True + return "" if hasattr(obj.dtype, "type") and hasattr(obj.dtype.type, "__name__"): # JAX, numpy @@ -193,7 +196,10 @@ def __instancecheck__(cls, obj): if in_dtypes: break if not in_dtypes: - return False + if len(cls.dtypes) == 1: + return f"this array has dtype {dtype}, not {cls.dtypes[0]} as expected by the type hint" # noqa: E501 + else: + return f"this array has dtype {dtype}, not any of {cls.dtypes} as expected by the type hint" # noqa: E501 single_memo, variadic_memo, pytree_memo, arg_memo = get_shape_memo() single_memo_bak = single_memo.copy() @@ -207,13 +213,13 @@ def __instancecheck__(cls, obj): single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak ) raise - if check: - return True + if check == "": + return check else: set_shape_memo( single_memo_bak, variadic_memo_bak, pytree_memo_bak, arg_memo_bak ) - return False + return check def _check_shape( cls, @@ -221,27 +227,32 @@ def _check_shape( single_memo: dict[str, int], variadic_memo: dict[str, tuple[bool, tuple[int, ...]]], arg_memo: dict[str, Any], - ): + ) -> str: if cls.index_variadic is None: if obj.ndim != len(cls.dims): - return False + return f"this array has {obj.ndim} dimensions, not the {len(cls.dims)} expected by the type hint" # noqa: E501 return _check_dims(cls.dims, obj.shape, single_memo, arg_memo) else: if obj.ndim < len(cls.dims) - 1: - return False + return f"this array has {obj.ndim} dimensions, which is fewer than {len(cls.dims - 1)} that is the minimum expected by the type hint" # noqa: E501 i = cls.index_variadic j = -(len(cls.dims) - i - 1) if j == 0: j = None - if not _check_dims(cls.dims[:i], obj.shape[:i], single_memo, arg_memo): - return False - if j is not None and not _check_dims( - cls.dims[j:], obj.shape[j:], single_memo, arg_memo - ): - return False + prefix_check = _check_dims( + cls.dims[:i], obj.shape[:i], single_memo, arg_memo + ) + if prefix_check != "": + return prefix_check + if j is not None: + suffix_check = _check_dims( + cls.dims[j:], obj.shape[j:], single_memo, arg_memo + ) + if suffix_check != "": + return suffix_check variadic_dim = cls.dims[i] if variadic_dim is _anonymous_variadic_dim: - return True + return "" else: assert type(variadic_dim) is _NamedVariadicDim if variadic_dim.treepath: @@ -253,16 +264,16 @@ def _check_shape( prev_broadcastable, prev_shape = variadic_memo[name] except KeyError: variadic_memo[name] = (broadcastable, obj.shape[i:j]) - return True + return "" else: new_shape = obj.shape[i:j] if prev_broadcastable: try: broadcast_shape = np.broadcast_shapes(new_shape, prev_shape) except ValueError: # not broadcastable e.g. (3, 4) and (5,) - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501 if not broadcastable and broadcast_shape != new_shape: - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which the existing value of {prev_shape} cannot be broadcast to" # noqa: E501 variadic_memo[name] = (broadcastable, broadcast_shape) else: if broadcastable: @@ -271,13 +282,13 @@ def _check_shape( new_shape, prev_shape ) except ValueError: # not broadcastable e.g. (3, 4) and (5,) - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast with the existing value of {prev_shape}" # noqa: E501 if broadcast_shape != prev_shape: - return False + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which cannot be broadcast to the existing value of {prev_shape}" # noqa: E501 else: if new_shape != prev_shape: - return False - return True + return f"the shape of its variadic dimensions '*{variadic_dim.name}' is {new_shape}, which does not equal the existing value of {prev_shape}" # noqa: E501 + return "" assert False