Skip to content

Commit

Permalink
numpifying to make vectorization work
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dudt committed Jun 5, 2024
1 parent 2bb9017 commit debecad
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 64 deletions.
5 changes: 3 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def sign(x):
1 where x>=0, -1 where x<0
"""
x = jnp.asarray(x)
y = jnp.where(x == 0, 1, jnp.sign(x))
# FIXME: when this is jnp, Basis with sym is a JAX object for some reason
x = np.asarray(x)
y = np.where(x == 0, 1, np.sign(x))
return y

@jit
Expand Down
10 changes: 5 additions & 5 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def evaluate(
if modes is None:
modes = self.modes
if (derivatives[1] != 0) or (derivatives[2] != 0):
return jnp.zeros((nodes.shape[0], modes.shape[0]))
return np.zeros((nodes.shape[0], modes.shape[0]))
if not len(modes):
return np.array([]).reshape((len(nodes), 0))

Expand Down Expand Up @@ -404,7 +404,7 @@ def evaluate(
if modes is None:
modes = self.modes
if (derivatives[0] != 0) or (derivatives[1] != 0):
return jnp.zeros((nodes.shape[0], modes.shape[0]))
return np.zeros((nodes.shape[0], modes.shape[0]))
if not len(modes):
return np.array([]).reshape((len(nodes), 0))

Expand Down Expand Up @@ -535,7 +535,7 @@ def evaluate(
if modes is None:
modes = self.modes
if derivatives[0] != 0:
return jnp.zeros((nodes.shape[0], modes.shape[0]))
return np.zeros((nodes.shape[0], modes.shape[0]))
if not len(modes):
return np.array([]).reshape((len(nodes), 0))

Expand Down Expand Up @@ -742,7 +742,7 @@ def evaluate(
if modes is None:
modes = self.modes
if derivatives[2] != 0:
return jnp.zeros((nodes.shape[0], modes.shape[0]))
return np.zeros((nodes.shape[0], modes.shape[0]))
if not len(modes):
return np.array([]).reshape((len(nodes), 0))

Expand Down Expand Up @@ -1241,7 +1241,7 @@ def evaluate(
if modes is None:
modes = self.modes
if (derivatives[1] != 0) or (derivatives[2] != 0):
return jnp.zeros((nodes.shape[0], modes.shape[0]))
return np.zeros((nodes.shape[0], modes.shape[0]))
if not len(modes):
return np.array([]).reshape((len(nodes), 0))

Expand Down
21 changes: 13 additions & 8 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from scipy.constants import mu_0
from termcolor import colored

from desc.backend import jnp
from desc.basis import FourierZernikeBasis, fourier, zernike_radial
from desc.compat import ensure_positive_jacobian
from desc.compute import compute as compute_fun
Expand Down Expand Up @@ -376,7 +375,7 @@ def __init__(
assert ("R_lmn" in kwargs) and ("Z_lmn" in kwargs), "Must give both R and Z"
self.R_lmn = kwargs.pop("R_lmn")
self.Z_lmn = kwargs.pop("Z_lmn")
self.L_lmn = kwargs.pop("L_lmn", jnp.zeros(self.L_basis.num_modes))
self.L_lmn = kwargs.pop("L_lmn", np.zeros(self.L_basis.num_modes))
else:
self.set_initial_guess(ensure_nested=ensure_nested)
if check_orientation:
Expand Down Expand Up @@ -600,9 +599,15 @@ def change_resolution(
)
self.axis.change_resolution(self.N, NFP=self.NFP, sym=self.sym)

self._R_lmn = copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes)
self._Z_lmn = copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes)
self._L_lmn = copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes)
self._R_lmn = np.asarray(
copy_coeffs(self.R_lmn, old_modes_R, self.R_basis.modes)
)
self._Z_lmn = np.asarray(
copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes)
)
self._L_lmn = np.asarray(
copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes)
)

def get_surface_at(self, rho=None, theta=None, zeta=None):
"""Return a representation for a given coordinate surface.
Expand Down Expand Up @@ -1265,7 +1270,7 @@ def R_lmn(self):

