diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 1ea33d78d5..09ed2a124c 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1336,9 +1336,13 @@ def per_node_state(ns: extract.NodeStates | tp.Any): ): return ns - def per_node_def(nd: graph.NodeDef | tp.Any): + def per_node_def(nd: graph.NodeDef | graph.NodeRef): if nd.index >= 0: global_index_mapping[nd.index] = nd.index + + if isinstance(nd, graph.NodeRef): + return + for sub_nd in nd.subgraphs.values(): per_node_def(sub_nd) for l in nd.leaves.values(): diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 4c327e1970..ad10a606ea 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2934,6 +2934,40 @@ def fwd_fn(i, input): _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) np.testing.assert_array_equal(y, x * 2 * 3) + def test_fori_loop_with_sharing(self): + class A(nnx.Object): + def __init__(self): + self.params = nnx.Param(jnp.zeros((10,), dtype=int)) + + class B(nnx.Object): + def __init__(self, a: A): + self.a = a + + class C(nnx.Object): + def __init__(self, a: A): + self.a = a + + class D(nnx.Object): + def __init__(self): + self.a = A() + self.b = B(self.a) + self.c = C(self.a) + + def increment(_, d: D) -> D: + d.a.params += 1 + return d + + @nnx.jit + def rollout(d: D): + nnx.fori_loop(0, 10, increment, d) + + d = D() + rollout(d) + + np.testing.assert_array_equal( + d.a.params.value, np.full((10,), 10, dtype=int) + ) + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self):