Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump the Version of Jax and Jaxlib to 0.4.34 #98

Merged
merged 2 commits into from
Nov 6, 2024
Merged

Conversation

mj023
Copy link
Collaborator

@mj023 mj023 commented Oct 23, 2024

This PR changes the dependencies for the pixi environments to require JAX>= 0.4.34 and jaxlib >= 0.4.34. The reason for this is, that #77 created memory allocation issues when using older JAX and jaxlib versions. For larger models this made them impossible to solve with limited GPU memory.

Problem description
With #77 the solve_model() function returned by LCM is jitted by default. After this change, one would run into problems when solving a model on the GPU. Jax would throw an error, because the program tried to save huge arrays, with the same dimensions as the the whole state-choice-space, into the GPU memory. For larger models these arrays could be multiple TB big. In the past high memory usage has never been a problem and considering the algorithm used for LCM, there should be no reason to save these arrays.

Reasons for Memory Allocation Issues
When jitting a function with JAX a computation graph will be created, that graph will then be passed to a compiler for further optimization. The Memory Allocation Issues probably stem from this optimization step. It is possible to visualize the computation graph before and after the compiler optimization. Below you can see the optimized computation graph with older Jax and Jaxlib versions. For some reason, the compiler splits the fusion into two parts, instead of creating one big fusion, with a reduce operator as the root. The arrays get passed as parameters from one fusion to the other and therefore are saved in the GPU memory.

for_loop_comp

@hmgaudecker
Copy link
Member

Great explanation! From my perspective, you won't have to go into such detail -- it is not that we are talking about a library that is in use across thousands of weird environments here. Everyone can update their environments without problems.

Copy link
Member

@timmens timmens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

I like the detailed explanation, I believe it can help us/others to find similar problems in the future. For now, I am relieved that the memory issues are gone! 🎉

I've quickly updated the README and pre-commit hooks myself.

@timmens timmens merged commit d27b9c7 into main Nov 6, 2024
7 checks passed
@timmens timmens deleted the Update-Jax-Version branch November 6, 2024 10:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants