Skip to content

Commit

Permalink
Moved the implementation of custom_partitioning into jax/_src
Browse files Browse the repository at this point in the history
This is necessary to avoid a circular dependency

   jax -> fused_attention_stablehlo -> experimental -> jax

in #21371.

PiperOrigin-RevId: 650183480
  • Loading branch information
superbobry authored and jax authors committed Jul 8, 2024
1 parent 0a48d37 commit 6ea0f63
Show file tree
Hide file tree
Showing 3 changed files with 555 additions and 537 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ py_library_providing_imports_info(
"_src/checkify.py",
"_src/custom_batching.py",
"_src/custom_derivatives.py",
"_src/custom_partitioning.py",
"_src/custom_transpose.py",
"_src/debugging.py",
"_src/dispatch.py",
Expand Down
Loading

0 comments on commit 6ea0f63

Please sign in to comment.