From 6ec9082cf5a1f0bb73d738a411b5e0fc9b0eed76 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 8 Feb 2023 11:55:10 -0800 Subject: [PATCH] Default `jax_jit_pjit_api_merge` to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the `Mesh` context manager. This changes the internals of JAX without affecting any public API. Before, `jit` was a final style primitive. This means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. PiperOrigin-RevId: 508143501 --- CHANGELOG.md | 16 ++++++++++++++++ jax/_src/config.py | 8 ++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b0da0491e950..aec9cbd9e139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,22 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.4 +* Changes + * The implementation of `jit` and `pjit` has been merged. Merging jit and pjit + changes the internals of JAX without affecting the public API of JAX. + Before, `jit` was a final style primitive. Final style means that the creation + of jaxpr was delayed as much as possible and transformations were stacked + on top of each other. With the `jit`-`pjit` implementation merge, `jit` + becomes an initial style primitive which means that we trace to jaxpr + as early as possible. For more information see + [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). + Moving to initial style should simplify JAX's internals and make + development of features like dynamic shapes, etc easier. + You can disable it only via the environment variable i.e. + `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. + The merge must be disabled via an environment variable since it affects JAX + at import time so it needs to be disabled before jax is imported. + ## jaxlib 0.4.4 ## jax 0.4.3 (Feb 8, 2023) diff --git a/jax/_src/config.py b/jax/_src/config.py index 1012ee5bacf2..ef7fedca9323 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -780,9 +780,13 @@ def _update_jax_array_thread_local(val): jit_pjit_api_merge = config.define_bool_state( name='jax_jit_pjit_api_merge', - default=False, + default=True, upgrade=True, - help=('If True, jit and pjit API will be merged.')) + help=('If True, jit and pjit API will be merged. You can only disable it via ' + "the environment variable i.e. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. " + "The merge must be disabled via an environment variable since it " + "affects JAX at import time so it needs to be disabled before jax is " + "imported.")) spmd_mode = config.define_enum_state(