From 363f3dfa2ee181c56b8d92875e786396988af33b Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 12 Oct 2024 16:13:08 +0100 Subject: [PATCH] [nnx] improve init_cache docs --- flax/nnx/nn/attention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 38d598fcda..185e0bd9d9 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -569,23 +569,25 @@ def __call__( def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): """Initializes cache for fast autoregressive decoding. When ``decode=True``, this method must be called first before performing - forward inference. + forward inference. When in decode mode, only one token must be passed + at a time. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp ... - >>> rngs = nnx.Rngs(42) + >>> batch_size = 5 + >>> embed_dim = 3 + >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... - >>> x = jnp.ones((1, 3)) >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, - ... rngs=rngs, + ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized