Skip to content

Commit

Permalink
Default jax_jit_pjit_api_merge to True. This means that the impleme…
Browse files Browse the repository at this point in the history
…ntation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the `Mesh` context manager.

This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
  • Loading branch information
yashk2810 authored and jax authors committed Feb 8, 2023
1 parent 9a1f9b1 commit 6ec9082
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.4

* Changes
* The implementation of `jit` and `pjit` has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, `jit` was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
[this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.

## jaxlib 0.4.4

## jax 0.4.3 (Feb 8, 2023)
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,9 +780,13 @@ def _update_jax_array_thread_local(val):

jit_pjit_api_merge = config.define_bool_state(
name='jax_jit_pjit_api_merge',
default=False,
default=True,
upgrade=True,
help=('If True, jit and pjit API will be merged.'))
help=('If True, jit and pjit API will be merged. You can only disable it via '
"the environment variable i.e. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. "
"The merge must be disabled via an environment variable since it "
"affects JAX at import time so it needs to be disabled before jax is "
"imported."))


spmd_mode = config.define_enum_state(
Expand Down

0 comments on commit 6ec9082

Please sign in to comment.