-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add FlatState #4410
[nnx] add FlatState #4410
Conversation
4a65bda
to
1da88f8
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
6c135f7
to
9578456
Compare
flax/nnx/graph.py
Outdated
def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]): | ||
for key, value in items: | ||
self.set_key(node, key, value) | ||
# def init(self, node: Node, items: tp.Iterable[tuple[Key, Leaf]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
flax/nnx/graph.py
Outdated
@@ -202,8 +205,8 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: | |||
|
|||
|
|||
class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): | |||
def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]): | |||
self._mapping = dict(mapping) | |||
def __init__(self, mapping: tp.Mapping[HA, HB], no_copy: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be copy=True
instead? no_copy=False
is double negation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True
flax/nnx/graph.py
Outdated
@@ -1787,7 +1794,7 @@ def is_pytree_node(x: tp.Any) -> bool: | |||
elif isinstance(x, Variable): | |||
return False | |||
# knon pytree types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: known
@@ -1787,7 +1794,7 @@ def is_pytree_node(x: tp.Any) -> bool: | |||
elif isinstance(x, Variable): | |||
return False | |||
# knon pytree types | |||
elif isinstance(x, (VariableState, State)): | |||
elif type(x) is VariableState or type(x) is State: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not just a performance optimization, right? You no longer allow subclasses of VariableState
or State
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. These should be treated more like Rust enums.
flax/nnx/statelib.py
Outdated
flat_state[path] = f(path, variable_state) | ||
return State.from_flat_path(flat_state) | ||
result = [] | ||
for path, variable_state in flat_state: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: this can be a list comprehension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, updated it to a list comprehension
flax/nnx/traversals.py
Outdated
keys and returns True if the nested mapping is a leaf (i.e., should not be | ||
flattened further). | ||
sep: if specified, then the keys of the returned mapping will be | ||
``sep``-joined strings (if ``None``, then keys will be tuples). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: blank line before "Returns".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
flax/nnx/traversals.py
Outdated
is_leaf: an optional function that takes the next nested mapping and nested | ||
keys and returns True if the nested mapping is a leaf (i.e., should not be | ||
flattened further). | ||
sep: if specified, then the keys of the returned mapping will be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
flax/nnx/traversals.py
Outdated
xs: Mapping[Any, Any], | ||
/, | ||
*, | ||
is_leaf: None | IsLeafCallable = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: None should come last.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
4aaed5f
to
363acd9
Compare
363acd9
to
d9acfb5
Compare
d9acfb5
to
217f5ba
Compare
What does this PR do?
FlatState[V]
which is aSequence[tuple[PathParts, V]]
that implements a trivial / low overhead pytree definition. The idea for subsequent optimization PRs is to haveto_tree
andfrom_tree
primarily useFlatState
instead ofState
to speed up training loops.graph.py
.