Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export compiled objectives for common equilibrium resolutions? #1402

Open
dpanici opened this issue Nov 18, 2024 · 8 comments · May be fixed by #1483
Open

Export compiled objectives for common equilibrium resolutions? #1402

dpanici opened this issue Nov 18, 2024 · 8 comments · May be fixed by #1483
Labels
P1 Lowest Priority, will get to eventually

Comments

@dpanici
Copy link
Collaborator

dpanici commented Nov 18, 2024

https://jax.readthedocs.io/en/latest/export/export.html

@YigitElma
Copy link
Collaborator

In my last meeting with Egemen, we talked about this after I showed him some profiling. The long initialization and long compile times made him ask this. I think #1374 and general performance improvements on factorize_inear_constraints will make actual minimization start way earlier and further tricks like storing compiled code will be redundant. But might be worth trying.

Some concerns tho,

  • Does every jax version create different compiled code? Should we compile for multiple versions?
  • How many different resolutions we want to store?
  • And should we store them in master branch?

I feel like this is not too easy but anyway.

@dpanici dpanici added the P1 Lowest Priority, will get to eventually label Nov 20, 2024
@dpanici
Copy link
Collaborator Author

dpanici commented Dec 2, 2024

JAX persistent cache, if enabled by a user, could partially alleviate this

@YigitElma
Copy link
Collaborator

@YigitElma
Copy link
Collaborator

For example, I created a jax-caches folder inside DESC, and I am using a notebook in my test-features subfolder. Adding the following lines to the top of the notebook, make JAX store compiled functions in the specified folder, and they are not deleted after each session!

import jax
import jax.numpy as jnp

jax.config.update("jax_compilation_cache_dir", "../jax-caches")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

This sets the folder for storage, and sets the minimum limit for time and size to 0 (basically store every compiled code).

@YigitElma
Copy link
Collaborator

YigitElma commented Dec 3, 2024

Maybe we can add a dummy folder in the repo, the inner content is not tracked (add them to the .gitignore), and we can add a README in it. Then, add the above lines of code to the backend.py. Users can remove the content once they change their dependencies etc, or not. Because the cache is not too big, only couple MBs. @dpanici @ddudt @f0uriest

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 6, 2024

Is there a way to have the cache only allow a certain amount max? because in theory it could grow without bound right?

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 6, 2024

@dpanici
Copy link
Collaborator Author

dpanici commented Dec 9, 2024

Make a note of this in docs that you can enable this to speed up repeated compilations

@YigitElma YigitElma mentioned this issue Dec 13, 2024
@YigitElma YigitElma linked a pull request Dec 20, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
P1 Lowest Priority, will get to eventually
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants