Skip to content

Commit

Permalink
Added environment config flags.
Browse files Browse the repository at this point in the history
These flags are `JAXTYPING_DISABLE` and `JAXTYPING_REMOVE_TYPECHECKER_STACK`.

In addition, have now added warnings when using old-style double-decorator syntax, which also serves to guard against the easy mistake of
```python
@jaxtyped(typechecker)
def foo(...)
```
which actually decorates the `typechecker`, not `foo`.
  • Loading branch information
patrick-kidger committed Dec 6, 2023
1 parent 8e47c90 commit 125bc89
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 46 deletions.
1 change: 1 addition & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
has_jax,
set_array_name_format as set_array_name_format,
)
from ._config import config as config
from ._decorator import jaxtyped as jaxtyped, TypeCheckError as TypeCheckError
from ._import_hook import install_import_hook as install_import_hook
from ._ipython_extension import load_ipython_extension as load_ipython_extension
Expand Down
48 changes: 48 additions & 0 deletions jaxtyping/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
from typing import Union


def _maybestr2bool(value: Union[bool, str], error: str) -> bool:
if isinstance(value, bool):
return value
elif isinstance(value, str):
if value.lower() in ("0", "false"):
return False
elif value.lower() in ("1", "true"):
return True
else:
raise ValueError(error)
else:
raise ValueError(error)


class _JaxtypingConfig:
def __init__(self):
self.update("jaxtyping_disable", os.environ.get("JAXTYPING_DISABLE", "0"))
self.update(
"jaxtyping_remove_typechecker_stack",
os.environ.get("JAXTYPING_REMOVE_TYPECHECKER_STACK", "0"),
)

def update(self, item: str, value):
if item.lower() == "jaxtyping_disable":
msg = (
"Unrecognised value for `JAXTYPING_DISABLE`. Valid values are "
"`JAXTYPING_DISABLE=0` (the default) or `JAXTYPING_DISABLE=1` (to "
"disable runtime type checking)."
)
self.jaxtyping_disable = _maybestr2bool(value, msg)
elif item.lower() == "jaxtyping_remove_typechecker_stack":
msg = (
"Unrecognised value for `JAXTYPING_REMOVE_TYPECHECKER_STACK`. Valid "
"values are `JAXTYPING_REMOVE_TYPECHECKER_STACK=0` (the default) or "
"`JAXTYPING_REMOVE_TYPECHECKER_STACK=1` (to remove the stack frames "
"from the typechecker in `jaxtyped(typechecker=...)`, when it raises a "
"runtime type-checking error)."
)
self.jaxtyping_remove_typechecker_stack = _maybestr2bool(value, msg)
else:
raise ValueError(f"Unrecognised config value {item}")


config = _JaxtypingConfig()
123 changes: 93 additions & 30 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import itertools as it
import sys
import types
import warnings
import weakref
from typing import Any, get_args, get_origin, get_type_hints, overload

Expand All @@ -35,6 +36,7 @@
traceback_util.register_exclusion(__file__)


from ._config import config
from ._storage import pop_shape_memo, push_shape_memo


Expand All @@ -48,17 +50,25 @@ class TypeCheckError(TypeError):
TypeCheckError.__module__ = "jaxtyping" # appears in error messages


class _Sentinel:
def __repr__(self):
return "sentinel"


_sentinel = _Sentinel()


@overload
def jaxtyped(*, typechecker=None):
def jaxtyped(*, typechecker=_sentinel):
...


@overload
def jaxtyped(fn, *, typechecker=None):
def jaxtyped(fn, *, typechecker=_sentinel):
...


