Skip to content

Commit

Permalink
Merge pull request #3380 from BradyPlanden/jax-metal
Browse files Browse the repository at this point in the history
Jax on Metal w/ GPU support
  • Loading branch information
BradyPlanden authored Oct 2, 2023
2 parents 7eb5879 + 577f9a1 commit f81de94
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
4 changes: 3 additions & 1 deletion pybamm/expression_tree/operations/evaluate_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import jax
from jax.config import config

config.update("jax_enable_x64", True)
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if platform != "metal":
config.update("jax_enable_x64", True)


class JaxCooMatrix:
Expand Down
4 changes: 3 additions & 1 deletion pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import cache, safe_map, split_list

config.update("jax_enable_x64", True)
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
if platform != "metal":
config.update("jax_enable_x64", True)

MAX_ORDER = 5
NEWTON_MAXITER = 4
Expand Down
6 changes: 5 additions & 1 deletion pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,11 @@ async def solve_model_async(inputs_v):
return await asyncio.gather(*coro)

y = asyncio.run(solve_model_for_inputs())
elif platform.startswith("gpu") or platform.startswith("tpu"):
elif (
platform.startswith("gpu")
or platform.startswith("tpu")
or platform.startswith("metal")
):
# gpu execution runs faster when parallelised with vmap
# (see also comment below regarding single-program multiple-data
# execution (SPMD) using pmap on multiple XLAs)
Expand Down

0 comments on commit f81de94

Please sign in to comment.