From 76913cf9f766c36ca1dbc9e4be7607496aa866d0 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 30 Nov 2024 18:05:22 +0000 Subject: [PATCH] [nnx] fix nanobind --- .github/workflows/flax_test.yml | 14 +- benchmarks/nnx_graph_overhead.py | 1 + docs_nnx/api_reference/flax.nnx/helpers.rst | 5 +- flax/configurations.py | 11 + flax/nnx/extract.py | 4 +- flax/nnx/graph.py | 159 +++-- flax/nnx/transforms/autodiff.py | 6 +- flax/nnx/transforms/iteration.py | 8 +- flax/nnx/variablelib.py | 8 +- flax/typing.py | 31 + flaxlib_src/CMakeLists.txt | 57 ++ flaxlib_src/flaxlib.pyi | 15 - flaxlib_src/meson.build | 14 - flaxlib_src/pyproject.toml | 17 +- flaxlib_src/src/flaxlib.cpp | 720 ++++++++++++++++++++ flaxlib_src/src/flaxlib/__init__.py | 15 + flaxlib_src/src/flaxlib/flaxlib_cpp.pyi | 55 ++ flaxlib_src/src/lib.cc | 14 - pyproject.toml | 4 + tests/flaxlib_test.py | 25 - tests/nnx/graph_utils_test.py | 48 +- tests/nnx/variable_test.py | 2 +- tests/run_all_tests.sh | 1 + uv.lock | 66 ++ 24 files changed, 1114 insertions(+), 186 deletions(-) create mode 100644 flaxlib_src/CMakeLists.txt delete mode 100644 flaxlib_src/flaxlib.pyi delete mode 100644 flaxlib_src/meson.build create mode 100644 flaxlib_src/src/flaxlib.cpp create mode 100644 flaxlib_src/src/flaxlib/__init__.py create mode 100644 flaxlib_src/src/flaxlib/flaxlib_cpp.pyi delete mode 100644 flaxlib_src/src/lib.cc delete mode 100644 tests/flaxlib_test.py diff --git a/.github/workflows/flax_test.yml b/.github/workflows/flax_test.yml index 4bed8d8179..01727d38a0 100644 --- a/.github/workflows/flax_test.yml +++ b/.github/workflows/flax_test.yml @@ -88,15 +88,23 @@ jobs: python-version: ['3.10', '3.11'] test-type: [doctest, pytest, pytype, mypy] jax-version: [newest] + use-flaxlib: [true, false] exclude: - test-type: pytype python-version: '3.10' - test-type: mypy python-version: '3.11' + - use-flaxlib: true + test-type: doctest + - use-flaxlib: true + test-type: pytype + - use-flaxlib: true + test-type: mypy include: - python-version: '3.10' test-type: pytest jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml + use-flaxlib: false steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -119,12 +127,16 @@ jobs: else uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi + if [[ "${{ matrix.use-flaxlib }}" == "true" ]]; then + uv pip install -e flaxlib_src + fi - name: Test with ${{ matrix.test-type }} run: | if [[ "${{ matrix.test-type }}" == "doctest" ]]; then uv run tests/run_all_tests.sh --only-doctest elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then - uv run tests/run_all_tests.sh --only-pytest + FLAX_USE_FLAXLIB=${{ matrix.use-flaxlib }} \ + uv run tests/run_all_tests.sh --only-pytest elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then uv run tests/run_all_tests.sh --only-pytype elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py index 73cff6d6d6..bd32fc7883 100644 --- a/benchmarks/nnx_graph_overhead.py +++ b/benchmarks/nnx_graph_overhead.py @@ -19,6 +19,7 @@ import optax from time import time + from flax import nnx from absl import flags diff --git a/docs_nnx/api_reference/flax.nnx/helpers.rst b/docs_nnx/api_reference/flax.nnx/helpers.rst index f2b67522d7..7ff94de201 100644 --- a/docs_nnx/api_reference/flax.nnx/helpers.rst +++ b/docs_nnx/api_reference/flax.nnx/helpers.rst @@ -4,10 +4,7 @@ helpers .. automodule:: flax.nnx .. currentmodule:: flax.nnx -.. autoclass:: Dict - :members: -.. autoclass:: List - :members: + .. autoclass:: Sequential :members: .. autoclass:: TrainState diff --git a/flax/configurations.py b/flax/configurations.py index ba19a572fc..5e1a492fcf 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -22,6 +22,7 @@ class Config: + flax_use_flaxlib: bool # See https://google.github.io/pytype/faq.html. _HAS_DYNAMIC_ATTRIBUTES = True @@ -62,6 +63,10 @@ def update(self, name_or_holder, value, /): raise LookupError(f'Unrecognized config option: {name}') self._values[name] = value + def __repr__(self): + values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items()) + return f'Config({values_repr}\n)' + config = Config() @@ -201,3 +206,9 @@ def temp_flip_flag(var_name: str, var_value: bool): ' PRNG keys.' ), ) + +flax_use_flaxlib = bool_flag( + name='flax_use_flaxlib', + default=False, + help='Whether to use flaxlib for C++ acceleration.', +) \ No newline at end of file diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 191a0c195a..13441c7370 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -67,7 +67,7 @@ def extract_graph_nodes( | tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]] ): """Extracts all graph nodes from a pytree.""" - nodes = graph.RefMap[tp.Any, Index]() + nodes = graph.RefMap[tp.Any, Index]({}) node_prefixes = [] leaves = [] @@ -324,7 +324,7 @@ def to_tree( assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] - node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]() + node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]({}) with graph.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index be04b279c8..4832476245 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -20,6 +20,7 @@ import threading import typing as tp +from flax import config import jax import numpy as np import typing_extensions as tpe @@ -33,15 +34,14 @@ from flax.nnx.statelib import State from flax.nnx import variablelib from flax.nnx.variablelib import Variable, VariableState -from flax.typing import Key, PathParts, is_key_like +from flax.typing import HashableMapping, Key, PathParts, is_key_like A = tp.TypeVar('A') B = tp.TypeVar('B') C = tp.TypeVar('C') F = tp.TypeVar('F', bound=tp.Callable) -HA = tp.TypeVar('HA', bound=tp.Hashable) -HB = tp.TypeVar('HB', bound=tp.Hashable) + KeyT = tp.TypeVar('KeyT', bound=Key) Index = int @@ -66,9 +66,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]: class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]): """A mapping that uses object id as the hash for the keys.""" - def __init__( - self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), / - ): + def __init__(self, mapping: tp.Mapping[A, B], /): self._mapping: dict[int, tuple[A, B]] = {} self.update(mapping) @@ -90,8 +88,14 @@ def __iter__(self) -> tp.Iterator[A]: def __len__(self) -> int: return len(self._mapping) - def __str__(self) -> str: - return repr(self) +RefIndexMapping = RefMap[tp.Any, Index] +IndexRefMapping = dict[Index, tp.Any] + +if config.flax_use_flaxlib and not tp.TYPE_CHECKING: + import flaxlib + + RefIndexMapping = flaxlib.RefIndexMapping + IndexRefMapping = flaxlib.IndexRefMapping @dataclasses.dataclass(frozen=True, slots=True) @@ -204,32 +208,13 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: return GRAPH_REGISTRY[x] -class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): - def __init__(self, mapping: tp.Mapping[HA, HB], no_copy: bool = False): - self._mapping = mapping if no_copy else dict(mapping) - - def __contains__(self, key: object) -> bool: - return key in self._mapping - - def __getitem__(self, key: HA) -> HB: - return self._mapping[key] - - def __iter__(self) -> tp.Iterator[HA]: - return iter(self._mapping) - - def __len__(self) -> int: - return len(self._mapping) - def __hash__(self) -> int: - return hash(tuple(sorted(self._mapping.items()))) +IndexMapping = HashableMapping[int, int] - def __eq__(self, other: tp.Any) -> bool: - return ( - isinstance(other, HashableMapping) and self._mapping == other._mapping - ) +if config.flax_use_flaxlib and not tp.TYPE_CHECKING: + import flaxlib - def __repr__(self) -> str: - return repr(self._mapping) + IndexMapping = flaxlib.IndexMapping class GraphDef(tp.Generic[Node]): @@ -321,7 +306,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable): index: int attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] metadata: tp.Any - index_mapping: HashableMapping[Index, Index] | None + index_mapping: IndexMapping | None @classmethod def create( @@ -330,14 +315,14 @@ def create( index: int, attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], metadata: tp.Any, - index_mapping: tp.Mapping[Index, Index] | None, + index_mapping: IndexMapping | None, ): return cls( type=type, index=index, attributes=attributes, metadata=metadata, - index_mapping=HashableMapping(index_mapping) + index_mapping=IndexMapping(index_mapping) # type: ignore if index_mapping is not None else None, ) @@ -392,7 +377,7 @@ def _apply( def flatten( - node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None + node: Node, /, ref_index: RefIndexMapping | None = None ) -> tuple[GraphDef[Node], GraphState]: """Flattens a graph node into a (graphdef, state) pair. @@ -403,15 +388,22 @@ def flatten( nodes that share references. """ if ref_index is None: - ref_index = RefMap() - flat_state: list[tuple[PathParts, StateLeaf]] = [] - graphdef = _graph_flatten((), ref_index, flat_state, node) + ref_index = RefIndexMapping({}) + + flat_state: list[tuple[PathParts, StateLeaf]] + if config.flax_use_flaxlib: + import flaxlib + + graphdef, flat_state = flaxlib._graph_flatten_top(ref_index, node) + else: + flat_state = [] + graphdef = _graph_flatten([], ref_index, flat_state, node) return graphdef, GraphState.from_flat_path(flat_state) def _graph_flatten( - path: PathParts, - ref_index: RefMap[tp.Any, Index], + path: list[Key], + ref_index: RefIndexMapping, flat_state: list[tuple[PathParts, StateLeaf]], node: Node, ) -> NodeDef[Node] | NodeRef: @@ -434,8 +426,9 @@ def _graph_flatten( values, metadata = node_impl.flatten(node) for key, value in values: + path.append(key) if is_node(value): - nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) + nodedef = _graph_flatten(path, ref_index, flat_state, value) # subgraphs.append((key, nodedef)) attributes.append(SubGraphAttribute(key, nodedef)) elif isinstance(value, Variable): @@ -444,7 +437,7 @@ def _graph_flatten( LeafAttribute(key, NodeRef(type(value), ref_index[value])) ) else: - flat_state.append(((*path, key), value.to_state())) + flat_state.append((tuple(path), value.to_state())) variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( type(value), variable_index, HashableMapping(value._var_metadata) @@ -452,12 +445,13 @@ def _graph_flatten( attributes.append(LeafAttribute(key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): - path_str = '/'.join(map(str, (*path, key))) + path_str = '/'.join(map(str, path)) raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' ) # static_fields.append((key, value)) attributes.append(StaticAttribute(key, value)) + path.pop() nodedef = NodeDef.create( type=node_impl.type, @@ -468,14 +462,22 @@ def _graph_flatten( ) return nodedef +if config.flax_use_flaxlib and not tp.TYPE_CHECKING: + print('flaxlib used') + import flaxlib + + _graph_flatten = flaxlib._graph_flatten +else: + print('flaxlib not used') + def unflatten( graphdef: GraphDef[Node], state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], /, *, - index_ref: dict[Index, tp.Any] | None = None, - index_ref_cache: dict[Index, tp.Any] | None = None, + index_ref: IndexRefMapping | None = None, + index_ref_cache: IndexRefMapping | None = None, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -495,7 +497,7 @@ def unflatten( if isinstance(state, State): state = state.raw_mapping # type: ignore if index_ref is None: - index_ref = {} + index_ref = IndexRefMapping({}) assert isinstance(graphdef, (NodeDef, NodeRef)) node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) return node @@ -503,8 +505,8 @@ def unflatten( def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], - index_ref: dict[Index, tp.Any], - index_ref_cache: dict[Index, tp.Any] | None, + index_ref: IndexRefMapping, + index_ref_cache: IndexRefMapping | None, ) -> Node: """Recursive helper for graph_unflatten. @@ -792,7 +794,7 @@ class GraphContext(threading.local): @dataclasses.dataclass class SplitContext: ctxtag: str | None - ref_index: RefMap[tp.Any, Index] + ref_index: RefIndexMapping @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @@ -819,9 +821,9 @@ def split( states = _split_state(state, filters) if ctx is not None: if ctx.index_ref is not None and isinstance(graphdef, NodeDef): - index_to_index = compose_mapping(ctx.index_ref, self.ref_index) + index_to_index = create_index_mapping(ctx.index_ref, self.ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True) + graphdef, index_mapping=IndexMapping(index_to_index, no_copy=True) ) return graphdef, *states @@ -829,7 +831,7 @@ def split( @contextlib.contextmanager def split_context(ctxtag: str | None = None): - index_ref: RefMap[tp.Any, Index] = RefMap() + index_ref = RefIndexMapping({}) flatten_ctx = SplitContext(ctxtag, index_ref) GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx) @@ -847,7 +849,7 @@ def split_context(ctxtag: str | None = None): @dataclasses.dataclass class MergeContext: ctxtag: str | None - index_ref: dict[Index, tp.Any] + index_ref: IndexRefMapping def merge( self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState @@ -862,9 +864,7 @@ def merge( ): # outer merge (4), create index_ref_cache assert ctx.ref_index is not None - index_ref_cache = compose_mapping_reversed( - ctx.ref_index, graphdef.index_mapping - ) + index_ref_cache = create_index_ref(ctx.ref_index, graphdef.index_mapping) else: # inner merge (2) index_ref_cache = None @@ -881,7 +881,7 @@ def merge( @contextlib.contextmanager def merge_context(ctxtag: str | None = None): - index_ref: dict[Index, tp.Any] = {} + index_ref = IndexRefMapping({}) unflatten_ctx = MergeContext(ctxtag, index_ref) GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx) @@ -902,7 +902,7 @@ class UpdateContext: """A context manager for handling complex state updates.""" tag: str - ref_index: RefMap[tp.Any, Index] | None + ref_index: RefIndexMapping | None index_ref: dict[Index, tp.Any] | None # define hash and eq to make this an opaque object @@ -912,7 +912,7 @@ def __hash__(self): def __eq__(self, other): return isinstance(other, UpdateContext) - def flatten_end(self, ref_index: RefMap[tp.Any, Index]): + def flatten_end(self, ref_index: RefIndexMapping): if self.ref_index is None: # outer split (1), store the references self.ref_index = ref_index @@ -1004,14 +1004,14 @@ def split( :class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no filters are passed, a single :class:`State` is returned. """ - ref_index: RefMap[tp.Any, Index] = RefMap() + ref_index = RefIndexMapping({}) graphdef, state = flatten(node, ref_index) states = _split_state(state, filters) if self.index_ref is not None and isinstance(graphdef, NodeDef): - index_to_index = compose_mapping(self.index_ref, ref_index) + index_to_index = create_index_mapping(self.index_ref, ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, no_copy=True) + graphdef, index_mapping=IndexMapping(index_to_index, no_copy=True) ) self.flatten_end(ref_index) @@ -1035,15 +1035,13 @@ def merge( if graphdef.index_mapping is not None: # outer merge (4), create index_ref_cache assert self.ref_index is not None - index_ref_cache = compose_mapping_reversed( - self.ref_index, graphdef.index_mapping - ) + index_ref_cache = create_index_ref(self.ref_index, graphdef.index_mapping) else: # inner merge (2) index_ref_cache = None state = State.merge(state, *states) - index_ref: dict[Index, tp.Any] = {} + index_ref = IndexRefMapping({}) node = unflatten( graphdef, state, index_ref=index_ref, index_ref_cache=index_ref_cache ) @@ -1755,16 +1753,29 @@ def _iter_graph( yield path_parts, node -def compose_mapping( - map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / -) -> dict[A, C]: - return {a: map_bc[b] for a, b in map_ab.items() if b in map_bc} +def create_index_mapping( + index_ref: IndexRefMapping, ref_index: RefIndexMapping, / +) -> IndexMapping: + return IndexMapping( + {a: ref_index[b] for a, b in index_ref.items() if b in ref_index}, + no_copy=True, + ) + + +def create_index_ref( + ref_index: RefIndexMapping, index_mapping: IndexMapping, / +) -> IndexRefMapping: + return { + index_mapping[index]: ref + for ref, index in ref_index.items() + if index in index_mapping + } + +if config.flax_use_flaxlib and not tp.TYPE_CHECKING: + import flaxlib -def compose_mapping_reversed( - map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / -) -> dict[C, A]: - return {map_bc[b]: a for a, b in map_ab.items() if b in map_bc} + # create_index_ref = flaxlib.create_index_ref @dataclasses.dataclass(frozen=True) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 5ef0d183b7..d200955be9 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -427,7 +427,7 @@ def _custom_vjp_split_fn( nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) -def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]): +def _extract_index_mappings(x, *, index_mappings: deque[graph.IndexMapping]): if isinstance(x, graph.NodeDef): assert x.index_mapping is not None index_mappings.append(x.index_mapping) @@ -465,7 +465,7 @@ def __call__(self, *pure_args): (args_out, out), ctxtag=self.ctxtag ) # remove index_mapping from NodeDef's but store them in global context - index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state( + index_mappings: deque[graph.IndexMapping] = extract.get_broadcast_state( self.ctxtag ) @@ -664,7 +664,7 @@ def __call__( # insert index_mappings def _insert_index_mappings(x): if isinstance(x, graph.NodeDef): - index_mapping: graph.HashableMapping = index_mappings.popleft() + index_mapping: tp.Mapping = index_mappings.popleft() return dataclasses.replace(x, index_mapping=index_mapping) return x diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 994e582862..903ca668d1 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -650,7 +650,7 @@ def check_carry_same_references(key_path, arg, out): def _extract_index_mappings( pure_carry_arg_out, - carry_index_mappings: list[graph.HashableMapping[int, int]], + carry_index_mappings: list[graph.IndexMapping], /, ): def extract_index_mappings(x): @@ -675,7 +675,7 @@ def extract_index_mappings(x): def _insert_index_mappings( pure_carry_arg_out, - carry_index_mappings: deque[graph.HashableMapping[int, int]], + carry_index_mappings: deque[graph.IndexMapping], /, ): def insert_index_mappings(x): @@ -1096,7 +1096,7 @@ def __call__( # next we have to remove all the index_mappings from the NodeDefs # in the carry outputs because they are not present in the inputs - carry_index_mappings: list[graph.HashableMapping[int, int]] = [] + carry_index_mappings: list[graph.IndexMapping] = [] pure_carry_arg_out = _extract_index_mappings( pure_carry_arg_out, carry_index_mappings ) @@ -1357,7 +1357,7 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef): return dataclasses.replace( ns, _graphdef=dataclasses.replace( - ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping) + ns._graphdef, index_mapping=graph.IndexMapping(global_index_mapping) ), ) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 4752a9b7bd..239292cc35 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -308,7 +308,7 @@ def copy(self: Variable[A]) -> Variable[A]: return obj def to_state(self: Variable[A]) -> VariableState[A]: - return VariableState(type(self), self.raw_value, **self._var_metadata) + return VariableState(type(self), self.raw_value, self._var_metadata.copy()) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) @@ -739,7 +739,7 @@ def __init__( self, type: type[Variable[A]], # type: ignore [valid-type] value: A, - **metadata, + metadata: dict[str, tp.Any], ): object.__setattr__(self, 'type', type) object.__setattr__(self, 'value', value) @@ -783,7 +783,7 @@ def __treescope_repr__(self, path, subtree_renderer): ) def replace(self, value: B) -> VariableState[B]: - return VariableState(self.type, value, **self.get_metadata()) + return VariableState(self.type, value, self.get_metadata().copy()) def to_variable(self) -> Variable[A]: # we use object.__new__ to avoid calling __init__ and bypass the @@ -826,7 +826,7 @@ def _variable_state_unflatten( return VariableState( type=static[0], value=children[0], - **dict(static[1]), + metadata=dict(static[1]), ) diff --git a/flax/typing.py b/flax/typing.py index a630a3571e..6493108e67 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -23,6 +23,7 @@ TypeVar, Union, ) +from collections.abc import Iterator from collections.abc import Callable, Hashable, Mapping, Sequence import jax @@ -161,3 +162,33 @@ class Missing: MISSING = Missing() +HA = TypeVar('HA', bound=Hashable) +HB = TypeVar('HB', bound=Hashable) + + +class HashableMapping(Mapping[HA, HB], Hashable): + def __init__(self, mapping: Mapping[HA, HB], no_copy: bool = False): + self._mapping = mapping if no_copy else dict(mapping) + + def __contains__(self, key: object) -> bool: + return key in self._mapping + + def __getitem__(self, key: HA) -> HB: + return self._mapping[key] + + def __iter__(self) -> Iterator[HA]: + return iter(self._mapping) + + def __len__(self) -> int: + return len(self._mapping) + + def __hash__(self) -> int: + return hash(tuple(sorted(self._mapping.items()))) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, HashableMapping) and self._mapping == other._mapping + ) + + def __repr__(self) -> str: + return repr(self._mapping) \ No newline at end of file diff --git a/flaxlib_src/CMakeLists.txt b/flaxlib_src/CMakeLists.txt new file mode 100644 index 0000000000..28b2b8cf36 --- /dev/null +++ b/flaxlib_src/CMakeLists.txt @@ -0,0 +1,57 @@ +# Set the minimum CMake version and policies for highest tested version +cmake_minimum_required(VERSION 3.15...3.27) + +# Set up the project and ensure there is a working C++ compiler +project(flaxlib LANGUAGES CXX) + +# Warn if the user invokes CMake directly +if (NOT SKBUILD) + message(WARNING "\ + This CMake file is meant to be executed using 'scikit-build-core'. + Running it directly will almost certainly not produce the desired + result. If you are a user trying to install this package, use the + command below, which will install all necessary build dependencies, + compile the package in an isolated environment, and then install it. + ===================================================================== + $ pip install . + ===================================================================== + If you are a software developer, and this is your own package, then + it is usually much more efficient to install the build dependencies + in your environment once and use the following command that avoids + a costly creation of a new virtual environment at every compilation: + ===================================================================== + $ pip install nanobind scikit-build-core[pyproject] + $ pip install --no-build-isolation -ve . + ===================================================================== + You may optionally add -Ceditable.rebuild=true to auto-rebuild when + the package is imported. Otherwise, you need to rerun the above + after editing C++ files.") +endif() + +# Try to import all Python components potentially needed by nanobind +find_package(Python 3.8 + REQUIRED COMPONENTS Interpreter Development.Module + OPTIONAL_COMPONENTS Development.SABIModule) + +# Import nanobind through CMake's find_package mechanism +find_package(nanobind CONFIG REQUIRED) +find_package(OpenSSL REQUIRED) + +# We are now ready to compile the actual extension module +nanobind_add_module( + # Name of the extension + flaxlib_cpp + + # Target the stable ABI for Python 3.12+, which reduces + # the number of binary wheels that must be built. This + # does nothing on older Python versions + STABLE_ABI + + # Source code goes here + src/flaxlib.cpp +) + +target_link_libraries(flaxlib_cpp PRIVATE OpenSSL::SSL OpenSSL::Crypto) + +# Install directive for scikit-build-core +install(TARGETS flaxlib_cpp LIBRARY DESTINATION flaxlib) \ No newline at end of file diff --git a/flaxlib_src/flaxlib.pyi b/flaxlib_src/flaxlib.pyi deleted file mode 100644 index 505fd3d0f0..0000000000 --- a/flaxlib_src/flaxlib.pyi +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -def sum_as_string(a: int, b: int) -> str: ... diff --git a/flaxlib_src/meson.build b/flaxlib_src/meson.build deleted file mode 100644 index 0d78d9436b..0000000000 --- a/flaxlib_src/meson.build +++ /dev/null @@ -1,14 +0,0 @@ -project( - 'flaxlib', - 'cpp', - version: '0.0.1', - default_options: ['cpp_std=c++17'], -) -py = import('python').find_installation() -nanobind_dep = dependency('nanobind', static: true) -py.extension_module( - 'flaxlib', - sources: ['src/lib.cc'], - dependencies: [nanobind_dep], - install: true, -) \ No newline at end of file diff --git a/flaxlib_src/pyproject.toml b/flaxlib_src/pyproject.toml index 0afc7699a5..fd6c0b61b4 100644 --- a/flaxlib_src/pyproject.toml +++ b/flaxlib_src/pyproject.toml @@ -1,17 +1,28 @@ [build-system] -requires = ['meson-python'] -build-backend = 'mesonpy' +requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"] +build-backend = "scikit_build_core.build" [project] name = "flaxlib" +version = "0.0.1" requires-python = ">=3.10" classifiers = [ "Programming Language :: C++", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dynamic = ["version"] + [project.optional-dependencies] tests = [ "pytest", ] + +[tool.scikit-build] +# Protect the configuration against future changes in scikit-build-core +minimum-version = "0.4" + +# Setuptools-style build caching in a local directory +build-dir = "build/{wheel_tag}" + +# Build stable ABI wheels for CPython 3.12+ +wheel.py-api = "cp312" \ No newline at end of file diff --git a/flaxlib_src/src/flaxlib.cpp b/flaxlib_src/src/flaxlib.cpp new file mode 100644 index 0000000000..193b47edae --- /dev/null +++ b/flaxlib_src/src/flaxlib.cpp @@ -0,0 +1,720 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +using namespace nb::literals; + +namespace flaxlib +{ + // ----------------------------------- + // helper functions + // ----------------------------------- + intptr_t get_id(nb::object obj) + { + // Get the object ID + return reinterpret_cast(obj.ptr()); + } + + bool nb_isinstance(nanobind::handle inst, nanobind::handle cls) + { + int ret = PyObject_IsInstance(inst.ptr(), cls.ptr()); + if (ret == -1) + { + throw nb::python_error(); + } + return ret; + } + + nb::object vector_to_tuple(const std::vector &vec) + { + + if (vec.empty()) + { + return nb::tuple(); + } + else + { + auto ls = nb::list(); + for (const auto &item : vec) + { + ls.append(item); + } + auto result = nb::tuple(ls); + return result; + } + } + + // ----------------------------------- + // IndexMapping + // ----------------------------------- + class IndexMappingKeysIterator + { + public: + IndexMappingKeysIterator(const std::unordered_map &data) : it(data.begin()), end(data.end()) {} + + int next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return it++->first; + } + + IndexMappingKeysIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map::const_iterator it; + std::unordered_map::const_iterator end; + }; + + struct IndexMapping + { + std::unordered_map mapping; + + IndexMapping(std::unordered_map &mapping, bool no_copy) + { + if (no_copy) + { + this->mapping = std::move(mapping); + } + else + { + this->mapping = mapping; + } + } + + // define the python __hash__ method + uint64_t __hash__() + { + EVP_MD_CTX *mdctx; + const EVP_MD *md; + unsigned char md_value[EVP_MAX_MD_SIZE]; + unsigned int md_len; + + // Serialize the map + std::stringstream ss; + for (const auto &pair : mapping) + { + ss << pair.first << ":" << pair.second << ","; + } + std::string serializedData = ss.str(); + + OpenSSL_add_all_digests(); + + md = EVP_get_digestbyname("SHA256"); + if (!md) + { + throw std::runtime_error("Unknown message digest BLAKE3"); + } + + mdctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(mdctx, md, NULL); + EVP_DigestUpdate(mdctx, serializedData.c_str(), serializedData.size()); + EVP_DigestFinal_ex(mdctx, md_value, &md_len); + EVP_MD_CTX_free(mdctx); + + // Convert (part of) the digest to a 64-bit integer + uint64_t result = 0; + for (size_t i = 0; i < 8 && i < md_len; ++i) + { + result = (result << 8) | static_cast(md_value[i]); + } + + return result; + } + + // define the python __repr__ method + std::string __repr__() + { + std::string repr; + if (mapping.size() == 1) + { + repr = "IndexMapping({"; + for (const auto &pair : mapping) + { + repr += std::to_string(pair.first) + ": " + std::to_string(pair.second); + } + repr += "})"; + } + else + { + repr = "IndexMapping({\n"; + for (const auto &pair : mapping) + { + repr += " " + std::to_string(pair.first) + ": " + std::to_string(pair.second) + ",\n"; + } + if (!mapping.empty()) + { + repr.pop_back(); + repr.pop_back(); + } + repr += "\n})"; + } + return repr; + } + + // define the python __getitem__ method + int __getitem__(int key) const + { + return mapping.at(key); + } + + // define __iter__ method + IndexMappingKeysIterator __iter__() const + { + return IndexMappingKeysIterator(mapping); + } + + // define the python __len__ method + size_t __len__() const + { + return mapping.size(); + } + + // define the python __contains__ method + bool __contains__(int key) const + { + return mapping.find(key) != mapping.end(); + } + + bool __eq__(const nb::object &other) const + { + if (!nb::isinstance(other)) + { + return false; + } + + auto other_mapping = nb::cast(other); + return mapping == other_mapping.mapping; + } + + nb::object items() const + { + return nb::make_iterator( + nb::type>>(), "IndexMappingItemsIterator", mapping.begin(), mapping.end()); + } + }; + + // ----------------------------------- + // RefIndexMapping + // ----------------------------------- + + struct RefIndexMappingKeysIterator + { + public: + RefIndexMappingKeysIterator(const std::unordered_map> &data) : it(data.begin()), end(data.end()) {} + + nb::object next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return it++->second.first; + } + + RefIndexMappingKeysIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map>::const_iterator it; + std::unordered_map>::const_iterator end; + }; + + struct RefIndexMappingItemsIterator + { + public: + RefIndexMappingItemsIterator(const std::unordered_map> &data) : it(data.begin()), end(data.end()) {} + + std::pair next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return it++->second; + } + + RefIndexMappingItemsIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map>::const_iterator it; + std::unordered_map>::const_iterator end; + }; + + struct RefIndexMapping + { + std::unordered_map> mapping; + + RefIndexMapping(std::map ref_mapping) + { + for (const auto &pair : ref_mapping) + { + mapping[get_id(pair.first)] = {pair.first, pair.second}; + } + } + + int __getitem__(nb::object key) const + { + return mapping.at(get_id(key)).second; + } + + bool __contains__(nb::object key) const + { + return mapping.find(get_id(key)) != mapping.end(); + } + + void __setitem__(nb::object key, int value) + { + mapping[get_id(key)] = {key, value}; + } + + void __delitem__(nb::object key) + { + mapping.erase(get_id(key)); + } + + RefIndexMappingKeysIterator __iter__() const + { + return RefIndexMappingKeysIterator(mapping); + } + + size_t __len__() const + { + return mapping.size(); + } + + // __repr__ method + std::string __repr__() + { + std::string repr; + if (mapping.size() == 1) + { + repr = "RefIndexMapping({"; + for (const auto &pair : mapping) + { + repr += nb::cast(nb::repr(pair.second.first)) + ": " + std::to_string(pair.second.second); + } + repr += "})"; + } + else + { + repr = "RefIndexMapping({\n"; + for (const auto &pair : mapping) + { + repr += " " + nb::cast(nb::repr(pair.second.first)) + ": " + std::to_string(pair.second.second) + ",\n"; + } + if (!mapping.empty()) + { + repr.pop_back(); + repr.pop_back(); + } + repr += "\n})"; + } + return repr; + } + + RefIndexMappingItemsIterator items() const + { + return RefIndexMappingItemsIterator(mapping); + } + }; + + // ------------------------------------- + // IndexRefMapping + // ------------------------------------- + + struct IndexRefMappingKeysIterator + { + public: + IndexRefMappingKeysIterator(const std::unordered_map &data) : it(data.begin()), end(data.end()) {} + + int next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return get_id(it++->second); + } + + IndexRefMappingKeysIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map::const_iterator it; + std::unordered_map::const_iterator end; + }; + + struct IndexRefMapping + { + std::unordered_map mapping; + + IndexRefMapping(std::unordered_map mapping) : mapping(mapping) {} + + nb::object __getitem__(int key) const + { + return mapping.at(key); + } + + bool __contains__(int key) const + { + return mapping.find(key) != mapping.end(); + } + + void __setitem__(int key, nb::object value) + { + mapping[key] = value; + } + + void __delitem__(int key) + { + mapping.erase(key); + } + + IndexRefMappingKeysIterator __iter__() const + { + return IndexRefMappingKeysIterator(mapping); + } + + size_t __len__() const + { + return mapping.size(); + } + + std::string __repr__() + { + std::string repr; + if (mapping.size() <= 1) + { + repr = "IndexRefMapping({"; + for (const auto &pair : mapping) + { + repr += std::to_string(pair.first) + ": " + nb::cast(nb::repr(pair.second)); + } + repr += "})"; + } + else + { + repr = "IndexRefMapping({\n"; + for (const auto &pair : mapping) + { + repr += " " + std::to_string(pair.first) + ": " + nb::cast(nb::repr(pair.second)) + ",\n"; + } + if (!mapping.empty()) + { + repr.pop_back(); + repr.pop_back(); + } + repr += "\n})"; + } + return repr; + } + + nb::object items() const + { + return nb::make_iterator(nb::type>>(), "IndexRefMappingItemsIterator", mapping.begin(), mapping.end()); + } + }; + + // ------------------------------------- + // functions + // ------------------------------------- + + IndexRefMapping create_index_ref(RefIndexMapping ref_index, IndexMapping index_mapping) + { + std::unordered_map new_mapping; + for (const auto &pair : ref_index.mapping) + { + auto a = pair.second.first; + auto b = pair.second.second; + + auto b_pos = index_mapping.mapping.find(b); + if (b_pos != index_mapping.mapping.end()) + { + new_mapping[b_pos->second] = a; + } + } + return IndexRefMapping(new_mapping); + } + + // def _graph_flatten( + // path: list[Key], + // ref_index: RefIndexMapping, + // flat_state: list[tuple[PathParts, StateLeaf]], + // node: Node, + // ) -> NodeDef[Node] | NodeRef: + // if not is_node(node): + // raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') + + // if node in ref_index: + // return NodeRef(type(node), ref_index[node]) + + // node_impl = get_node_impl(node) + + // # only cache graph nodes + // if isinstance(node_impl, GraphNodeImpl): + // index = len(ref_index) + // ref_index[node] = index + // else: + // index = -1 + + // attributes: list[SubGraphAttribute | StaticAttribute | LeafAttribute] = [] + + // values, metadata = node_impl.flatten(node) + // for key, value in values: + // path.append(key) + // if is_node(value): + // nodedef = _graph_flatten(path, ref_index, flat_state, value) + // # subgraphs.append((key, nodedef)) + // attributes.append(SubGraphAttribute(key, nodedef)) + // elif isinstance(value, Variable): + // if value in ref_index: + // attributes.append( + // LeafAttribute(key, NodeRef(type(value), ref_index[value])) + // ) + // else: + // flat_state.append((tuple(path), value.to_state())) + // variable_index = ref_index[value] = len(ref_index) + // variabledef = VariableDef( + // type(value), variable_index, HashableMapping(value._var_metadata) + // ) + // attributes.append(LeafAttribute(key, variabledef)) + // else: + // if isinstance(value, (jax.Array, np.ndarray)): + // path_str = '/'.join(map(str, path)) + // raise ValueError( + // f'Arrays leaves are not supported, at {path_str!r}: {value}' + // ) + // attributes.append(StaticAttribute(key, value)) + // path.pop() + + // nodedef = NodeDef.create( + // type=node_impl.type, + // index=index, + // attributes=tuple(attributes), + // metadata=metadata, + // index_mapping=None, + // ) + // return nodedef + + nb::object _graph_flatten( + std::vector &path, + RefIndexMapping &ref_index, + std::vector> &flat_state, + nb::object node) + { + // import graph Module from flax.nnx + auto graph = nb::module_::import_("flax.nnx.graph"); + auto jax = nb::module_::import_("jax"); + auto np = nb::module_::import_("numpy"); + + auto jax_Array = nb::getattr(jax, "Array"); + auto np_ndarray = nb::getattr(np, "ndarray"); + auto GraphNodeImpl = nb::getattr(graph, "GraphNodeImpl"); + auto Variable = nb::getattr(graph, "Variable"); + auto SubGraphAttribute = nb::getattr(graph, "SubGraphAttribute"); + auto StaticAttribute = nb::getattr(graph, "StaticAttribute"); + auto LeafAttribute = nb::getattr(graph, "LeafAttribute"); + auto NodeRef = nb::getattr(graph, "NodeRef"); + auto NodeDef = nb::getattr(graph, "NodeDef"); + auto VariableDef = nb::getattr(graph, "VariableDef"); + auto HashableMapping = nb::getattr(graph, "HashableMapping"); + + if (!nb::bool_(nb::getattr(graph, "is_node")(node))) + { + throw std::runtime_error("Unsupported type: " + nb::cast(node.type().attr("__name__")) + ", this is a bug."); + } + + if (ref_index.__contains__(node)) + { + return NodeRef(node.type(), ref_index.__getitem__(node)); + } + + auto node_impl = nb::getattr(graph, "get_node_impl")(node); + + int index; + // only cache graph nodes + if (nb_isinstance(node_impl, GraphNodeImpl)) + { + index = ref_index.__len__(); + ref_index.__setitem__(node, index); + } + else + { + index = -1; + } + + std::vector attributes; + + auto values_metadata = nb::getattr(node_impl, "flatten")(node); + auto values = values_metadata[0]; + auto metadata = values_metadata[1]; + + for (const auto &key_value : values) + { + auto key = key_value[0]; + auto value = key_value[1]; + + path.push_back(key); + + if (nb::bool_(nb::getattr(graph, "is_node")(value))) + { + auto nodedef = _graph_flatten(path, ref_index, flat_state, value); + attributes.push_back(SubGraphAttribute(key, nodedef)); + } + else if (nb_isinstance(value, Variable)) + { + if (ref_index.__contains__(value)) + { + attributes.push_back(LeafAttribute(key, NodeRef(value.type(), ref_index.__getitem__(value)))); + } + else + { + auto path_tuple = vector_to_tuple(path); + flat_state.push_back({path_tuple, nb::getattr(value, "to_state")()}); + auto variable_index = ref_index.__len__(); + ref_index.__setitem__(value, variable_index); + auto var_meta = HashableMapping(nb::getattr(value, "_var_metadata")); + auto variabledef = VariableDef(value.type(), variable_index, var_meta); + attributes.push_back(LeafAttribute(key, variabledef)); + } + } + else + { + if (nb_isinstance(value, jax_Array) || nb_isinstance(value, np_ndarray)) + { + std::string path_str; + for (const auto &part : path) + { + path_str += nb::cast(nb::repr(part)) + "/"; + } + throw std::runtime_error("Arrays leaves are not supported, at " + path_str + ": " + nb::cast(nb::repr(value))); + } + attributes.push_back(StaticAttribute(key, value)); + } + path.pop_back(); + } + + auto attributes_tuple = vector_to_tuple(attributes); + auto nodedef = nb::getattr(NodeDef, "create")( + nb::getattr(node_impl, "type"), index, attributes_tuple, metadata, nb::none()); + + return nodedef; + } + + std::pair _graph_flatten_top( + RefIndexMapping &ref_index, + nb::object node) + { + // print "here" + std::vector path = {}; + std::vector> flat_state = {}; + auto nodedef = _graph_flatten(path, ref_index, flat_state, node); + + auto flat_state_out = nb::list(); + for (const auto &pair : flat_state) + { + flat_state_out.append(nb::make_tuple(pair.first, pair.second)); + } + return {nodedef, flat_state_out}; + } + + NB_MODULE(flaxlib_cpp, m) + { + //------------------------------------------------------------------------- + // IndexMapping + //------------------------------------------------------------------------- + nb::class_(m, "IndexMapping") + // no_copy defaults to false + .def(nb::init &, bool>(), nb::arg("mapping"), nb::arg("no_copy") = false) + .def("__hash__", &IndexMapping::__hash__) + .def("__repr__", &IndexMapping::__repr__) + .def("__getitem__", &IndexMapping::__getitem__) + .def("__iter__", &IndexMapping::__iter__) + .def("__len__", &IndexMapping::__len__) + .def("__contains__", &IndexMapping::__contains__, nb::arg("key").none()) + .def("__eq__", &IndexMapping::__eq__) + .def("items", &IndexMapping::items); + + nb::class_(m, "IndexMappingIterator") + .def("__iter__", &IndexMappingKeysIterator::__iter__) + .def("__next__", &IndexMappingKeysIterator::next); + + //------------------------------------------------------------------------- + // RefIndexMapping + //------------------------------------------------------------------------- + nb::class_(m, "RefIndexMapping") + .def(nb::init>()) + .def("__getitem__", &RefIndexMapping::__getitem__) + .def("__contains__", &RefIndexMapping::__contains__, nb::arg("key").none()) + .def("__setitem__", &RefIndexMapping::__setitem__) + .def("__delitem__", &RefIndexMapping::__delitem__) + .def("__iter__", &RefIndexMapping::__iter__) + .def("__len__", &RefIndexMapping::__len__) + .def("__repr__", &RefIndexMapping::__repr__) + .def("items", &RefIndexMapping::items); + + nb::class_(m, "RefIndexMappingKeysIterator") + .def("__iter__", &RefIndexMappingKeysIterator::__iter__) + .def("__next__", &RefIndexMappingKeysIterator::next); + + nb::class_(m, "RefIndexMappingItemsIterator") + .def("__iter__", &RefIndexMappingItemsIterator::__iter__) + .def("__next__", &RefIndexMappingItemsIterator::next); + + //------------------------------------------------------------------------- + // IndexRefMapping + //------------------------------------------------------------------------- + nb::class_(m, "IndexRefMapping") + .def(nb::init &>()) + .def("__getitem__", &IndexRefMapping::__getitem__) + .def("__contains__", &IndexRefMapping::__contains__, nb::arg("key").none()) + .def("__setitem__", &IndexRefMapping::__setitem__) + .def("__delitem__", &IndexRefMapping::__delitem__) + .def("__iter__", &IndexRefMapping::__iter__) + .def("__len__", &IndexRefMapping::__len__) + .def("__repr__", &IndexRefMapping::__repr__) + .def("items", &IndexRefMapping::items); + + nb::class_(m, "IndexRefMappingKeysIterator") + .def("__iter__", &IndexRefMappingKeysIterator::__iter__) + .def("__next__", &IndexRefMappingKeysIterator::next); + + //------------------------------------------------------------------------- + // functions + //------------------------------------------------------------------------- + m.def("create_index_ref", &create_index_ref); + m.def("_graph_flatten_top", &_graph_flatten_top); + m.def("_graph_flatten", &_graph_flatten); + } + +} // namespace flaxlib \ No newline at end of file diff --git a/flaxlib_src/src/flaxlib/__init__.py b/flaxlib_src/src/flaxlib/__init__.py new file mode 100644 index 0000000000..a2138d52d8 --- /dev/null +++ b/flaxlib_src/src/flaxlib/__init__.py @@ -0,0 +1,15 @@ +from .flaxlib_cpp import IndexMapping as IndexMapping +from .flaxlib_cpp import RefIndexMapping as RefIndexMapping +from .flaxlib_cpp import IndexRefMapping as IndexRefMapping +from .flaxlib_cpp import create_index_ref as create_index_ref +from .flaxlib_cpp import _graph_flatten as _graph_flatten +from .flaxlib_cpp import _graph_flatten_top as _graph_flatten_top + +# ----------------------------- +# Register pytrees types +# ----------------------------- +import jax + +jax.tree_util.register_static(IndexMapping) + +del jax \ No newline at end of file diff --git a/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi new file mode 100644 index 0000000000..1ab90b0b7c --- /dev/null +++ b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi @@ -0,0 +1,55 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterator + +def sum_as_string(a: int, b: int) -> str: ... + +def create_index_ref( + ref_index: RefIndexMapping, index_mapping: IndexMapping +) -> IndexRefMapping: ... + +class IndexMapping: + def __init__(self, mapping: dict[int, int], /) -> None: ... + def __hash__(self) -> int: ... + def __getitem__(self, key: int) -> int: ... + def __len__(self) -> int: ... + def __contains__(self, key: int) -> bool: ... + def __iter__(self) -> Iterator[int]: ... + def items(self) -> Iterator[tuple[int, int]]: ... + +class RefIndexMapping: + def __init__(self, ref_mapping: dict[Any, int], /) -> None: ... + def __getitem__(self, key: Any) -> int: ... + def __contains__(self, key: Any) -> bool: ... + def __setitem__(self, key: Any, value: int) -> None: ... + def __delitem__(self, key: Any) -> None: ... + def __iter__(self) -> Iterator[Any]: ... + def __len__(self) -> int: ... + def items(self) -> Iterator[tuple[Any, int]]: ... + +class IndexRefMapping: + def __init__(self, mapping: dict[int, Any], /) -> None: ... + def __getitem__(self, key: int) -> Any: ... + def __contains__(self, key: int) -> bool: ... + def __setitem__(self, key: int, value: Any) -> None: ... + def __delitem__(self, key: int) -> None: ... + def __iter__(self) -> Iterator[int]: ... + def __len__(self) -> int: ... + def items(self) -> Iterator[tuple[int, Any]]: ... + +def _graph_flatten_top(ref_index: RefIndexMapping, node: Any) -> Any: ... +def _graph_flatten( + path: list, ref_index: RefIndexMapping, flat_state: list, node: Any +) -> Any: ... diff --git a/flaxlib_src/src/lib.cc b/flaxlib_src/src/lib.cc deleted file mode 100644 index c714588118..0000000000 --- a/flaxlib_src/src/lib.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include - -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" - -namespace flaxlib { -std::string sum_as_string(int a, int b) { - return std::to_string(a + b); -} - -NB_MODULE(flaxlib, m) { - m.def("sum_as_string", &sum_as_string); -} -} // namespace flaxlib \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 658b2f15d5..5d6dfc29a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,11 @@ docs = [ "ipywidgets>=8.1.5", ] dev = [ + "build>=1.2.2.post1", + "nanobind>=2.2.0", + "ninja>=1.11.1.1", "pre-commit>=3.8.0", + "scikit-build-core[pyproject]>=0.10.7", ] [project.urls] diff --git a/tests/flaxlib_test.py b/tests/flaxlib_test.py deleted file mode 100644 index c23f70baa7..0000000000 --- a/tests/flaxlib_test.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# TODO: Re-enable this test after setting up CI build for flaxlib CC. - -# from absl.testing import absltest -# import flaxlib - - -# class TestFlaxlib(absltest.TestCase): - -# def test_flaxlib(self): -# self.assertEqual(flaxlib.sum_as_string(1, 2), '3') diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index a7bbf178cb..f3b47c1162 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -59,11 +59,15 @@ def __call__(self, x): class TestGraphUtils(absltest.TestCase): + def test_flatten_basic(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + graphdef, state = nnx.split(m) + def test_flatten(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - refmap = nnx.graph.RefMap() + refmap = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(g, ref_index=refmap) state[0]['b'].raw_value = 2 @@ -326,7 +330,7 @@ def f(m: Foo): a = m.a b = m.b - ref_out_idx_out = nnx.graph.RefMap() + ref_out_idx_out = nnx.graph.RefIndexMapping({}) graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @@ -335,19 +339,19 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() + ref_in_idx_in = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + idx_out_idx_in = nnx.graph.create_index_mapping( + idx_out_ref_in, ref_in_idx_in + ) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out static_out: nnx.graph.Static state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] + idx_out_idx_in: nnx.graph.IndexMapping graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in - ) + idx_in_ref_out = nnx.graph.create_index_ref(ref_out_idx_out, idx_out_idx_in) m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b @@ -366,7 +370,7 @@ def f(m: Foo): a = m.a b = m.b - ref_out_idx_out = nnx.graph.RefMap[Any, int]() + ref_out_idx_out = nnx.graph.RefIndexMapping({}) graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @@ -375,19 +379,19 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() + ref_in_idx_in = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + idx_out_idx_in = nnx.graph.create_index_mapping( + idx_out_ref_in, ref_in_idx_in + ) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out static_out: nnx.graph.Static state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] + idx_out_idx_in: nnx.graph.IndexMapping graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in - ) + idx_in_ref_out = nnx.graph.create_index_ref(ref_out_idx_out, idx_out_idx_in) m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b @@ -403,7 +407,7 @@ def f(m: Foo): m = Foo() - ref_out_idx_out = nnx.graph.RefMap() + ref_out_idx_out = nnx.graph.RefIndexMapping({}) graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @@ -412,19 +416,19 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() + ref_in_idx_in = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + idx_out_idx_in = nnx.graph.create_index_mapping( + idx_out_ref_in, ref_in_idx_in + ) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out static_out: nnx.graph.Static state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] + idx_out_idx_in: nnx.graph.IndexMapping graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in - ) + idx_in_ref_out = nnx.graph.create_index_ref(ref_out_idx_out, idx_out_idx_in) m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.ref is m2 diff --git a/tests/nnx/variable_test.py b/tests/nnx/variable_test.py index 9729828278..40dc8b8941 100644 --- a/tests/nnx/variable_test.py +++ b/tests/nnx/variable_test.py @@ -24,7 +24,7 @@ class TestVariableState(absltest.TestCase): def test_pytree(self): - r1 = nnx.VariableState(nnx.Param, 1) + r1 = nnx.VariableState(nnx.Param, 1, {}) self.assertEqual(r1.value, 1) r2 = jax.tree.map(lambda x: x + 1, r1) diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 920d71017b..2c210b18d0 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -63,6 +63,7 @@ echo "GH_VENV: $GH_VENV" echo "WHICH PYTHON: $(which python)" echo "jax: $(python -c 'import jax; print(jax.__version__)')" echo "flax: $(python -c 'import flax; print(flax.__version__)')" +echo "flax config: $(python -c 'from flax import config; print(config)')" echo "==========================" echo "" diff --git a/uv.lock b/uv.lock index a30155113e..539338ad7f 100644 --- a/uv.lock +++ b/uv.lock @@ -179,6 +179,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/fe/e8c672695b37eecc5cbf43e1d0638d88d66ba3a44c4d321c796f4e59167f/beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed", size = 147925 }, ] +[[package]] +name = "build" +version = "1.2.2.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "os_name == 'nt'" }, + { name = "importlib-metadata", marker = "python_full_version < '3.10.2'" }, + { name = "packaging" }, + { name = "pyproject-hooks" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/46/aeab111f8e06793e4f0e421fcad593d547fb8313b50990f31681ee2fb1ad/build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7", size = 46701 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/c2/80633736cd183ee4a62107413def345f7e6e3c01563dbca1417363cf957e/build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5", size = 22950 }, +] + [[package]] name = "certifi" version = "2024.7.4" @@ -793,7 +809,11 @@ all = [ { name = "matplotlib" }, ] dev = [ + { name = "build" }, + { name = "nanobind" }, + { name = "ninja" }, { name = "pre-commit" }, + { name = "scikit-build-core" }, ] docs = [ { name = "dm-haiku" }, @@ -841,6 +861,7 @@ testing = [ [package.metadata] requires-dist = [ + { name = "build", marker = "extra == 'dev'", specifier = ">=1.2.2.post1" }, { name = "cloudpickle", marker = "extra == 'testing'", specifier = ">=3.0.0" }, { name = "clu", marker = "python_full_version < '3.10' and extra == 'testing'", specifier = "<=0.0.9" }, { name = "clu", marker = "extra == 'testing'" }, @@ -865,7 +886,9 @@ requires-dist = [ { name = "msgpack" }, { name = "mypy", marker = "extra == 'testing'" }, { name = "myst-nb", marker = "extra == 'docs'" }, + { name = "nanobind", marker = "extra == 'dev'", specifier = ">=2.2.0" }, { name = "nbstripout", marker = "extra == 'docs'" }, + { name = "ninja", marker = "extra == 'dev'", specifier = ">=1.11.1.1" }, { name = "numpy", marker = "python_full_version >= '3.11'", specifier = ">=1.23.2" }, { name = "numpy", marker = "python_full_version >= '3.12'", specifier = ">=1.26.0" }, { name = "opencv-python", marker = "extra == 'testing'" }, @@ -881,6 +904,7 @@ requires-dist = [ { name = "pyyaml", specifier = ">=5.4.1" }, { name = "recommonmark", marker = "extra == 'docs'" }, { name = "rich", specifier = ">=11.1" }, + { name = "scikit-build-core", extras = ["pyproject"], marker = "extra == 'dev'", specifier = ">=0.10.7" }, { name = "scikit-learn", marker = "extra == 'docs'" }, { name = "sentencepiece", marker = "extra == 'testing'" }, { name = "sphinx", marker = "extra == 'docs'", specifier = ">=3.3.1" }, @@ -1936,6 +1960,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/59/7854fbfb59f8ae35483ce93493708be5942ebb6328cd85b3a609df629736/namex-0.0.8-py3-none-any.whl", hash = "sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487", size = 5806 }, ] +[[package]] +name = "nanobind" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c5/f2/f1e6c86edf90caf04a4c28e789b15a6a5aa87a5e037e0bf03bbfcc4937b6/nanobind-2.2.0.tar.gz", hash = "sha256:53fa7a6227bddecaa4a0710e0b8dc18fad4c8ded7a0a31d6eddcf68009ead603", size = 944277 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/af/7032b05a35284e741666acbf3eac3a14b5e81cd92264ac775426884ed460/nanobind-2.2.0-py3-none-any.whl", hash = "sha256:138685ec9c5de4f57dd02d715b89ffcbcabae39c4e36b8b2c40eea2f1aa2f0d7", size = 231618 }, +] + [[package]] name = "nbclient" version = "0.10.0" @@ -2304,6 +2337,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 }, ] +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191 }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2604,6 +2646,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/ea/6d76df31432a0e6fdf81681a895f009a4bb47b3c39036db3e1b528191d52/pyparsing-3.1.2-py3-none-any.whl", hash = "sha256:f9db75911801ed778fe61bb643079ff86601aca99fcae6345aa67292038fb742", size = 103245 }, ] +[[package]] +name = "pyproject-hooks" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/82/28175b2414effca1cdac8dc99f76d660e7a4fb0ceefa4b4ab8f5f6742925/pyproject_hooks-1.2.0.tar.gz", hash = "sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8", size = 19228 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl", hash = "sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913", size = 10216 }, +] + [[package]] name = "pytest" version = "8.3.2" @@ -2959,6 +3010,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/ea/6f121d1802f3adae1981aea4209ea66f9d3c7f2f6d6b85ef4f13a61d17ef/rpds_py-0.20.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bb273176be34a746bdac0b0d7e4e2c467323d13640b736c4c477881a3220a989", size = 213529 }, ] +[[package]] +name = "scikit-build-core" +version = "0.10.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/75/ad5664c8050bbbea46a5f2b6a3dfbc6e6cf284826c0eee0a12f861364b3f/scikit_build_core-0.10.7.tar.gz", hash = "sha256:04cbb59fe795202a7eeede1849112ee9dcbf3469feebd9b8b36aa541336ac4f8", size = 255019 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/fe/90476c4f6a1b2f922efa00d26e876dd40c7279e28ec18f08f0851ad21ba6/scikit_build_core-0.10.7-py3-none-any.whl", hash = "sha256:5e13ab7ca7c3c6dd019607c3a6f53cba67dade8757c4c4f75b459e2f90e4dbc3", size = 165511 }, +] + [[package]] name = "scikit-learn" version = "1.5.1"