Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add FlatState #4410

Merged
merged 1 commit into from
Dec 4, 2024
Merged

[nnx] add FlatState #4410

merged 1 commit into from
Dec 4, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Nov 29, 2024

What does this PR do?

  • Adds FlatState[V] which is a Sequence[tuple[PathParts, V]] that implements a trivial / low overhead pytree definition. The idea for subsequent optimization PRs is to have to_tree and from_tree primarily use FlatState instead of State to speed up training loops.
  • Optimizers some functions in graph.py.

@cgarciae cgarciae force-pushed the nnx-optimizations-p3 branch from 4a65bda to 1da88f8 Compare November 29, 2024 17:59
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae changed the title [nnx] optimize graph [nnx] add FlatState Nov 29, 2024
@cgarciae cgarciae force-pushed the nnx-optimizations-p3 branch 2 times, most recently from 6c135f7 to 9578456 Compare December 2, 2024 22:03
def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]):
for key, value in items:
self.set_key(node, key, value)
# def init(self, node: Node, items: tp.Iterable[tuple[Key, Leaf]]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

@@ -202,8 +205,8 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]):
self._mapping = dict(mapping)
def __init__(self, mapping: tp.Mapping[HA, HB], no_copy: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be copy=True instead? no_copy=False is double negation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True

@@ -1787,7 +1794,7 @@ def is_pytree_node(x: tp.Any) -> bool:
elif isinstance(x, Variable):
return False
# knon pytree types
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: known

@@ -1787,7 +1794,7 @@ def is_pytree_node(x: tp.Any) -> bool:
elif isinstance(x, Variable):
return False
# knon pytree types
elif isinstance(x, (VariableState, State)):
elif type(x) is VariableState or type(x) is State:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not just a performance optimization, right? You no longer allow subclasses of VariableState or State.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. These should be treated more like Rust enums.

flat_state[path] = f(path, variable_state)
return State.from_flat_path(flat_state)
result = []
for path, variable_state in flat_state:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this can be a list comprehension.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, updated it to a list comprehension

keys and returns True if the nested mapping is a leaf (i.e., should not be
flattened further).
sep: if specified, then the keys of the returned mapping will be
``sep``-joined strings (if ``None``, then keys will be tuples).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: blank line before "Returns".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

is_leaf: an optional function that takes the next nested mapping and nested
keys and returns True if the nested mapping is a leaf (i.e., should not be
flattened further).
sep: if specified, then the keys of the returned mapping will be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

xs: Mapping[Any, Any],
/,
*,
is_leaf: None | IsLeafCallable = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: None should come last.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@cgarciae cgarciae force-pushed the nnx-optimizations-p3 branch 2 times, most recently from 4aaed5f to 363acd9 Compare December 4, 2024 15:07
@cgarciae cgarciae force-pushed the nnx-optimizations-p3 branch from 363acd9 to d9acfb5 Compare December 4, 2024 15:19
@cgarciae cgarciae force-pushed the nnx-optimizations-p3 branch from d9acfb5 to 217f5ba Compare December 4, 2024 15:33
@copybara-service copybara-service bot merged commit 7eac051 into main Dec 4, 2024
19 checks passed
@copybara-service copybara-service bot deleted the nnx-optimizations-p3 branch December 4, 2024 15:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants