Skip to content

Commit

Permalink
[nnx] fix nanobind
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Dec 4, 2024
1 parent 9578456 commit d088e18
Show file tree
Hide file tree
Showing 21 changed files with 1,045 additions and 183 deletions.
15 changes: 14 additions & 1 deletion .github/workflows/flax_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,23 @@ jobs:
python-version: ['3.10', '3.11']
test-type: [doctest, pytest, pytype, mypy]
jax-version: [newest]
use-flaxlib: [true, false]
exclude:
- test-type: pytype
python-version: '3.10'
- test-type: mypy
python-version: '3.11'
- use-flaxlib: true
test-type: doctest
- use-flaxlib: true
test-type: pytype
- use-flaxlib: true
test-type: mypy
include:
- python-version: '3.10'
test-type: pytest
jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml
use-flaxlib: false
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -119,12 +127,17 @@ jobs:
else
uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}"
fi
if [[ "${{ matrix.use-flaxlib }}" == "true" ]]; then
uv pip install "nanobind" "scikit-build-core[pyproject]"
uv pip install -e flaxlib_src
fi
- name: Test with ${{ matrix.test-type }}
run: |
if [[ "${{ matrix.test-type }}" == "doctest" ]]; then
uv run tests/run_all_tests.sh --only-doctest
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
uv run tests/run_all_tests.sh --only-pytest
FLAX_USE_FLAXLIB=${{ matrix.use-flaxlib }} \
uv run tests/run_all_tests.sh --only-pytest
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
uv run tests/run_all_tests.sh --only-pytype
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
Expand Down
1 change: 1 addition & 0 deletions benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import optax
from time import time


from flax import nnx

from absl import flags
Expand Down
5 changes: 1 addition & 4 deletions docs_nnx/api_reference/flax.nnx/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ helpers
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Dict
:members:
.. autoclass:: List
:members:

.. autoclass:: Sequential
:members:
.. autoclass:: TrainState
Expand Down
11 changes: 11 additions & 0 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


class Config:
flax_use_flaxlib: bool
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True

Expand Down Expand Up @@ -62,6 +63,10 @@ def update(self, name_or_holder, value, /):
raise LookupError(f'Unrecognized config option: {name}')
self._values[name] = value

def __repr__(self):
values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items())
return f'Config({values_repr}\n)'


config = Config()

Expand Down Expand Up @@ -201,3 +206,9 @@ def temp_flip_flag(var_name: str, var_value: bool):
' PRNG keys.'
),
)

flax_use_flaxlib = bool_flag(
name='flax_use_flaxlib',
default=False,
help='Whether to use flaxlib for C++ acceleration.',
)
6 changes: 3 additions & 3 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def extract_graph_nodes(
| tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]
):
"""Extracts all graph nodes from a pytree."""
nodes = graph.RefMap[tp.Any, Index]()
nodes = graph.RefMap[tp.Any, Index]({})
node_prefixes = []
leaves = []

Expand Down Expand Up @@ -138,7 +138,7 @@ def check_consistent_aliasing(
| None = None,
):
if node_prefixes is None:
node_prefixes = graph.RefMap()
node_prefixes = graph.RefMap({})

# collect all paths and prefixes for each node
for path, value in graph.iter_graph(node):
Expand Down Expand Up @@ -324,7 +324,7 @@ def to_tree(

assert len(leaf_keys) == len(leaf_prefixes)
leaves_out = []
node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]()
node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]({})

with graph.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
Expand Down
Loading

0 comments on commit d088e18

Please sign in to comment.