Skip to content

Commit

Permalink
[nnx] fix nanobind
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 2, 2024
1 parent 9578456 commit cf8fbfc
Show file tree
Hide file tree
Showing 18 changed files with 1,056 additions and 131 deletions.
1 change: 1 addition & 0 deletions benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import optax
from time import time


from flax import nnx

from absl import flags
Expand Down
9 changes: 9 additions & 0 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@


class Config:
flax_use_flaxlib: bool
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True


def __init__(self):
self._values = {}

Expand Down Expand Up @@ -69,6 +71,7 @@ def update(self, name_or_holder, value, /):


class FlagHolder(Generic[_T]):

def __init__(self, name, help):
self.name = name
self.__name__ = name[4:] if name.startswith('flax_') else name
Expand Down Expand Up @@ -201,3 +204,9 @@ def temp_flip_flag(var_name: str, var_value: bool):
' PRNG keys.'
),
)

flax_use_flaxlib = bool_flag(
name='flax_use_flaxlib',
default=False,
help='Whether to use flaxlib for C++ acceleration.',
)
4 changes: 2 additions & 2 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def extract_graph_nodes(
| tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]
):
"""Extracts all graph nodes from a pytree."""
nodes = graph.RefMap[tp.Any, Index]()
nodes = graph.RefMap[tp.Any, Index]({})
node_prefixes = []
leaves = []

Expand Down Expand Up @@ -324,7 +324,7 @@ def to_tree(

assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []
node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]()
node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]({})

with graph.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
Expand Down
Loading

0 comments on commit cf8fbfc

Please sign in to comment.