From debecad3624a0f8e64edd763e7218bfc4b3323c5 Mon Sep 17 00:00:00 2001 From: daniel-dudt Date: Tue, 4 Jun 2024 18:02:48 -0600 Subject: [PATCH] numpifying to make vectorization work --- desc/backend.py | 5 ++-- desc/basis.py | 10 ++++---- desc/equilibrium/equilibrium.py | 21 ++++++++++------- desc/geometry/core.py | 17 +++++++------- desc/geometry/curve.py | 36 ++++++++++++++--------------- desc/geometry/surface.py | 8 +++---- desc/objectives/_generic.py | 41 ++++++++++++++++++++++++++++----- desc/profiles.py | 24 +++++++++---------- tests/test_examples.py | 1 + 9 files changed, 99 insertions(+), 64 deletions(-) diff --git a/desc/backend.py b/desc/backend.py index 77cf6f090d..240db14afd 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -126,8 +126,9 @@ def sign(x): 1 where x>=0, -1 where x<0 """ - x = jnp.asarray(x) - y = jnp.where(x == 0, 1, jnp.sign(x)) + # FIXME: when this is jnp, Basis with sym is a JAX object for some reason + x = np.asarray(x) + y = np.where(x == 0, 1, np.sign(x)) return y @jit diff --git a/desc/basis.py b/desc/basis.py index 9ba700f499..2c8099e5b2 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -290,7 +290,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -404,7 +404,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[0] != 0) or (derivatives[1] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -535,7 +535,7 @@ def evaluate( if modes is None: modes = self.modes if derivatives[0] != 0: - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -742,7 +742,7 @@ def evaluate( if modes is None: modes = self.modes if derivatives[2] != 0: - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) @@ -1241,7 +1241,7 @@ def evaluate( if modes is None: modes = self.modes if (derivatives[1] != 0) or (derivatives[2] != 0): - return jnp.zeros((nodes.shape[0], modes.shape[0])) + return np.zeros((nodes.shape[0], modes.shape[0])) if not len(modes): return np.array([]).reshape((len(nodes), 0)) diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index b32ae3b0fa..c213d0115d 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -10,7 +10,6 @@ from scipy.constants import mu_0 from termcolor import colored -from desc.backend import jnp from desc.basis import FourierZernikeBasis, fourier, zernike_radial from desc.compat import ensure_positive_jacobian from desc.compute import compute as compute_fun @@ -376,7 +375,7 @@ def __init__( assert ("R_lmn" in kwargs) and ("Z_lmn" in kwargs), "Must give both R and Z" self.R_lmn = kwargs.pop("R_lmn") self.Z_lmn = kwargs.pop("Z_lmn") - self.L_lmn = kwargs.pop("L_lmn", jnp.zeros(self.L_basis.num_modes)) + self.L_lmn = kwargs.pop("L_lmn", np.zeros(self.L_basis.num_modes)) else: self.set_initial_guess(ensure_nested=ensure_nested) if check_orientation: @@ -600,9 +599,15 @@ def change_resolution( ) self.axis.change_resolution(self.N, NFP=self.NFP, sym=self.sym) - self._R_lmn = copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes) - self._Z_lmn = copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes) - self._L_lmn = copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes) + self._R_lmn = np.asarray( + copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes) + ) + self._Z_lmn = np.asarray( + copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes) + ) + self._L_lmn = np.asarray( + copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes) + ) def get_surface_at(self, rho=None, theta=None, zeta=None): """Return a representation for a given coordinate surface. @@ -1265,7 +1270,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, R_lmn): - R_lmn = jnp.atleast_1d(jnp.asarray(R_lmn)) + R_lmn = np.atleast_1d(np.asarray(R_lmn)) errorif( R_lmn.size != self._R_lmn.size, ValueError, @@ -1282,7 +1287,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, Z_lmn): - Z_lmn = jnp.atleast_1d(jnp.asarray(Z_lmn)) + Z_lmn = np.atleast_1d(np.asarray(Z_lmn)) errorif( Z_lmn.size != self._Z_lmn.size, ValueError, @@ -1299,7 +1304,7 @@ def L_lmn(self): @L_lmn.setter def L_lmn(self, L_lmn): - L_lmn = jnp.atleast_1d(jnp.asarray(L_lmn)) + L_lmn = np.atleast_1d(np.asarray(L_lmn)) errorif( L_lmn.size != self._L_lmn.size, ValueError, diff --git a/desc/geometry/core.py b/desc/geometry/core.py index a637ff529e..e2cb8417c3 100644 --- a/desc/geometry/core.py +++ b/desc/geometry/core.py @@ -5,7 +5,6 @@ import numpy as np -from desc.backend import jnp from desc.compute import compute as compute_fun from desc.compute import data_index from desc.compute.geom_utils import reflection_matrix, rotation_matrix @@ -26,8 +25,8 @@ class Curve(IOAble, Optimizable, ABC): _io_attrs_ = ["_name", "_shift", "_rotmat"] def __init__(self, name=""): - self._shift = jnp.array([0, 0, 0], dtype=float) - self._rotmat = jnp.eye(3, dtype=float).flatten() + self._shift = np.array([0, 0, 0], dtype=float) + self._rotmat = np.eye(3, dtype=float).flatten() self._name = name def _set_up(self): @@ -43,12 +42,12 @@ def _set_up(self): @property def shift(self): """Displacement of curve in X, Y, Z.""" - return self.__dict__.setdefault("_shift", jnp.array([0, 0, 0], dtype=float)) + return self.__dict__.setdefault("_shift", np.array([0, 0, 0], dtype=float)) @shift.setter def shift(self, new): if len(new) == 3: - self._shift = jnp.asarray(new) + self._shift = np.asarray(new) else: raise ValueError("shift should be a 3 element vector, got {}".format(new)) @@ -56,14 +55,14 @@ def shift(self, new): @property def rotmat(self): """Rotation matrix of curve in X, Y, Z.""" - return self.__dict__.setdefault("_rotmat", jnp.eye(3, dtype=float).flatten()) + return self.__dict__.setdefault("_rotmat", np.eye(3, dtype=float).flatten()) @rotmat.setter def rotmat(self, new): if len(new) == 9: - self._rotmat = jnp.asarray(new) + self._rotmat = np.asarray(new) else: - self._rotmat = jnp.asarray(new.flatten()) + self._rotmat = np.asarray(new.flatten()) @property def name(self): @@ -179,7 +178,7 @@ def compute( def translate(self, displacement=[0, 0, 0]): """Translate the curve by a rigid displacement in X,Y,Z coordinates.""" - self.shift = self.shift + jnp.asarray(displacement) + self.shift = self.shift + np.asarray(displacement) def rotate(self, axis=[0, 0, 1], angle=0): """Rotate the curve by a fixed angle about axis in X,Y,Z coordinates.""" diff --git a/desc/geometry/curve.py b/desc/geometry/curve.py index 3f670656e7..bd087dda61 100644 --- a/desc/geometry/curve.py +++ b/desc/geometry/curve.py @@ -2,7 +2,7 @@ import numpy as np -from desc.backend import jnp, put +from desc.backend import put from desc.basis import FourierSeries from desc.compute import rpz2xyz, xyz2rpz from desc.grid import LinearGrid @@ -86,11 +86,11 @@ def __init__( NZ = np.max(abs(modes_Z)) N = max(NR, NZ) self._NFP = check_posint(NFP, "NFP", False) - self._R_basis = FourierSeries(N, int(NFP), sym="cos" if sym else False) - self._Z_basis = FourierSeries(N, int(NFP), sym="sin" if sym else False) + self._R_basis = FourierSeries(int(N), int(NFP), sym="cos" if sym else False) + self._Z_basis = FourierSeries(int(N), int(NFP), sym="sin" if sym else False) - self._R_n = copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2]) - self._Z_n = copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2]) + self._R_n = np.array(copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2])) + self._Z_n = np.array(copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2])) @property def sym(self): @@ -138,8 +138,8 @@ def change_resolution(self, N=None, NFP=None, sym=None): self.Z_basis.change_resolution( N=N, NFP=self.NFP, sym="sin" if self.sym else self.sym ) - self.R_n = copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes) - self.Z_n = copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes) + self.R_n = np.array(copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes)) + self.Z_n = np.array(copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes)) def get_coeffs(self, n): """Get Fourier coefficients for given mode number(s).""" @@ -176,7 +176,7 @@ def R_n(self): @R_n.setter def R_n(self, new): if len(new) == self.R_basis.num_modes: - self._R_n = jnp.asarray(new) + self._R_n = np.asarray(new) else: raise ValueError( f"R_n should have the same size as the basis, got {len(new)} for " @@ -192,7 +192,7 @@ def Z_n(self): @Z_n.setter def Z_n(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_n = jnp.asarray(new) + self._Z_n = np.asarray(new) else: raise ValueError( f"Z_n should have the same size as the basis, got {len(new)} for " @@ -439,7 +439,7 @@ def X_n(self): @X_n.setter def X_n(self, new): if len(new) == self.X_basis.num_modes: - self._X_n = jnp.asarray(new) + self._X_n = np.asarray(new) else: raise ValueError( f"X_n should have the same size as the basis, got {len(new)} for " @@ -455,7 +455,7 @@ def Y_n(self): @Y_n.setter def Y_n(self, new): if len(new) == self.Y_basis.num_modes: - self._Y_n = jnp.asarray(new) + self._Y_n = np.asarray(new) else: raise ValueError( f"Y_n should have the same size as the basis, got {len(new)} for " @@ -471,7 +471,7 @@ def Z_n(self): @Z_n.setter def Z_n(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_n = jnp.asarray(new) + self._Z_n = np.asarray(new) else: raise ValueError( f"Z_n should have the same size as the basis, got {len(new)} for " @@ -663,7 +663,7 @@ def r_n(self): @r_n.setter def r_n(self, new): if len(np.asarray(new)) == self.r_basis.num_modes: - self._r_n = jnp.asarray(new) + self._r_n = np.asarray(new) else: raise ValueError( f"r_n should have the same size as the basis, got {len(new)} for " @@ -829,7 +829,7 @@ def X(self): @X.setter def X(self, new): if len(new) == len(self.knots): - self._X = jnp.asarray(new) + self._X = np.asarray(new) else: raise ValueError( "X should have the same size as the knots, " @@ -845,7 +845,7 @@ def Y(self): @Y.setter def Y(self, new): if len(new) == len(self.knots): - self._Y = jnp.asarray(new) + self._Y = np.asarray(new) else: raise ValueError( "Y should have the same size as the knots, " @@ -861,7 +861,7 @@ def Z(self): @Z.setter def Z(self, new): if len(new) == len(self.knots): - self._Z = jnp.asarray(new) + self._Z = np.asarray(new) else: raise ValueError( "Z should have the same size as the knots, " @@ -876,7 +876,7 @@ def knots(self): @knots.setter def knots(self, new): if len(new) == len(self.knots): - knots = jnp.atleast_1d(jnp.asarray(new)) + knots = np.atleast_1d(np.asarray(new)) errorif( not np.all(np.diff(knots) > 0), ValueError, @@ -884,7 +884,7 @@ def knots(self, new): ) errorif(knots[0] < 0, ValueError, "knots must lie in [0, 2pi]") errorif(knots[-1] > 2 * np.pi, ValueError, "knots must lie in [0, 2pi]") - self._knots = jnp.asarray(knots) + self._knots = np.asarray(knots) else: raise ValueError( "new knots should have the same size as the current knots, " diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 3a4cf0c136..0bbd750a93 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -225,7 +225,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, new): if len(new) == self.R_basis.num_modes: - self._R_lmn = jnp.asarray(new) + self._R_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"R_lmn should have the same size as the basis, got {len(new)} for " @@ -241,7 +241,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_lmn = jnp.asarray(new) + self._Z_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"Z_lmn should have the same size as the basis, got {len(new)} for " @@ -963,7 +963,7 @@ def R_lmn(self): @R_lmn.setter def R_lmn(self, new): if len(new) == self.R_basis.num_modes: - self._R_lmn = jnp.asarray(new) + self._R_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"R_lmn should have the same size as the basis, got {len(new)} for " @@ -979,7 +979,7 @@ def Z_lmn(self): @Z_lmn.setter def Z_lmn(self, new): if len(new) == self.Z_basis.num_modes: - self._Z_lmn = jnp.asarray(new) + self._Z_lmn = np.atleast_1d(np.asarray(new)) else: raise ValueError( f"Z_lmn should have the same size as the basis, got {len(new)} for " diff --git a/desc/objectives/_generic.py b/desc/objectives/_generic.py index 00159e1ebc..f140781148 100644 --- a/desc/objectives/_generic.py +++ b/desc/objectives/_generic.py @@ -2,7 +2,9 @@ import functools import inspect +import multiprocessing import re +import warnings import numpy as np @@ -76,6 +78,7 @@ def __init__( normalize_target=False, loss_function=None, fd_step=1e-4, # TODO: generalize this to allow a vector of different scales + vectorized=False, name="external", **kwargs, ): @@ -85,6 +88,7 @@ def __init__( self._fun = fun self._dim_f = dim_f self._fd_step = fd_step + self._vectorized = vectorized self._kwargs = kwargs super().__init__( things=eq, @@ -114,11 +118,30 @@ def build(self, use_jit=True, verbose=1): def fun_wrapped(params): """Wrap external function with optimizable params arguments.""" - for param_key in self._eq.optimizable_params: - param_value = params[param_key] - if len(param_value): - setattr(self._eq, param_key, param_value) - return self._fun(self._eq, **self._kwargs) + param_shape = params["Psi"].shape + num_eq = param_shape[0] if len(param_shape) > 1 else 1 + if self._vectorized and num_eq > 1: + # convert params to list of Equilibria + eqs = [self._eq.copy() for _ in range(num_eq)] + for k, eq in enumerate(eqs): + for param_key in self._eq.optimizable_params: + param_value = np.array(params[param_key][k, :]) + if len(param_value): + setattr(eq, param_key, param_value) + # parallelize calls to external function + with warnings.catch_warnings(action="ignore", category=RuntimeWarning): + with multiprocessing.Pool() as pool: + results = pool.map( + functools.partial(self._fun, **self._kwargs), eqs + ) + return jnp.vstack(results) + else: + # update Equilibrium with params + for param_key in self._eq.optimizable_params: + param_value = params[param_key] + if len(param_value): + setattr(self._eq, param_key, param_value) + return self._fun(self._eq, **self._kwargs) # wrap external function to work with JAX abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f) @@ -187,7 +210,13 @@ def wrap_pure_callback(func): @functools.wraps(func) def wrapper(*args, **kwargs): result_shape_dtype = abstract_eval(*args, **kwargs) - return jax.pure_callback(func, result_shape_dtype, *args, **kwargs) + return jax.pure_callback( + func, + result_shape_dtype, + *args, + vectorized=self._vectorized, + **kwargs, + ) return wrapper diff --git a/desc/profiles.py b/desc/profiles.py index 2190359501..4223f8283f 100644 --- a/desc/profiles.py +++ b/desc/profiles.py @@ -277,7 +277,7 @@ def __init__(self, scale, profile, **kwargs): @property def params(self): """ndarray: Parameters for computation [scale, profile.params].""" - return jnp.concatenate([jnp.atleast_1d(self._scale), self._profile.params]) + return np.concatenate([np.atleast_1d(self._scale), self._profile.params]) @params.setter def params(self, x): @@ -364,7 +364,7 @@ def __init__(self, *profiles, **kwargs): @property def params(self): """ndarray: Concatenated array of parameters for computation.""" - return jnp.concatenate([profile.params for profile in self._profiles]) + return np.concatenate([profile.params for profile in self._profiles]) @params.setter def params(self, x): @@ -451,7 +451,7 @@ def __init__(self, *profiles, **kwargs): @property def params(self): """ndarray: Concatenated array of parameters for computation.""" - return jnp.concatenate([profile.params for profile in self._profiles]) + return np.concatenate([profile.params for profile in self._profiles]) @params.setter def params(self, x): @@ -591,9 +591,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size == self._basis.num_modes: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( "params should have the same size as the basis, " @@ -745,9 +745,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size == 3: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError(f"params should be an array of size 3, got {len(new)}.") @@ -849,7 +849,7 @@ def params(self): @params.setter def params(self, new): if len(new) == len(self._knots): - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( "params should have the same size as the knots, " @@ -932,9 +932,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size >= 5: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( "params should have at least 5 elements [ped, offset, sym, width," @@ -1199,9 +1199,9 @@ def params(self): @params.setter def params(self, new): - new = jnp.atleast_1d(jnp.asarray(new)) + new = np.atleast_1d(np.asarray(new)) if new.size == self._basis.num_modes: - self._params = jnp.asarray(new) + self._params = np.asarray(new) else: raise ValueError( f"params should have the same size as the basis, got {new.size} " diff --git a/tests/test_examples.py b/tests/test_examples.py index 5ee219d2fb..23b5c7e899 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1319,6 +1319,7 @@ def test_optimize_with_fourier_planar_coil(): @pytest.mark.unit +@pytest.mark.slow def test_external_vs_generic_objectives(tmpdir_factory): """Test ExternalObjective compared to GenericObjective.""" target = np.array([6.2e-3, 1.1e-1, 6.5e-3, 0]) # values at p_l = [2e2, -2e2]