Skip to content

Commit

Permalink
Merge pull request #4145 from IvyZX:bdg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668633713
  • Loading branch information
Flax Authors committed Aug 28, 2024
2 parents afbc502 + 4bb2152 commit 839db8c
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 74 deletions.
4 changes: 3 additions & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from flax.typing import Initializer as Initializer

from .nnx.bridge import wrappers as wrappers
from .nnx.bridge.variables import (
register_variable_name_type_pair as register_variable_name_type_pair,
)
from .nnx import graph as graph
from .nnx import errors as errors
from .nnx import helpers as helpers
Expand Down Expand Up @@ -124,7 +127,6 @@
from .nnx.training import metrics as metrics
from .nnx.variables import (
Param as Param,
register_variable_name_type_pair as register_variable_name_type_pair,
)
# this needs to be imported before optimizer to prevent circular import
from .nnx.training import optimizer as optimizer
Expand Down
3 changes: 2 additions & 1 deletion flax/nnx/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
from .wrappers import ToNNX as ToNNX
from .wrappers import lazy_init as lazy_init
from .wrappers import ToLinen as ToLinen
from .wrappers import to_linen as to_linen
from .wrappers import to_linen as to_linen
from .variables import NNXMeta as NNXMeta
138 changes: 138 additions & 0 deletions flax/nnx/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, TypeVar

import jax
from flax import struct
from flax.core import meta
from flax.nnx.nnx import variables as variableslib
import typing as tp


A = TypeVar('A')
B = TypeVar('B')


#######################################################
### Variable type <-> Linen collection name mapping ###
#######################################################
# Assumption: the mapping is 1-1 and unique.

VariableTypeCache: dict[str, tp.Type[variableslib.Variable[tp.Any]]] = {}


