Skip to content

Commit

Permalink
[nnx] add PrefixMapping
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 10, 2024
1 parent fec3284 commit 39565d3
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 14 deletions.
26 changes: 20 additions & 6 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import contextlib
import dataclasses
import threading
Expand All @@ -23,7 +24,7 @@
from flax import struct
from flax.nnx.object import Object
from flax.typing import Missing, PathParts
from flax.nnx import graph
from flax.nnx import graph, variablelib


A = tp.TypeVar('A')
Expand Down Expand Up @@ -119,6 +120,14 @@ def _maybe_insert(x):
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, ExtractionIndex)
)

class PrefixMapping(abc.ABC):
@abc.abstractmethod
def map_prefix(
self,
path: variablelib.PathParts,
variable: variablelib.Variable,
/,
) -> tp.Any: ...

def check_consistent_aliasing(
node: tuple[tp.Any, ...],
Expand All @@ -143,11 +152,16 @@ def check_consistent_aliasing(
raise ValueError(
f'Cannot extract graph node from different trace level, got {value!r}'
)
if value in node_prefixes:
paths_prefixes = node_prefixes[value]
paths_prefixes.append((path, prefix))
else:
node_prefixes[value] = [(path, prefix)]
if isinstance(prefix, PrefixMapping):
variable_prefix = prefix.map_prefix(path, value)
else:
variable_prefix = prefix

if value in node_prefixes:
paths_prefixes = node_prefixes[value]
paths_prefixes.append((path, variable_prefix))
else:
node_prefixes[value] = [(path, variable_prefix)]

# check for inconsistent aliasing
node_msgs = []
Expand Down
12 changes: 11 additions & 1 deletion flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
extract,
filterlib,
graph,
variablelib,
)
import jax
import jax.core
Expand All @@ -36,7 +37,7 @@
# -------------------------------


class StateSharding:
class StateSharding(extract.PrefixMapping):
def __init__(
self,
filter_sharding: tp.Mapping[filterlib.Filter, tp.Any]
Expand All @@ -59,6 +60,15 @@ def filters(self) -> tuple[filterlib.Filter, ...]:
def shardings(self) -> tuple[tp.Any, ...]:
return self._shardings

def map_prefix(
self, path: variablelib.PathParts, variable: variablelib.Variable
) -> tp.Any:
for filter, sharding in zip(self.filters, self.shardings):
predicate = filterlib.to_predicate(filter)
if predicate(path, variable):
return sharding
raise ValueError(f'No axis found for {path=}, {variable=}')

def __repr__(self):
return f'StateSharding({dict(zip(self.filters, self.shardings))})'

Expand Down
13 changes: 11 additions & 2 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx import extract, filterlib, graph, spmd
from flax.nnx import extract, filterlib, graph, spmd, variablelib
from flax.nnx.module import Module
from flax.nnx.statelib import State
from flax.nnx.transforms.transforms import resolve_kwargs
Expand Down Expand Up @@ -54,7 +54,7 @@ class Carry:
# -------------------------------


class StateAxes:
class StateAxes(extract.PrefixMapping):

def __init__(
self,
Expand All @@ -80,6 +80,15 @@ def filters(self) -> tuple[filterlib.Filter, ...]:
def axes(self) -> tuple[Index | type[Carry] | None, ...]:
return self._axes

def map_prefix(
self, path: variablelib.PathParts, variable: variablelib.Variable
) -> tp.Any:
for filter, axis in zip(self.filters, self.axes):
predicate = filterlib.to_predicate(filter)
if predicate(path, variable):
return axis
raise ValueError(f'No axis found for {path=}, {variable=}')

def __repr__(self):
return f'StateAxes({dict(self.items())})'

Expand Down
38 changes: 33 additions & 5 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2135,7 +2135,6 @@ def f(m):

def test_consistent_aliasing_shared(self):
class Shared(nnx.Module):

def __init__(self):
self.a = nnx.Param(jnp.zeros((3, 3)))

Expand All @@ -2148,17 +2147,46 @@ def __init__(self, shared: Shared):
m1 = Foo(shared)
m2 = Foo(shared)

@partial(nnx.vmap, in_axes=(0, 1))
@nnx.vmap(in_axes=(0, 1))
def f(m1, m2):
pass

with self.assertRaisesRegex(
ValueError,
r'Inconsistent aliasing detected([\s\S]*)Shared([\s\S]*)a:'
r' 0([\s\S]*)a: 1',
ValueError,
r'Inconsistent aliasing detected([\s\S]*)Param([\s\S]*)a:'
r' 0([\s\S]*)a: 1',
):
f(m1, m2)

def test_equivalent_state_axes_mapping(self):
m = nnx.Linear(3, 3, rngs=nnx.Rngs(0))

sa1 = nnx.StateAxes({...: 0})
sa2 = nnx.StateAxes({nnx.Param: 0})

@nnx.vmap(in_axes=(0, sa1, sa2))
def f(m1, m2, m3):
pass

f(m, m, m)

def test_equivalent_state_sharding_mapping(self):
m = nnx.Linear(3, 3, rngs=nnx.Rngs(0))

mesh = jax.sharding.Mesh(jax.devices(), ('mp',))
sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec('mp')
)

sa1 = nnx.StateSharding({...: sharding})
sa2 = nnx.StateSharding({nnx.Param: sharding})

@nnx.jit(in_shardings=(sharding, sa1, sa2))
def f(m1, m2, m3):
pass

f(m, m, m)

@absltest.skip('Enable once jax#19586 resolved')
def test_captured_module_in_return_error(self):
class Foo(nnx.Module):
Expand Down

0 comments on commit 39565d3

Please sign in to comment.