Skip to content

Commit

Permalink
Update dataclass docs (#155)
Browse files Browse the repository at this point in the history
* Update dataclass docs
  • Loading branch information
patrick-kidger authored Jan 5, 2024
1 parent adf1a5e commit eb9a23d
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def batch_outer_product(x: Float[Array, "b c1"],
return x[:, :, None] * y[:, None, :]
# Type-check a dataclass
from dataclasses import dataclass
@jaxtyped(typechecker=typechecker)
@dataclass
class MyDataclass:
Expand All @@ -88,7 +90,18 @@ class MyDataclass:
**Arguments:**
- `fn`: The function or dataclass to decorate.
- `fn`: The function or dataclass to decorate. In practice if you want to use
dataclasses with JAX, then
[`equinox.Module`](https://docs.kidger.site/equinox/api/module/module/) is our
recommended approach:
```python
import equinox as eqx
@jaxtyped(typechecker=typechecker)
class MyModule(eqx.Module):
...
```
- `typechecker`: Keyword-only argument: the runtime type-checker to use. This should
be a function decorator that will raise an exception if there is a type error,
e.g.
Expand Down

0 comments on commit eb9a23d

Please sign in to comment.