Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support linen <-> nnx metadata box converging in nnx.bridge #4145

Merged
merged 1 commit into from
Aug 28, 2024

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Aug 27, 2024

Supports conversion between Linen and NNX metadata boxes.

  • NNX always has a metadata box (aka nnx.Variable), so make sure ToNNX always make them.
  • ToLinen will always convert nnx.Variable to a boxed Linen variable. By default, it would be nnx.bridge.NNXMeta, a class that captures everything that an nnx.Variable should.
  • Already existing boxes of nn.Partitioned should be preserved, and its axes names will be translated to nnx.Variable sharding fields inside ToNNX.

Also refactored variable-related functions that only bridge API uses outside of nnx/variables.py, into a separate nnx/bridge/variables.py.

@IvyZX IvyZX requested review from levskaya and cgarciae August 27, 2024 01:35
@IvyZX IvyZX force-pushed the bdg branch 2 times, most recently from aa7f87a to de17c01 Compare August 27, 2024 21:58

def register_variable_name_type_pair(name, typ):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
VariableTypeCache[name] = typ
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check that the name does exist already

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should allow overwrite of names if user wants to. I'll add an overwrite option there.

def unbox(self) -> A:
return self.value

def replace_boxed(self, val: B) -> 'NNXMeta[B]':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String types not needed in python >= 3.10

Suggested change
def replace_boxed(self, val: B) -> 'NNXMeta[B]':
def replace_boxed(self, val: B) -> NNXMeta[B]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got error trying to do this, saying NNXMeta is not found. I'll leave it as is for now.


nnx_vars = jtu.tree_map_with_path(bv.to_nnx_var, variables,
is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
for col, tree in nnx_vars.items():
self._setattr(col, tree)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit more idiomatic:

Suggested change
self._setattr(col, tree)
setattr(self, col, tree)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return NNXMeta(vs.type, vs.value, vs.get_metadata())


def to_nnx_var(kp: tp.Sequence[Any], x: meta.AxisMetadata | Any) -> variableslib.Variable:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: We could avoid using key paths here by accepting the collection and iterating over all collections when used.

Suggested change
def to_nnx_var(kp: tp.Sequence[Any], x: meta.AxisMetadata | Any) -> variableslib.Variable:
def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:

Its a bit more verbose but easier to understand.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. We need the key path to get the collection name, so I just moved that part to wrapper.py, to make the function in variables.py cleaner.

for collection, value in updates.items():
self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value))
self._setattr(collection, value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._setattr(collection, value)
setattr(self, collection, value)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


nnx_vars = jtu.tree_map_with_path(bv.to_nnx_var, variables,
is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
for col, tree in nnx_vars.items():
self._setattr(col, tree)
self.linen_collections.add(col)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For linen_collections we should either implement this with a list or create a PytreeSet. The issue is that sets are not pytrees so they are treated as static but they are not hashable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list is mutable so not hashable either. Curious why that matters though... but I'll make it a tuple.

@@ -225,14 +230,16 @@ def update_variables(self, module):
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = variableslib.sort_variable_types(types)
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use State.split:

Suggested change
_, *state_by_types = nnx.split(module, *types)
*state_by_types = state.split(*types)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks but nnx.split is more convenient because it always return a tuple...

@@ -225,14 +230,16 @@ def update_variables(self, module):
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = variableslib.sort_variable_types(types)
types = bv.sort_variable_types(types)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort_variable_types expects type list but we can change the type to Iterable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! Done.

@copybara-service copybara-service bot merged commit 839db8c into google:main Aug 28, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants