Skip to content

Commit

Permalink
fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
YigitElma committed Dec 20, 2024
1 parent 2ab9f0d commit 8ff1c59
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions docs/performance-tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Performance Tips

Caching the Compiled (Jitted) Code
----------------------------------
Although the compiled code is fast, it still takes time to compile. If you are running the same optimization, or some similar optimization, multiple times, you can save time by caching the compiled code. This automatically happens for a single session (for example, until you restart your kernel in Jupyter Notebook) but once you start using another session, the code will need to be recompiled. Fortunately, there is a way to bypass this. First create a cache directory, and put the following code at the beginning of your script:
Although the compiled code is fast, it still takes time to compile. If you are running the same optimization, or some similar optimization, multiple times, you can save time by caching the compiled code. This automatically happens for a single session (for example, until you restart your kernel in Jupyter Notebook) but once you start using another session, the code will need to be recompiled. Fortunately, there is a way to bypass this. First create a cache directory (i.e. ``jax-caches``), and put the following code at the beginning of your script:

.. code-block:: python
import jax
Expand All @@ -14,7 +14,7 @@ Although the compiled code is fast, it still takes time to compile. If you are r
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
This will create a cache directory called ``jax-caches`` in the parent directory of the script. The ``jax_persistent_cache_min_entry_size_bytes`` and ``jax_persistent_cache_min_compile_time_secs`` parameters are set to -1 and 0, respectively, to ensure that all compiled code is cached. For more details on caching, refer to official JAX documentation `here <https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html#persistent-compilation-cache>`__.
This will use a directory called ``jax-caches`` in the parent directory of the script to store the compiled code. The ``jax_persistent_cache_min_entry_size_bytes`` and ``jax_persistent_cache_min_compile_time_secs`` parameters are set to -1 and 0, respectively, to ensure that all compiled code is cached. For more details on caching, refer to official JAX documentation `here <https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html#persistent-compilation-cache>`__.

Note: Updating JAX version might re-compile some previously cached code, and thi might increase the cache size. Every once in a while, you might need to clear your cache directory.

Expand Down

0 comments on commit 8ff1c59

Please sign in to comment.