@R_lmn.setter
def R_lmn(self, R_lmn):
R_lmn = jnp.atleast_1d(jnp.asarray(R_lmn))
R_lmn = np.atleast_1d(np.asarray(R_lmn))
errorif(
R_lmn.size != self._R_lmn.size,
ValueError,
Expand All @@ -1282,7 +1287,7 @@ def Z_lmn(self):

@Z_lmn.setter
def Z_lmn(self, Z_lmn):
Z_lmn = jnp.atleast_1d(jnp.asarray(Z_lmn))
Z_lmn = np.atleast_1d(np.asarray(Z_lmn))
errorif(
Z_lmn.size != self._Z_lmn.size,
ValueError,
Expand All @@ -1299,7 +1304,7 @@ def L_lmn(self):

@L_lmn.setter
def L_lmn(self, L_lmn):
L_lmn = jnp.atleast_1d(jnp.asarray(L_lmn))
L_lmn = np.atleast_1d(np.asarray(L_lmn))
errorif(
L_lmn.size != self._L_lmn.size,
ValueError,
Expand Down
17 changes: 8 additions & 9 deletions desc/geometry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np

from desc.backend import jnp
from desc.compute import compute as compute_fun
from desc.compute import data_index
from desc.compute.geom_utils import reflection_matrix, rotation_matrix
Expand All @@ -26,8 +25,8 @@ class Curve(IOAble, Optimizable, ABC):
_io_attrs_ = ["_name", "_shift", "_rotmat"]

def __init__(self, name=""):
self._shift = jnp.array([0, 0, 0], dtype=float)
self._rotmat = jnp.eye(3, dtype=float).flatten()
self._shift = np.array([0, 0, 0], dtype=float)
self._rotmat = np.eye(3, dtype=float).flatten()
self._name = name

def _set_up(self):
Expand All @@ -43,27 +42,27 @@ def _set_up(self):
@property
def shift(self):
"""Displacement of curve in X, Y, Z."""
return self.__dict__.setdefault("_shift", jnp.array([0, 0, 0], dtype=float))
return self.__dict__.setdefault("_shift", np.array([0, 0, 0], dtype=float))

@shift.setter
def shift(self, new):
if len(new) == 3:
self._shift = jnp.asarray(new)
self._shift = np.asarray(new)
else:
raise ValueError("shift should be a 3 element vector, got {}".format(new))

@optimizable_parameter
@property
def rotmat(self):
"""Rotation matrix of curve in X, Y, Z."""
return self.__dict__.setdefault("_rotmat", jnp.eye(3, dtype=float).flatten())
return self.__dict__.setdefault("_rotmat", np.eye(3, dtype=float).flatten())

@rotmat.setter
def rotmat(self, new):
if len(new) == 9:
self._rotmat = jnp.asarray(new)
self._rotmat = np.asarray(new)
else:
self._rotmat = jnp.asarray(new.flatten())
self._rotmat = np.asarray(new.flatten())

@property
def name(self):
Expand Down Expand Up @@ -179,7 +178,7 @@ def compute(

def translate(self, displacement=[0, 0, 0]):
"""Translate the curve by a rigid displacement in X,Y,Z coordinates."""
self.shift = self.shift + jnp.asarray(displacement)
self.shift = self.shift + np.asarray(displacement)

def rotate(self, axis=[0, 0, 1], angle=0):
"""Rotate the curve by a fixed angle about axis in X,Y,Z coordinates."""
Expand Down
36 changes: 18 additions & 18 deletions desc/geometry/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from desc.backend import jnp, put
from desc.backend import put
from desc.basis import FourierSeries
from desc.compute import rpz2xyz, xyz2rpz
from desc.grid import LinearGrid
Expand Down Expand Up @@ -86,11 +86,11 @@ def __init__(
NZ = np.max(abs(modes_Z))
N = max(NR, NZ)
self._NFP = check_posint(NFP, "NFP", False)
self._R_basis = FourierSeries(N, int(NFP), sym="cos" if sym else False)
self._Z_basis = FourierSeries(N, int(NFP), sym="sin" if sym else False)
self._R_basis = FourierSeries(int(N), int(NFP), sym="cos" if sym else False)
self._Z_basis = FourierSeries(int(N), int(NFP), sym="sin" if sym else False)

self._R_n = copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2])
self._Z_n = copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2])
self._R_n = np.array(copy_coeffs(R_n, modes_R, self.R_basis.modes[:, 2]))
self._Z_n = np.array(copy_coeffs(Z_n, modes_Z, self.Z_basis.modes[:, 2]))

@property
def sym(self):
Expand Down Expand Up @@ -138,8 +138,8 @@ def change_resolution(self, N=None, NFP=None, sym=None):
self.Z_basis.change_resolution(
N=N, NFP=self.NFP, sym="sin" if self.sym else self.sym
)
self.R_n = copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes)
self.Z_n = copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes)
self.R_n = np.array(copy_coeffs(self.R_n, R_modes_old, self.R_basis.modes))
self.Z_n = np.array(copy_coeffs(self.Z_n, Z_modes_old, self.Z_basis.modes))

def get_coeffs(self, n):
"""Get Fourier coefficients for given mode number(s)."""
Expand Down Expand Up @@ -176,7 +176,7 @@ def R_n(self):
@R_n.setter
def R_n(self, new):
if len(new) == self.R_basis.num_modes:
self._R_n = jnp.asarray(new)
self._R_n = np.asarray(new)
else:
raise ValueError(
f"R_n should have the same size as the basis, got {len(new)} for "
Expand All @@ -192,7 +192,7 @@ def Z_n(self):
@Z_n.setter
def Z_n(self, new):
if len(new) == self.Z_basis.num_modes:
self._Z_n = jnp.asarray(new)
self._Z_n = np.asarray(new)
else:
raise ValueError(
f"Z_n should have the same size as the basis, got {len(new)} for "
Expand Down Expand Up @@ -439,7 +439,7 @@ def X_n(self):
@X_n.setter
def X_n(self, new):
if len(new) == self.X_basis.num_modes:
self._X_n = jnp.asarray(new)
self._X_n = np.asarray(new)
else:
raise ValueError(
f"X_n should have the same size as the basis, got {len(new)} for "
Expand All @@ -455,7 +455,7 @@ def Y_n(self):
@Y_n.setter
def Y_n(self, new):
if len(new) == self.Y_basis.num_modes:
self._Y_n = jnp.asarray(new)
self._Y_n = np.asarray(new)
else:
raise ValueError(
f"Y_n should have the same size as the basis, got {len(new)} for "
Expand All @@ -471,7 +471,7 @@ def Z_n(self):
@Z_n.setter
def Z_n(self, new):
if len(new) == self.Z_basis.num_modes:
self._Z_n = jnp.asarray(new)
self._Z_n = np.asarray(new)
else:
raise ValueError(
f"Z_n should have the same size as the basis, got {len(new)} for "
Expand Down Expand Up @@ -663,7 +663,7 @@ def r_n(self):
@r_n.setter
def r_n(self, new):
if len(np.asarray(new)) == self.r_basis.num_modes:
self._r_n = jnp.asarray(new)
self._r_n = np.asarray(new)
else:
raise ValueError(
f"r_n should have the same size as the basis, got {len(new)} for "
Expand Down Expand Up @@ -829,7 +829,7 @@ def X(self):
@X.setter
def X(self, new):
if len(new) == len(self.knots):
self._X = jnp.asarray(new)
self._X = np.asarray(new)
else:
raise ValueError(
"X should have the same size as the knots, "
Expand All @@ -845,7 +845,7 @@ def Y(self):
@Y.setter
def Y(self, new):
if len(new) == len(self.knots):
self._Y = jnp.asarray(new)
self._Y = np.asarray(new)
else:
raise ValueError(
"Y should have the same size as the knots, "
Expand All @@ -861,7 +861,7 @@ def Z(self):
@Z.setter
def Z(self, new):
if len(new) == len(self.knots):
self._Z = jnp.asarray(new)
self._Z = np.asarray(new)
else:
raise ValueError(
"Z should have the same size as the knots, "
Expand All @@ -876,15 +876,15 @@ def knots(self):
@knots.setter
def knots(self, new):
if len(new) == len(self.knots):
knots = jnp.atleast_1d(jnp.asarray(new))
knots = np.atleast_1d(np.asarray(new))
errorif(
not np.all(np.diff(knots) > 0),
ValueError,
"supplied knots must be monotonically increasing",
)
errorif(knots[0] < 0, ValueError, "knots must lie in [0, 2pi]")
errorif(knots[-1] > 2 * np.pi, ValueError, "knots must lie in [0, 2pi]")
self._knots = jnp.asarray(knots)
self._knots = np.asarray(knots)
else:
raise ValueError(
"new knots should have the same size as the current knots, "
Expand Down
8 changes: 4 additions & 4 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def R_lmn(self):
@R_lmn.setter
def R_lmn(self, new):
if len(new) == self.R_basis.num_modes:
self._R_lmn = jnp.asarray(new)
self._R_lmn = np.atleast_1d(np.asarray(new))
else:
raise ValueError(
f"R_lmn should have the same size as the basis, got {len(new)} for "
Expand All @@ -241,7 +241,7 @@ def Z_lmn(self):
@Z_lmn.setter
def Z_lmn(self, new):
if len(new) == self.Z_basis.num_modes:
self._Z_lmn = jnp.asarray(new)
self._Z_lmn = np.atleast_1d(np.asarray(new))
else:
raise ValueError(
f"Z_lmn should have the same size as the basis, got {len(new)} for "
Expand Down Expand Up @@ -963,7 +963,7 @@ def R_lmn(self):
@R_lmn.setter
def R_lmn(self, new):
if len(new) == self.R_basis.num_modes:
self._R_lmn = jnp.asarray(new)
self._R_lmn = np.atleast_1d(np.asarray(new))
else:
raise ValueError(
f"R_lmn should have the same size as the basis, got {len(new)} for "
Expand All @@ -979,7 +979,7 @@ def Z_lmn(self):
@Z_lmn.setter
def Z_lmn(self, new):
if len(new) == self.Z_basis.num_modes:
self._Z_lmn = jnp.asarray(new)
self._Z_lmn = np.atleast_1d(np.asarray(new))
else:
raise ValueError(
f"Z_lmn should have the same size as the basis, got {len(new)} for "
Expand Down
Loading

0 comments on commit debecad

Please sign in to comment.