-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support linen <-> nnx metadata box converging in nnx.bridge
#4145
Conversation
aa7f87a
to
de17c01
Compare
|
||
def register_variable_name_type_pair(name, typ): | ||
"""Register a pair of variable type name (like Linen collections) and its NNX type.""" | ||
VariableTypeCache[name] = typ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should check that the name
does exist already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should allow overwrite of names if user wants to. I'll add an overwrite option there.
def unbox(self) -> A: | ||
return self.value | ||
|
||
def replace_boxed(self, val: B) -> 'NNXMeta[B]': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
String types not needed in python >= 3.10
def replace_boxed(self, val: B) -> 'NNXMeta[B]': | |
def replace_boxed(self, val: B) -> NNXMeta[B]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got error trying to do this, saying NNXMeta is not found
. I'll leave it as is for now.
flax/nnx/nnx/bridge/wrappers.py
Outdated
|
||
nnx_vars = jtu.tree_map_with_path(bv.to_nnx_var, variables, | ||
is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) | ||
for col, tree in nnx_vars.items(): | ||
self._setattr(col, tree) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit more idiomatic:
self._setattr(col, tree) | |
setattr(self, col, tree) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
flax/nnx/nnx/bridge/variables.py
Outdated
return NNXMeta(vs.type, vs.value, vs.get_metadata()) | ||
|
||
|
||
def to_nnx_var(kp: tp.Sequence[Any], x: meta.AxisMetadata | Any) -> variableslib.Variable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: We could avoid using key paths here by accepting the collection and iterating over all collections when used.
def to_nnx_var(kp: tp.Sequence[Any], x: meta.AxisMetadata | Any) -> variableslib.Variable: | |
def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: |
Its a bit more verbose but easier to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. We need the key path to get the collection name, so I just moved that part to wrapper.py, to make the function in variables.py cleaner.
flax/nnx/nnx/bridge/wrappers.py
Outdated
for collection, value in updates.items(): | ||
self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value)) | ||
self._setattr(collection, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._setattr(collection, value) | |
setattr(self, collection, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
flax/nnx/nnx/bridge/wrappers.py
Outdated
|
||
nnx_vars = jtu.tree_map_with_path(bv.to_nnx_var, variables, | ||
is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) | ||
for col, tree in nnx_vars.items(): | ||
self._setattr(col, tree) | ||
self.linen_collections.add(col) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For linen_collections
we should either implement this with a list
or create a PytreeSet
. The issue is that sets are not pytrees so they are treated as static but they are not hashable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list
is mutable so not hashable either. Curious why that matters though... but I'll make it a tuple.
@@ -225,14 +230,16 @@ def update_variables(self, module): | |||
types = set(jax.tree.leaves( | |||
jax.tree.map(lambda x: x.type, state, | |||
is_leaf=lambda x: isinstance(x, nnx.VariableState)))) | |||
types = variableslib.sort_variable_types(types) | |||
types = bv.sort_variable_types(types) | |||
_, *state_by_types = nnx.split(module, *types) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use State.split
:
_, *state_by_types = nnx.split(module, *types) | |
*state_by_types = state.split(*types) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks but nnx.split
is more convenient because it always return a tuple...
@@ -225,14 +230,16 @@ def update_variables(self, module): | |||
types = set(jax.tree.leaves( | |||
jax.tree.map(lambda x: x.type, state, | |||
is_leaf=lambda x: isinstance(x, nnx.VariableState)))) | |||
types = variableslib.sort_variable_types(types) | |||
types = bv.sort_variable_types(types) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sort_variable_types
expects type list
but we can change the type to Iterable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the catch! Done.
Supports conversion between Linen and NNX metadata boxes.
nnx.Variable
), so make sureToNNX
always make them.ToLinen
will always convertnnx.Variable
to a boxed Linen variable. By default, it would bennx.bridge.NNXMeta
, a class that captures everything that annnx.Variable
should.nn.Partitioned
should be preserved, and its axes names will be translated tonnx.Variable
sharding fields insideToNNX
.Also refactored variable-related functions that only
bridge
API uses outside ofnnx/variables.py
, into a separatennx/bridge/variables.py
.