From 454eb82a827841a15f275a55a7a7435e6076c9ba Mon Sep 17 00:00:00 2001 From: YigitElma Date: Thu, 18 Jul 2024 15:41:56 +0300 Subject: [PATCH] change set_default_cpu to intended JAX version, in case in a future release JAX changes how it works --- desc/backend.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/desc/backend.py b/desc/backend.py index 1ba5128b4c..32227aecfe 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -1,5 +1,6 @@ """Backend functions for DESC, with options for JAX or regular numpy.""" +import functools import os import warnings @@ -73,7 +74,6 @@ vmap = jax.vmap scan = jax.lax.scan bincount = jnp.bincount - set_default_cpu = jax.default_device(jax.devices("cpu")[0]) from jax import custom_jvp from jax.experimental.ode import odeint from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular @@ -113,6 +113,28 @@ def put(arr, inds, vals): return arr return jnp.asarray(arr).at[inds].set(vals) + def set_default_cpu(func): + """Decorator to set default device to CPU for a function. + + Parameters + ---------- + func : callable + Function to decorate + + Returns + ------- + wrapper : callable + Decorated function that will run always on CPU even if + there are available GPUs. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with jax.default_device(jax.devices("cpu")[0]): + return func(*args, **kwargs) + + return wrapper + def sign(x): """Sign function, but returns 1 for x==0.