Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nnx.fori_loop #4353

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs_nnx/api_reference/flax.nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ transforms
.. autofunction:: cond
.. autofunction:: switch
.. autofunction:: while_loop
.. autofunction:: fori_loop
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
from .transforms.transforms import cond as cond
from .transforms.transforms import switch as switch
from .transforms.iteration import while_loop as while_loop
from .transforms.iteration import fori_loop as fori_loop
from .transforms.iteration import StateAxes as StateAxes
from .variablelib import A as A
from .variablelib import BatchStat as BatchStat
Expand Down
91 changes: 90 additions & 1 deletion flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,7 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any],
"""NNX transform of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.

Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the reference structure of `init_val` inside `body_fun`.
change the variable reference structure of `init_val` inside `body_fun`.

Example::

Expand Down Expand Up @@ -1448,4 +1448,93 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any],
pure_init_val,
)
out = extract.from_tree(pure_out, ctxtag='while_loop')
return out


@dataclasses.dataclass(eq=False)
class ForiLoopBodyFn:
f: tp.Callable[..., tp.Any]

def __post_init__(self):
functools.update_wrapper(self, self.f)

@graph.update_context('fori_loop_body')
def __call__(self, i, pure_val):
# Removing the dummy index mapping being added outside of body function.
pure_val_in = _remove_index_mapping(pure_val)

val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body')
out = self.f(i, val)
pure_out = extract.to_tree(out, ctxtag='fori_loop_body')

try:
jax.tree.map(lambda a, b: None, pure_val, pure_out)
except ValueError as e:
msg = ("nnx.fori_loop requires body function's input and output to "
"have the same reference and pytree structure, but they differ. "
"If the mismatch comes from `index_mapping` field, you might "
"have modified reference structure within the body function, "
"which is not allowed."
f"Detail of the mismatch: \n {str(e)}")
raise ValueError(msg)

return pure_out


@graph.update_context('fori_loop')
def fori_loop(lower: int, upper: int,
body_fun: tp.Callable[[int, T], T],
init_val: T,
*,
unroll: int | bool | None = None) -> T:
"""NNX transform of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.

Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.

Example::

>>> import jax
>>> from flax import nnx

>>> def fwd_fn(i, input):
... m, x = input
... m.kernel.value = jnp.identity(10) * i
... return m, m(x)

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
>>> np.testing.assert_array_equal(y, x * 2 * 3)


Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
structure between inputs and outputs.
init_val: the initial input for body_fun. Must be of type `T`.
unroll: An optional integer or boolean that determines how much to unroll
the loop. If an integer is provided, it determines how many unrolled
loop iterations to run within a single rolled iteration of the loop. If a
boolean is provided, it will determine if the loop is competely unrolled
(i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`).
This argument is only applicable if the loop bounds are statically known.

Returns:
Loop value from the final iteration, of type ``T``.

"""

pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop')

# Adding the expected reference mapping to `pure_init_val` to match
# `body_fun`'s output pytree structure, to make JAX happy.
pure_init_val = _add_fake_index_mapping(pure_init_val)

pure_out = jax.lax.fori_loop(lower, upper,
ForiLoopBodyFn(body_fun), pure_init_val,
unroll=unroll)
out = extract.from_tree(pure_out, ctxtag='fori_loop')
return out
13 changes: 13 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2922,6 +2922,19 @@ def body_fn(val):
(0, m, m),
)

def test_fori_loop_basic(self):
def fwd_fn(i, input):
m, x = input
m.kernel.value = jnp.identity(10) * i
return m, m(x)

module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.key(0), (10,))

_, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
np.testing.assert_array_equal(y, x * 2 * 3)


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
class StatefulLinear(nnx.Linear):
Expand Down
Loading