diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 38d598fcda..185e0bd9d9 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -569,23 +569,25 @@ def __call__( def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): """Initializes cache for fast autoregressive decoding. When ``decode=True``, this method must be called first before performing - forward inference. + forward inference. When in decode mode, only one token must be passed + at a time. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp ... - >>> rngs = nnx.Rngs(42) + >>> batch_size = 5 + >>> embed_dim = 3 + >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... - >>> x = jnp.ones((1, 3)) >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, - ... rngs=rngs, + ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized