Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Default
jax_jit_pjit_api_merge
to True. This means that the impleme…
…ntation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the `Mesh` context manager. This changes the internals of JAX without affecting any public API. Before, `jit` was a final style primitive. This means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. PiperOrigin-RevId: 508143501
- Loading branch information