Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmark on state traversal, and a readme #4428

Merged
merged 1 commit into from
Dec 11, 2024

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Dec 10, 2024

Adds a benchmark script showing that the NNX state pytree traversal improved ~6x on flatten and ~3x on tree-map, after the latest JAX release.

Running the script in jaxlib==0.4.35 jax==0.4.35 gives:

### tree_flatten_with_path ###
total time: 3.397632122039795
time per step: 3397.63 µs
time per layer: 849.41 µs
### tree_map_with_path ###
total time: 4.1260809898376465
time per step: 4126.08 µs
time per layer: 1031.52 µs

And running in jaxlib==0.4.36 jax==0.4.37 gives:

### tree_flatten_with_path ###
total time: 0.5094578266143799
time per step: 509.46 µs
time per layer: 127.36 µs
### tree_map_with_path ###
total time: 1.18902587890625
time per step: 1189.03 µs
time per layer: 297.26 µs

@IvyZX IvyZX requested a review from cgarciae December 10, 2024 22:36
@copybara-service copybara-service bot merged commit a785bff into google:main Dec 11, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants