From a546779e2418ce8a4d4d162ef48d850829458a73 Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Thu, 7 Nov 2024 12:59:00 -0800 Subject: [PATCH] C++ tree with path API * Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening. * Moves all the key classes down to C++ level, while keeping the APIs unchanged. Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy. * Registered defaultdict and ordereddict via the keypath API now. PiperOrigin-RevId: 694219933 --- chex/_src/dataclass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index c4e8426..1d3c8f0 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -279,10 +279,10 @@ def _flatten_with_path(dcls): path = [] keys = [] for k, v in sorted(dcls.__dict__.items()): + keys.append(k) # generate same aux data as flatten without path k = jax.tree_util.GetAttrKey(k) path.append((k, v)) - keys.append(k) - return path, keys + return path, tuple(keys) @functools.cache