Bump the Version of Jax and Jaxlib to 0.4.34 #98
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR changes the dependencies for the pixi environments to require JAX>= 0.4.34 and jaxlib >= 0.4.34. The reason for this is, that #77 created memory allocation issues when using older JAX and jaxlib versions. For larger models this made them impossible to solve with limited GPU memory.
Problem description
With #77 the solve_model() function returned by LCM is jitted by default. After this change, one would run into problems when solving a model on the GPU. Jax would throw an error, because the program tried to save huge arrays, with the same dimensions as the the whole state-choice-space, into the GPU memory. For larger models these arrays could be multiple TB big. In the past high memory usage has never been a problem and considering the algorithm used for LCM, there should be no reason to save these arrays.
Reasons for Memory Allocation Issues
When jitting a function with JAX a computation graph will be created, that graph will then be passed to a compiler for further optimization. The Memory Allocation Issues probably stem from this optimization step. It is possible to visualize the computation graph before and after the compiler optimization. Below you can see the optimized computation graph with older Jax and Jaxlib versions. For some reason, the compiler splits the fusion into two parts, instead of creating one big fusion, with a reduce operator as the root. The arrays get passed as parameters from one fusion to the other and therefore are saved in the GPU memory.