def jaxtyped(fn=None, *, typechecker=None):
def jaxtyped(fn=_sentinel, *, typechecker=_sentinel):
"""Decorate a function with this to perform runtime type-checking of its arguments
and return value. Decorate a dataclass to perform type-checking of its attributes.
Expand Down Expand Up @@ -90,8 +100,9 @@ class MyDataclass:
**Arguments:**
- `fn`: The function or dataclass to decorate.
- `typechecker`: The runtime type-checker to use. This should be a function
decorator that will raise an exception if there is a type error, e.g.
- `typechecker`: Keyword-only argument: the runtime type-checker to use. This should
be a function decorator that will raise an exception if there is a type error,
e.g.
```python
@typechecker
def f(x: int):
Expand All @@ -104,7 +115,7 @@ def f(x: int):
skip automatic runtime type-checking, but still support manual `isinstance`
checks inside the function body:
```python
@jaxtyped
@jaxtyped(typechecker=None)
def f(x):
assert isinstance(x, Float[Array, "batch channel"])
```
Expand All @@ -126,10 +137,10 @@ def f(x):
@typechecker
def f(...): ...
```
This is still supported, but the `jaxtyped(typechecker=typechecker)` syntax
discussed above will produce easier-to-debug error messages. Under the hood, the
new syntax more carefully manipulates the typechecker so as to determine where
a type-check error arises.
This is still supported, but will now raise a warning recommending the
`jaxtyped(typechecker=typechecker)` syntax discussed above. (Which will produce
easier-to-debug error messages: under the hood, the new syntax more carefully
manipulates the typechecker so as to determine where a type-check error arises.)
??? Info "Notes for advanced users"
Expand Down Expand Up @@ -163,23 +174,75 @@ def f(...): ...
**Decoupling contexts from function calls:**
If you would like a new dynamic context *without* calling a new function, then
`jaxtyped` may be passed the string `"context"` and used as a context manager:
If you would like to call a new function *without* creating a new
dynamic context (and using the same set of axis and structure values), then
simply do not add a `jaxtyped` decorator to your inner function, whilst
continuing to perform type-checking in whatever way you prefer.
Conversely, if you would like a new dynamic context *without* calling a new
function, then in addition to the usage discussed above, `jaxtyped` also
supports being used as a context manager, by passing it the string `"context"`:
```python
with jaxtyped("context"):
assert isinstance(x, Float[Array, "batch channel"])
```
which is equivalent to placing this code inside a new function wrapped in
This is equivalent to placing this code inside a new function wrapped in
`jaxtyped(typechecker=None)`. Usage like this is very rare; it's mostly only
useful when working at the global scope.
Conversely, if you would like to call a new function *without* creating a new
dynamic context (and using the same set of axis and structure values), then
simply do not add a `jaxtyped` decorator to your inner function, whilst
continuing to perform type-checking in whatever way you prefer.
"""

if fn is None:
# First handle the `jaxtyped("context")` usage, which is a special case.
if fn == "context":
if typechecker is not _sentinel:
raise ValueError(
"Cannot use `jaxtyped` as a context with a typechecker. That is, "
"`with jaxtyped('context', typechecker=...):`. is not allowed. In this "
"case the type checker does not actually do anything, as there is no "
"function to type-check."
)
return _JaxtypingContext()

# Now check that a typechecker has been explicitly declared. (Or explicitly declared
# as not being used, via `typechecker=None`.)
# This is needed just for backward compatibility: an undeclared typechecker
# corresponds to the old double-decorator syntax.
if typechecker is _sentinel:
# This branch will also catch the easy-to-make mistake of
# ```python
# @jaxtyped(typechecker)
# def foo(...):
# ```
# which is a bug as `typechecker` is interpreted as the function to decorate!
warnings.warn(
"As of jaxtyping version 0.2.24, jaxtyping now prefers the syntax\n"
"```\n"
"from jaxtyping import jaxtyped\n"
"# Use your favourite typechecker: usually one of the two lines below.\n"
"from typeguard import typechecked as typechecker\n"
"from beartype import beartype as typechecker\n"
"\n"
"@jaxtyped(typechecker=typechecker)\n"
"def foo(...):\n"
"```\n"
"and the old double-decorator syntax\n"
"```\n"
"@jaxtyped\n"
"@typechecker\n"
"def foo(...):\n"
"```\n"
"should no longer be used. (It will continue to work as it did before, but "
"the new approach will produce more readable error messages.)\n"
"In particular note that `typechecker` must be passed via keyword "
"argument; the following is not valid:\n"
"```\n"
"@jaxtyped(typechecker)\n"
"def foo(...):\n"
"```\n",
stacklevel=2,
)
typechecker = None

