Skip to content

Commit

Permalink
C++ tree with path API
Browse files Browse the repository at this point in the history
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

* Moves all the key classes down to C++ level, while keeping the APIs unchanged.
  * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy.

* Registered defaultdict and ordereddict via the keypath API now.

PiperOrigin-RevId: 694219933
  • Loading branch information
IvyZX authored and ChexDev committed Nov 20, 2024
1 parent adaf1b2 commit 5647bb8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,10 @@ def _flatten_with_path(dcls):
path = []
keys = []
for k, v in sorted(dcls.__dict__.items()):
keys.append(k) # generate same aux data as flatten without path
k = jax.tree_util.GetAttrKey(k)
path.append((k, v))
keys.append(k)
return path, keys
return path, tuple(keys)


@functools.cache
Expand Down

0 comments on commit 5647bb8

Please sign in to comment.