diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 549d725..af18be9 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -79,6 +79,8 @@ def batch_outer_product(x: Float[Array, "b c1"], return x[:, :, None] * y[:, None, :] # Type-check a dataclass + from dataclasses import dataclass + @jaxtyped(typechecker=typechecker) @dataclass class MyDataclass: @@ -88,7 +90,18 @@ class MyDataclass: **Arguments:** - - `fn`: The function or dataclass to decorate. + - `fn`: The function or dataclass to decorate. In practice if you want to use + dataclasses with JAX, then + [`equinox.Module`](https://docs.kidger.site/equinox/api/module/module/) is our + recommended approach: + ```python + import equinox as eqx + + @jaxtyped(typechecker=typechecker) + class MyModule(eqx.Module): + ... + ``` + - `typechecker`: Keyword-only argument: the runtime type-checker to use. This should be a function decorator that will raise an exception if there is a type error, e.g.