Skip to content

Commit

Permalink
change set_default_cpu to intended JAX version, in case in a future r…
Browse files Browse the repository at this point in the history
…elease JAX changes how it works
  • Loading branch information
YigitElma committed Jul 18, 2024
1 parent 150117f commit 454eb82
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion desc/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Backend functions for DESC, with options for JAX or regular numpy."""

import functools
import os
import warnings

Expand Down Expand Up @@ -73,7 +74,6 @@
vmap = jax.vmap
scan = jax.lax.scan
bincount = jnp.bincount
set_default_cpu = jax.default_device(jax.devices("cpu")[0])
from jax import custom_jvp
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
Expand Down Expand Up @@ -113,6 +113,28 @@ def put(arr, inds, vals):
return arr
return jnp.asarray(arr).at[inds].set(vals)

def set_default_cpu(func):
"""Decorator to set default device to CPU for a function.
Parameters
----------
func : callable
Function to decorate
Returns
-------
wrapper : callable
Decorated function that will run always on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

def sign(x):
"""Sign function, but returns 1 for x==0.
Expand Down

0 comments on commit 454eb82

Please sign in to comment.