From 6ad09a3dd6e12492017d614186f4f35d2e0562a4 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Thu, 31 Oct 2024 16:39:12 -0700 Subject: [PATCH] add --- .../api_reference/flax.nnx/transforms.rst | 1 + flax/nnx/__init__.py | 1 + flax/nnx/transforms/iteration.py | 91 ++++++++++++++++++- tests/nnx/transforms_test.py | 13 +++ 4 files changed, 105 insertions(+), 1 deletion(-) diff --git a/docs_nnx/api_reference/flax.nnx/transforms.rst b/docs_nnx/api_reference/flax.nnx/transforms.rst index aead2f7841..54ba3399a3 100644 --- a/docs_nnx/api_reference/flax.nnx/transforms.rst +++ b/docs_nnx/api_reference/flax.nnx/transforms.rst @@ -24,3 +24,4 @@ transforms .. autofunction:: cond .. autofunction:: switch .. autofunction:: while_loop +.. autofunction:: fori_loop diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index c670cc8556..affa691d07 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -152,6 +152,7 @@ from .transforms.transforms import cond as cond from .transforms.transforms import switch as switch from .transforms.iteration import while_loop as while_loop +from .transforms.iteration import fori_loop as fori_loop from .transforms.iteration import StateAxes as StateAxes from .variablelib import A as A from .variablelib import BatchStat as BatchStat diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 466e307e90..a59fdbd8fa 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1409,7 +1409,7 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any], """NNX transform of `jax.lax.while_loop `_. Caution: for the NNX internal reference tracing mechanism to work, you cannot - change the reference structure of `init_val` inside `body_fun`. + change the variable reference structure of `init_val` inside `body_fun`. Example:: @@ -1448,4 +1448,93 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any], pure_init_val, ) out = extract.from_tree(pure_out, ctxtag='while_loop') + return out + + +@dataclasses.dataclass(eq=False) +class ForiLoopBodyFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + @graph.update_context('fori_loop_body') + def __call__(self, i, pure_val): + # Removing the dummy index mapping being added outside of body function. + pure_val_in = _remove_index_mapping(pure_val) + + val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body') + out = self.f(i, val) + pure_out = extract.to_tree(out, ctxtag='fori_loop_body') + + try: + jax.tree.map(lambda a, b: None, pure_val, pure_out) + except ValueError as e: + msg = ("nnx.fori_loop requires body function's input and output to " + "have the same reference and pytree structure, but they differ. " + "If the mismatch comes from `index_mapping` field, you might " + "have modified reference structure within the body function, " + "which is not allowed." + f"Detail of the mismatch: \n {str(e)}") + raise ValueError(msg) + + return pure_out + + +@graph.update_context('fori_loop') +def fori_loop(lower: int, upper: int, + body_fun: tp.Callable[[int, T], T], + init_val: T, + *, + unroll: int | bool | None = None) -> T: + """NNX transform of `jax.lax.fori_loop `_. + + Caution: for the NNX internal reference tracing mechanism to work, you cannot + change the variable reference structure of `init_val` inside `body_fun`. + + Example:: + + >>> import jax + >>> from flax import nnx + + >>> def fwd_fn(i, input): + ... m, x = input + ... m.kernel.value = jnp.identity(10) * i + ... return m, m(x) + + >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + >>> x = jax.random.normal(jax.random.key(0), (10,)) + >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) + >>> np.testing.assert_array_equal(y, x * 2 * 3) + + + Args: + lower: an integer representing the loop index lower bound (inclusive) + upper: an integer representing the loop index upper bound (exclusive) + body_fun: a function that takes an input of type `T` and outputs an `T`. + Note that both data and modules of `T` must have the same reference + structure between inputs and outputs. + init_val: the initial input for body_fun. Must be of type `T`. + unroll: An optional integer or boolean that determines how much to unroll + the loop. If an integer is provided, it determines how many unrolled + loop iterations to run within a single rolled iteration of the loop. If a + boolean is provided, it will determine if the loop is competely unrolled + (i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`). + This argument is only applicable if the loop bounds are statically known. + + Returns: + Loop value from the final iteration, of type ``T``. + + """ + + pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop') + + # Adding the expected reference mapping to `pure_init_val` to match + # `body_fun`'s output pytree structure, to make JAX happy. + pure_init_val = _add_fake_index_mapping(pure_init_val) + + pure_out = jax.lax.fori_loop(lower, upper, + ForiLoopBodyFn(body_fun), pure_init_val, + unroll=unroll) + out = extract.from_tree(pure_out, ctxtag='fori_loop') return out \ No newline at end of file diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 160855158b..4c327e1970 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2922,6 +2922,19 @@ def body_fn(val): (0, m, m), ) + def test_fori_loop_basic(self): + def fwd_fn(i, input): + m, x = input + m.kernel.value = jnp.identity(10) * i + return m, m(x) + + module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) + x = jax.random.normal(jax.random.key(0), (10,)) + + _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) + np.testing.assert_array_equal(y, x * 2 * 3) + + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self): class StatefulLinear(nnx.Linear):