diff --git a/scipy/pixi.toml b/scipy/pixi.toml index 0a7fbdb..6e89362 100644 --- a/scipy/pixi.toml +++ b/scipy/pixi.toml @@ -149,14 +149,22 @@ cupy = "*" [feature.cupy.tasks] test-cupy = { cmd = "python dev.py test -b cupy", cwd = "scipy" } -[feature.jax] +[feature.jax-cpu] # Windows support pending: https://github.com/conda-forge/jaxlib-feedstock/issues/161 platforms = ["linux-64", "osx-arm64"] -[feature.jax.dependencies] -jax = "=0.4.28" +[feature.jax-cpu.dependencies] +jax = "*" +jaxlib = { version = "*", build = "cpu" } -[feature.jax.tasks] +[feature.jax-cuda] +platforms = ["linux-64"] + +[feature.jax-cuda.dependencies] +jax = "*" +jaxlib = { version = "*", build = "*cuda*" } + +[feature.jax-base.tasks] test-jax = { cmd = "python dev.py test -b jax.numpy", cwd = "scipy" } [feature.array_api_strict.dependencies] @@ -214,17 +222,25 @@ netlib = ["netlib", "test"] blis = ["blis", "test"] torch = ["torch", "mkl"] # FIXME: add env var cupy = ["cupy"] -jax = ["jax"] +jax = ["jax-base", "jax-cpu"] mlx = ["mlx"] array-api-strict = ["array_api_strict"] -array-api = ["cpu", "array_api_strict", "jax", "mkl", "torch"] +array-api = [ + "cpu", + "array_api_strict", + "jax-base", + "jax-cpu", + "mkl", + "torch", +] array-api-cuda = [ "cuda", "array_api_strict", "cupy", - "jax", + "jax-base", + "jax-cuda", "mkl", "torch", "test", ] -free-threading = {features = ["free-threading"], no-default-feature = true } +free-threading = { features = ["free-threading"], no-default-feature = true }