def variable_type(name: str) -> tp.Type[variableslib.Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its corresponding NNX Variable type."""
if name not in VariableTypeCache:
VariableTypeCache[name] = type(name, (variableslib.Variable,), {})
return VariableTypeCache[name]


def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:
"""Given an NNX Variable type, get or create its Linen-style collection name.
Should output the exact inversed result of `variable_type()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name_type_pair(name, typ)
return name


def register_variable_name_type_pair(name, typ, overwrite = False):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
if not overwrite and name in VariableTypeCache:
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call with `overwrite=True`.')
VariableTypeCache[name] = typ


# add known variable type names
register_variable_name_type_pair('params', variableslib.Param)
register_variable_name_type_pair('batch_stats', variableslib.BatchStat)
register_variable_name_type_pair('cache', variableslib.Cache)
register_variable_name_type_pair('intermediates', variableslib.Intermediate)


def sort_variable_types(types: tp.Iterable[type]):
def _variable_parents_count(t: type):
return sum(1 for p in t.mro() if issubclass(p, variableslib.Variable))
parent_count = {t: _variable_parents_count(t) for t in types}
return sorted(types, key=lambda t: -parent_count[t])


#############################################
### NNX Variable <-> Linen metadata boxes ###
#############################################


class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
"""Default Flax metadata class for `nnx.VariableState`.
"""

var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
value: Any = struct.field(pytree_node=True)
metadata: dict[str, tp.Any] = struct.field(pytree_node=False)

def unbox(self) -> A:
return self.value

def replace_boxed(self, val: B) -> 'NNXMeta[B]':
return self.replace(value=val) # type: ignore

def add_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
# TODO: implement this, supporting hooks
return self

def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
# TODO: implement this, supporting hooks
return self


def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
metadata = vs.get_metadata()
if 'linen_meta_type' in metadata:
if metadata['linen_meta_type'] is not meta.Partitioned:
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
return NNXMeta(vs.type, vs.value, vs.get_metadata())


def get_col_name(keypath: tp.Sequence[Any]) -> str:
"""Given the keypath of a Flax variable type, return its Linen collection name."""
# Infer variable type from the leaf's path, which contains its Linen collection name
assert isinstance(keypath[0], jax.tree_util.DictKey)
return str(keypath[0].key)


def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
"""Convert a Linen variable to an NNX variable.
This process needs the collection name,
"""
vtype = variable_type(col)
if isinstance(x, NNXMeta):
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
return x.var_type(x.value, **x.metadata)
if isinstance(x, meta.AxisMetadata):
if isinstance(x, meta.Partitioned):
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
return vtype(x)
43 changes: 29 additions & 14 deletions flax/nnx/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from flax import nnx
from flax import linen
from flax.core import meta
from flax.nnx.nnx import graph
from flax.nnx.nnx import variables as variableslib
from flax.nnx.nnx.bridge import variables as bv
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.rnglib import Rngs
from flax.nnx.nnx.state import State
Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(
):
self.module = module
self.rngs = rngs
self.linen_collections: set[str] = set()
self.linen_collections: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
return lazy_init(self, *args, **kwargs)
Expand All @@ -143,16 +144,20 @@ def __call__(
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)
def nn_var_to_nnx_state(kp, v):
assert isinstance(kp[0], jtu.DictKey)
vtype = variableslib.variable_type(kp[0].key)
return vtype(v)
for col, tree in jtu.tree_map_with_path(nn_var_to_nnx_state, variables).items():
self._setattr(col, tree)
self.linen_collections.add(col)

nnx_vars = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
linen_collections = set()
for col, tree in nnx_vars.items():
setattr(self, col, tree)
linen_collections.add(col)
self.linen_collections = tuple(linen_collections) # make it hashable

else:
variables = {col: jax.tree.map(lambda v: v.value, getattr(self, col))
variables = {col: jax.tree.map(lambda x: bv.to_linen_var(x.to_state()),
getattr(self, col),
is_leaf=lambda x: isinstance(x, nnx.Variable))
for col in self.linen_collections}
_rngs = (
{name: stream() for name, stream in rngs.items()} if rngs else {}
Expand All @@ -162,8 +167,11 @@ def nn_var_to_nnx_state(kp, v):
# Split out the updates if `mutable` is passed into the Flax module
if kwargs.get('mutable', False) != False:
out, updates = out
updates = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x),
updates, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
for collection, value in updates.items():
self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value))
setattr(self, collection, value)

return out

Expand Down Expand Up @@ -214,6 +222,7 @@ class ToLinen(linen.Module):
args: tp.Sequence = ()
kwargs: tp.Mapping = dataclasses.field(default_factory=dict)
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

def update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
Expand All @@ -225,14 +234,16 @@ def update_variables(self, module):
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = variableslib.sort_variable_types(types)
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = variableslib.variable_type_name(typ)
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)

@linen.compact
Expand All @@ -250,7 +261,11 @@ def __call__(self, *args, **kwargs):
# apply codepath
gdef = self.get_variable('nnx', 'graphdef')
assert gdef, 'GraphDef not found in variables. Was the collection "nnx" dropped somewhere?'
states = [State(state) for col, state in self.variables.items() if col != 'nnx']
variables = {col: v for col, v in self.variables.items() if col != 'nnx'}
states = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
states = [State(v) for v in states.values()]
nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({})
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
Expand Down
51 changes: 3 additions & 48 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2023 The Flax Authors.
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Expand Down Expand Up @@ -998,48 +998,3 @@ def wrapper(*args):

return wrapper # type: ignore


### Variable type <-> name mapping ###
# Assumption: the mapping is 1-1 and unique.

def variable_type(name: str) -> tp.Type[Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its corresponding NNX Variable type."""
if name not in VariableTypeCache:
VariableTypeCache[name] = type(name, (Variable,), {})
return VariableTypeCache[name]


def variable_type_name(typ: tp.Type[Variable[tp.Any]]) -> str:
"""Given an NNX Variable type, get or create its Linen-style collection name.
Should output the exact inversed result of `variable_type()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name_type_pair(name, typ)
return name


def register_variable_name_type_pair(name, typ):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
VariableTypeCache[name] = typ


# add known variable type names
register_variable_name_type_pair('params', Param)
register_variable_name_type_pair('batch_stats', BatchStat)
register_variable_name_type_pair('cache', Cache)
register_variable_name_type_pair('intermediates', Intermediate)


def sort_variable_types(types: list[type]):
def _variable_parents_count(t: type):
return sum(1 for p in t.mro() if issubclass(p, nnx.Variable))
parent_count = {t: _variable_parents_count(t) for t in types}
return sorted(types, key=lambda t: -parent_count[t])
Loading

0 comments on commit 839db8c

Please sign in to comment.