if fn is _sentinel:
return ft.partial(jaxtyped, typechecker=typechecker)
elif type(fn) is types.FunctionType and fn in _jaxtyped_fns:
return fn
Expand Down Expand Up @@ -235,15 +298,6 @@ def __init__(self, *args, **kwargs):
else:
fdel = jaxtyped(fn.fdel, typechecker=typechecker)
return property(fget=fget, fset=fset, fdel=fdel)
elif fn == "context":
if typechecker is not None:
raise ValueError(
"Cannot use `jaxtyped` as a context with a typechecker. That is, "
"`with jaxtyped('context', typechecker=...):`. is not allowed. In this "
"case the type checker does not actually do anything, as there is no "
"function to type-check."
)
return _JaxtypingContext()
else:
if typechecker is None:
# Probably being used in the old style as
Expand Down Expand Up @@ -321,6 +375,9 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore

@ft.wraps(fn)
def wrapped_fn(*args, **kwargs):
if config.jaxtyping_disable:
return fn(*args, **kwargs)

# Raise bind-time errors before we do any shape analysis. (I.e. skip
# the pointless jaxtyping information for a non-typechecking failure.)
bound = param_signature.bind(*args, **kwargs)
Expand Down Expand Up @@ -351,7 +408,10 @@ def wrapped_fn(*args, **kwargs):
f"Parameter annotations: {param_hints}.\n"
+ _exc_shape_info(memos)
)
raise TypeCheckError(msg) from e
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
else:
raise TypeCheckError(msg) from e

# Actually call the function.
out = fn(*args, **kwargs)
Expand Down Expand Up @@ -403,7 +463,10 @@ def wrapped_fn(*args, **kwargs):
f"Return annotation: {return_hint}.\n"
+ _exc_shape_info(memos)
)
raise TypeCheckError(msg) from e
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
else:
raise TypeCheckError(msg) from e

return out
finally:
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def jaxtyp(request):
# def f(...)
def impl(typechecker):
def decorator(fn):
return jaxtyping.jaxtyped(typechecker(fn))
with pytest.warns(match="As of jaxtyping version 0.2.24"):
return jaxtyping.jaxtyped(typechecker(fn))

return decorator

Expand Down
28 changes: 15 additions & 13 deletions test/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,49 +9,49 @@


class M(metaclass=abc.ABCMeta):
@jaxtyped
@jaxtyped(typechecker=None)
def f(self):
...

@jaxtyped
@jaxtyped(typechecker=None)
@classmethod
def g1(cls):
return 3

@classmethod
@jaxtyped
@jaxtyped(typechecker=None)
def g2(cls):
return 4

@jaxtyped
@jaxtyped(typechecker=None)
@staticmethod
def h1():
return 3

@staticmethod
@jaxtyped
@jaxtyped(typechecker=None)
def h2():
return 4

@jaxtyped
@jaxtyped(typechecker=None)
@abc.abstractmethod
def i1(self):
...

@abc.abstractmethod
@jaxtyped
@jaxtyped(typechecker=None)
def i2(self):
...


class N:
@jaxtyped
@jaxtyped(typechecker=None)
@property
def j1(self):
return 3

@property
@jaxtyped
@jaxtyped(typechecker=None)
def j2(self):
return 4

Expand Down Expand Up @@ -154,10 +154,12 @@ def f(x: "LocalFoo") -> "LocalFoo":

f(LocalFoo())

@jaxtyped
@typecheck
def g(x: "LocalFoo") -> "LocalFoo":
return x
with pytest.warns(match="As of jaxtyping version 0.2.24"):

@jaxtyped
@typecheck
def g(x: "LocalFoo") -> "LocalFoo":
return x

g(LocalFoo())

Expand Down
3 changes: 1 addition & 2 deletions test/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def join(self, timeout=None):


def test_threading_jaxtyped():
@jaxtyped
@typechecked
@jaxtyped(typechecker=typechecked)
def add(x: Float[Array, "a b"], y: Float[Array, "a b"]) -> Float[Array, "a b"]:
return x + y

Expand Down

0 comments on commit 125bc89

Please sign in to comment.