From eb9a23df63a315c433b2b7c4a0075bcfd019b4f0 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 5 Jan 2024 05:17:45 -0800 Subject: [PATCH] Update dataclass docs (#155) * Update dataclass docs --- jaxtyping/_decorator.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 549d725..af18be9 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -79,6 +79,8 @@ def batch_outer_product(x: Float[Array, "b c1"], return x[:, :, None] * y[:, None, :] # Type-check a dataclass + from dataclasses import dataclass + @jaxtyped(typechecker=typechecker) @dataclass class MyDataclass: @@ -88,7 +90,18 @@ class MyDataclass: **Arguments:** - - `fn`: The function or dataclass to decorate. + - `fn`: The function or dataclass to decorate. In practice if you want to use + dataclasses with JAX, then + [`equinox.Module`](https://docs.kidger.site/equinox/api/module/module/) is our + recommended approach: + ```python + import equinox as eqx + + @jaxtyped(typechecker=typechecker) + class MyModule(eqx.Module): + ... + ``` + - `typechecker`: Keyword-only argument: the runtime type-checker to use. This should be a function decorator that will raise an exception if there is a type error, e.g.