Skip to content

Commit

Permalink
Merge pull request #4379 from google:nnx-fix-fori-loop
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696274913
  • Loading branch information
Flax Authors committed Nov 13, 2024
2 parents 480a196 + c48905a commit ac3e85a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
6 changes: 5 additions & 1 deletion flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,9 +1336,13 @@ def per_node_state(ns: extract.NodeStates | tp.Any):
):
return ns

def per_node_def(nd: graph.NodeDef | tp.Any):
def per_node_def(nd: graph.NodeDef | graph.NodeRef):
if nd.index >= 0:
global_index_mapping[nd.index] = nd.index

if isinstance(nd, graph.NodeRef):
return

for sub_nd in nd.subgraphs.values():
per_node_def(sub_nd)
for l in nd.leaves.values():
Expand Down
34 changes: 34 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2934,6 +2934,40 @@ def fwd_fn(i, input):
_, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
np.testing.assert_array_equal(y, x * 2 * 3)

def test_fori_loop_with_sharing(self):
class A(nnx.Object):
def __init__(self):
self.params = nnx.Param(jnp.zeros((10,), dtype=int))

class B(nnx.Object):
def __init__(self, a: A):
self.a = a

class C(nnx.Object):
def __init__(self, a: A):
self.a = a

class D(nnx.Object):
def __init__(self):
self.a = A()
self.b = B(self.a)
self.c = C(self.a)

def increment(_, d: D) -> D:
d.a.params += 1
return d

@nnx.jit
def rollout(d: D):
nnx.fori_loop(0, 10, increment, d)

d = D()
rollout(d)

np.testing.assert_array_equal(
d.a.params.value, np.full((10,), 10, dtype=int)
)


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
Expand Down

0 comments on commit ac3e85a

Please sign in to comment.