diff --git a/docs/performance-tips.rst b/docs/performance-tips.rst index 5206ff941..b72957314 100644 --- a/docs/performance-tips.rst +++ b/docs/performance-tips.rst @@ -4,7 +4,7 @@ Performance Tips Caching the Compiled (Jitted) Code ---------------------------------- -Although the compiled code is fast, it still takes time to compile. If you are running the same optimization, or some similar optimization, multiple times, you can save time by caching the compiled code. This automatically happens for a single session (for example, until you restart your kernel in Jupyter Notebook) but once you start using another session, the code will need to be recompiled. Fortunately, there is a way to bypass this. First create a cache directory, and put the following code at the beginning of your script: +Although the compiled code is fast, it still takes time to compile. If you are running the same optimization, or some similar optimization, multiple times, you can save time by caching the compiled code. This automatically happens for a single session (for example, until you restart your kernel in Jupyter Notebook) but once you start using another session, the code will need to be recompiled. Fortunately, there is a way to bypass this. First create a cache directory (i.e. ``jax-caches``), and put the following code at the beginning of your script: .. code-block:: python import jax @@ -14,7 +14,7 @@ Although the compiled code is fast, it still takes time to compile. If you are r jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) -This will create a cache directory called ``jax-caches`` in the parent directory of the script. The ``jax_persistent_cache_min_entry_size_bytes`` and ``jax_persistent_cache_min_compile_time_secs`` parameters are set to -1 and 0, respectively, to ensure that all compiled code is cached. For more details on caching, refer to official JAX documentation `here `__. +This will use a directory called ``jax-caches`` in the parent directory of the script to store the compiled code. The ``jax_persistent_cache_min_entry_size_bytes`` and ``jax_persistent_cache_min_compile_time_secs`` parameters are set to -1 and 0, respectively, to ensure that all compiled code is cached. For more details on caching, refer to official JAX documentation `here `__. Note: Updating JAX version might re-compile some previously cached code, and thi might increase the cache size. Every once in a while, you might need to clear your cache directory.