Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partially revert #4192 which sets back a bunch of previous merged pushes. #4201

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import abc
import dataclasses
import functools
from typing import Any, Generic, TypeVar
from collections.abc import Callable
Expand Down Expand Up @@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
"""Returns the ``NamedSharding`` for this partitioned value."""
return jax.sharding.NamedSharding(mesh, self.get_partition_spec())

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
metadata['names'] = metadata.pop('sharding')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_partitioning(
fn: Callable[..., Any],
Expand Down
9 changes: 9 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def __reduce__(self):
return (FlaxError, (str(self),))


#################################################
# NNX errors #
#################################################


class TraceContextError(FlaxError):
pass


#################################################
# lazy_init.py errors #
#################################################
Expand Down
15 changes: 15 additions & 0 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any:
else:
return self.value

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
metadata['sharding_rules'] = metadata.pop('rules')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
metadata['names'] = metadata.pop('sharding')
metadata['rules'] = metadata.pop('sharding_rules')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_logical_partitioning(
fn: Callable[..., Any],
Expand Down
30 changes: 15 additions & 15 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:


def register_variable_name_type_pair(name, typ, overwrite = False):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
"""Register a pair of Linen collection name and its NNX type."""
if not overwrite and name in VariableTypeCache:
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call with `overwrite=True`.')
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
VariableTypeCache[name] = typ


Expand All @@ -85,8 +85,7 @@ def _variable_parents_count(t: type):


class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
"""Default Flax metadata class for `nnx.VariableState`.
"""
"""Default Flax metadata class for `nnx.VariableState`."""

var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
value: Any = struct.field(pytree_node=True)
Expand All @@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
metadata = vs.get_metadata()
if 'linen_meta_type' in metadata:
if metadata['linen_meta_type'] is not meta.Partitioned:
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
return NNXMeta(vs.type, vs.value, vs.get_metadata())
linen_type = metadata['linen_meta_type']
if hasattr(linen_type, 'from_nnx_metadata'):
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
return linen_type(vs.value, **metadata)
return NNXMeta(vs.type, vs.value, metadata)


def get_col_name(keypath: tp.Sequence[Any]) -> str:
Expand All @@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:


def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
"""Convert a Linen variable to an NNX variable.
This process needs the collection name,
"""
"""Convert a Linen variable to an NNX variable."""
vtype = variable_type(col)
if isinstance(x, NNXMeta):
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
return x.var_type(x.value, **x.metadata)
if isinstance(x, meta.AxisMetadata):
if isinstance(x, meta.Partitioned):
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
return vtype(x)
x_metadata = vars(x)
if hasattr(x, 'to_nnx_metadata'):
x_metadata = x.to_nnx_metadata()
assert hasattr(x, 'value')
return vtype(**x_metadata, linen_meta_type=type(x))
return vtype(x)
53 changes: 27 additions & 26 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
module = fn
assert callable(fn)
else:
if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)):
if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module):
raise ValueError(f'{fn = } needs to be a method of an NNX Module.')
module = fn.__self__
_set_initializing(module, True)
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(
self.linen_collections: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
return lazy_init(self, *args, **kwargs)

def __call__(
Expand Down Expand Up @@ -224,28 +225,6 @@ class ToLinen(linen.Module):
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

def update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)

@linen.compact
def __call__(self, *args, **kwargs):
# init codepath
Expand All @@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs):
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
self.update_variables(module)
self._update_variables(module)
return module(*args, **kwargs)

# apply codepath
Expand All @@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs):
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
self._update_variables(module)
return out

def _update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)


def to_linen(nnx_class: tp.Callable[..., Module], *args,
name: str | None = None, **kwargs):
"""Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields."""
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)
17 changes: 0 additions & 17 deletions flax/nnx/errors.py

This file was deleted.

18 changes: 10 additions & 8 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from flax import struct
from flax.nnx.object import Object
from flax.typing import MISSING, PathParts
from flax.typing import Missing, PathParts
from flax.nnx import graph


Expand Down Expand Up @@ -59,7 +59,7 @@ def extract_graph_nodes(
pytree: A,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> (
tuple[A, tuple[tp.Any, ...]]
Expand Down Expand Up @@ -101,7 +101,7 @@ def extract_graph_nodes(

pytree_out = jax.tree.unflatten(treedef, leaves)

if prefix is MISSING:
if prefix is Missing:
return pytree_out, tuple(nodes) # type: ignore[bad-return-type]
else:
return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type]
Expand Down Expand Up @@ -330,12 +330,13 @@ def to_tree(
tree,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
split_fn: tp.Callable[
[graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any
] = default_split_fn,
map_non_graph_nodes: bool = False,
ctxtag: str | None = None,
check_aliasing: bool = True,
) -> tp.Any:
leaf_prefixes = broadcast_prefix(
prefix,
Expand All @@ -351,9 +352,10 @@ def to_tree(
with graph.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if graph.is_graph_node(leaf):
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
)
if check_aliasing:
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
)
tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(tree_node)
else:
Expand Down Expand Up @@ -381,7 +383,7 @@ def from_tree(
tree: tp.Any,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
merge_fn: tp.Callable[
[graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any
] = merge_tree_node,
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import numpy as np

from flax.nnx import (
errors,
reprlib,
tracers,
)
from flax.nnx import graph
from flax.nnx.variables import Variable, VariableState
from flax.typing import Key
from flax import errors

G = tp.TypeVar('G', bound='Object')

Expand Down
10 changes: 8 additions & 2 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _add_axis(x: tp.Any):
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

x.add_axis(axis_name, index)
x.add_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any):
sharding = list(x.sharding)
assert sharding.pop(index) == axis_name
x.sharding = tuple(sharding)
x.remove_axis(axis_name, index)
x.remove_axis(index, axis_name)
return x

return jax.tree.map(
Expand Down Expand Up @@ -89,9 +89,15 @@ def _maybe_replicate(x):
else:
return None

def from_rules(sharding, sharding_rules):
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return (rules[s] if s in rules else s for s in sharding)

def f(x):
if isinstance(x, (variables.VariableState, variables.Variable)):
if hasattr(x, 'sharding') and x.sharding:
if hasattr(x, 'sharding_rules') and x.sharding_rules:
return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules)))
return x.replace(PartitionSpec(*x.sharding))
else:
return x.replace(_maybe_replicate(x.value))
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def jit_wrapper(*args, **kwargs):
(args, kwargs),
prefix=(in_shardings, kwarg_shardings),
split_fn=_jit_split_fn,
check_aliasing=in_shardings is not None,
ctxtag='jit',
)
pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
Expand Down
Loading
Loading