diff --git a/desc/coils.py b/desc/coils.py index afef9a87fc..4fcb7248b1 100644 --- a/desc/coils.py +++ b/desc/coils.py @@ -883,21 +883,22 @@ def linspaced_angular( axis : array-like, shape(3,) axis to rotate about angle : float - total rotational extend of coil set. + total rotational extent of coil set n : int number of copies of original coil endpoint : bool whether to include a coil at final angle + """ assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet) if current is None: current = coil.current currents = jnp.broadcast_to(current, (n,)) + phi = jnp.linspace(0, angle, n, endpoint=endpoint) coils = [] - phis = jnp.linspace(0, angle, n, endpoint=endpoint) for i in range(n): coili = coil.copy() - coili.rotate(axis, angle=phis[i]) + coili.rotate(axis=axis, angle=phi[i]) coili.current = currents[i] coils.append(coili) return cls(*coils) @@ -920,14 +921,15 @@ def linspaced_linear( number of copies of original coil endpoint : bool whether to include a coil at final point + """ assert isinstance(coil, _Coil) and not isinstance(coil, CoilSet) if current is None: current = coil.current currents = jnp.broadcast_to(current, (n,)) displacement = jnp.asarray(displacement) - coils = [] a = jnp.linspace(0, 1, n, endpoint=endpoint) + coils = [] for i in range(n): coili = coil.copy() coili.translate(a[i] * displacement) @@ -953,6 +955,7 @@ def from_symmetry(cls, coils, NFP, sym=False): number of field periods sym : bool whether coils should be stellarator symmetric + """ if not isinstance(coils, CoilSet): coils = CoilSet(coils) diff --git a/desc/compute/_curve.py b/desc/compute/_curve.py index 2fefa33143..f549e79d54 100644 --- a/desc/compute/_curve.py +++ b/desc/compute/_curve.py @@ -3,14 +3,8 @@ from desc.backend import jnp from .data_index import register_compute_fun -from .geom_utils import ( - _rotation_matrix_from_normal, - rpz2xyz, - rpz2xyz_vec, - xyz2rpz, - xyz2rpz_vec, -) -from .utils import cross, dot +from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec +from .utils import cross, dot, safenormalize @register_compute_fun( @@ -184,16 +178,19 @@ def _Z_Curve(params, transforms, profiles, data, **kwargs): basis="basis", ) def _x_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): - # create planar curve at z==0 + # create planar curve at Z==0 r = transforms["r"].transform(params["r_n"], dz=0) Z = jnp.zeros_like(r) X = r * jnp.cos(data["s"]) Y = r * jnp.sin(data["s"]) coords = jnp.array([X, Y, Z]).T # rotate into place - R = _rotation_matrix_from_normal(params["normal"]) - coords = jnp.matmul(coords, R.T) + params["center"] - coords = jnp.matmul(coords, params["rotmat"].T) + params["shift"] + Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis + axis = cross(Zaxis, params["normal"]) + angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"]))) + A = rotation_matrix(axis=axis, angle=angle) + coords = jnp.matmul(coords, A.T) + params["center"] + coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) + params["shift"] if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz(coords) data["x"] = coords @@ -224,16 +221,22 @@ def _x_s_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): dY = dr * jnp.sin(data["s"]) + r * jnp.cos(data["s"]) dZ = jnp.zeros_like(dX) coords = jnp.array([dX, dY, dZ]).T - A = _rotation_matrix_from_normal(params["normal"]) + # rotate into place + Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis + axis = cross(Zaxis, params["normal"]) + angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"]))) + A = rotation_matrix(axis=axis, angle=angle) coords = jnp.matmul(coords, A.T) - coords = jnp.matmul(coords, params["rotmat"].T) + coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) if kwargs.get("basis", "rpz").lower() == "rpz": X = r * jnp.cos(data["s"]) Y = r * jnp.sin(data["s"]) Z = jnp.zeros_like(X) xyzcoords = jnp.array([X, Y, Z]).T xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"] - xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"] + xyzcoords = ( + jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"] + ) x, y, z = xyzcoords.T coords = xyz2rpz_vec(coords, x=x, y=y) data["x_s"] = coords @@ -269,16 +272,22 @@ def _x_ss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): ) d2Z = jnp.zeros_like(d2X) coords = jnp.array([d2X, d2Y, d2Z]).T - A = _rotation_matrix_from_normal(params["normal"]) + # rotate into place + Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis + axis = cross(Zaxis, params["normal"]) + angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"]))) + A = rotation_matrix(axis=axis, angle=angle) coords = jnp.matmul(coords, A.T) - coords = jnp.matmul(coords, params["rotmat"].T) + coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) if kwargs.get("basis", "rpz").lower() == "rpz": X = r * jnp.cos(data["s"]) Y = r * jnp.sin(data["s"]) Z = jnp.zeros_like(X) xyzcoords = jnp.array([X, Y, Z]).T xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"] - xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"] + xyzcoords = ( + jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"] + ) x, y, z = xyzcoords.T coords = xyz2rpz_vec(coords, x=x, y=y) data["x_ss"] = coords @@ -321,16 +330,22 @@ def _x_sss_FourierPlanarCurve(params, transforms, profiles, data, **kwargs): ) d3Z = jnp.zeros_like(d3X) coords = jnp.array([d3X, d3Y, d3Z]).T - A = _rotation_matrix_from_normal(params["normal"]) + # rotate into place + Zaxis = jnp.array([0.0, 0.0, 1.0]) # 2D curve in X-Y plane has normal = +Z axis + axis = cross(Zaxis, params["normal"]) + angle = jnp.arccos(dot(Zaxis, safenormalize(params["normal"]))) + A = rotation_matrix(axis=axis, angle=angle) coords = jnp.matmul(coords, A.T) - coords = jnp.matmul(coords, params["rotmat"].T) + coords = jnp.matmul(coords, params["rotmat"].reshape((3, 3)).T) if kwargs.get("basis", "rpz").lower() == "rpz": X = r * jnp.cos(data["s"]) Y = r * jnp.sin(data["s"]) Z = jnp.zeros_like(X) xyzcoords = jnp.array([X, Y, Z]).T xyzcoords = jnp.matmul(xyzcoords, A.T) + params["center"] - xyzcoords = jnp.matmul(xyzcoords, params["rotmat"].T) + params["shift"] + xyzcoords = ( + jnp.matmul(xyzcoords, params["rotmat"].reshape((3, 3)).T) + params["shift"] + ) x, y, z = xyzcoords.T coords = xyz2rpz_vec(coords, x=x, y=y) data["x_sss"] = coords @@ -363,7 +378,9 @@ def _x_FourierRZCurve(params, transforms, profiles, data, **kwargs): coords = jnp.stack([R, phi, Z], axis=1) # convert to xyz for displacement and rotation coords = rpz2xyz(coords) - coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :] + coords = ( + coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :] + ) if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz(coords) data["x"] = coords @@ -397,7 +414,7 @@ def _x_s_FourierRZCurve(params, transforms, profiles, data, **kwargs): coords = jnp.stack([dR, dphi, dZ], axis=1) # convert to xyz for displacement and rotation coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2]) - coords = coords @ params["rotmat"].T + coords = coords @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2]) data["x_s"] = coords @@ -435,7 +452,7 @@ def _x_ss_FourierRZCurve(params, transforms, profiles, data, **kwargs): coords = jnp.stack([R, phi, Z], axis=1) # convert to xyz for displacement and rotation coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2]) - coords = coords @ params["rotmat"].T + coords = coords @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2]) data["x_ss"] = coords @@ -473,7 +490,7 @@ def _x_sss_FourierRZCurve(params, transforms, profiles, data, **kwargs): coords = jnp.stack([R, phi, Z], axis=1) # convert to xyz for displacement and rotation coords = rpz2xyz_vec(coords, phi=transforms["grid"].nodes[:, 2]) - coords = coords @ params["rotmat"].T + coords = coords @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz_vec(coords, phi=transforms["grid"].nodes[:, 2]) data["x_sss"] = coords @@ -504,7 +521,9 @@ def _x_FourierXYZCurve(params, transforms, profiles, data, **kwargs): Y = transforms["Y"].transform(params["Y_n"], dz=0) Z = transforms["Z"].transform(params["Z_n"], dz=0) coords = jnp.stack([X, Y, Z], axis=1) - coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :] + coords = ( + coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :] + ) if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz(coords) data["x"] = coords @@ -535,7 +554,7 @@ def _x_s_FourierXYZCurve(params, transforms, profiles, data, **kwargs): dY = transforms["Y"].transform(params["Y_n"], dz=1) dZ = transforms["Z"].transform(params["Z_n"], dz=1) coords = jnp.stack([dX, dY, dZ], axis=1) - coords = coords @ params["rotmat"].T + coords = coords @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz_vec( coords, @@ -570,7 +589,7 @@ def _x_ss_FourierXYZCurve(params, transforms, profiles, data, **kwargs): d2Y = transforms["Y"].transform(params["Y_n"], dz=2) d2Z = transforms["Z"].transform(params["Z_n"], dz=2) coords = jnp.stack([d2X, d2Y, d2Z], axis=1) - coords = coords @ params["rotmat"].T + coords = coords @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz_vec( coords, @@ -605,7 +624,7 @@ def _x_sss_FourierXYZCurve(params, transforms, profiles, data, **kwargs): d3Y = transforms["Y"].transform(params["Y_n"], dz=3) d3Z = transforms["Z"].transform(params["Z_n"], dz=3) coords = jnp.stack([d3X, d3Y, d3Z], axis=1) - coords = coords @ params["rotmat"].T + coords = coords @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz_vec( coords, @@ -662,7 +681,9 @@ def _x_SplineXYZCurve(params, transforms, profiles, data, **kwargs): ) coords = jnp.stack([Xq, Yq, Zq], axis=1) - coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :] + coords = ( + coords @ params["rotmat"].reshape((3, 3)).T + params["shift"][jnp.newaxis, :] + ) if kwargs.get("basis", "rpz").lower() == "rpz": coords = xyz2rpz(coords) data["x"] = coords @@ -715,7 +736,7 @@ def _x_s_SplineXYZCurve(params, transforms, profiles, data, **kwargs): ) coords_s = jnp.stack([dXq, dYq, dZq], axis=1) - coords_s = coords_s @ params["rotmat"].T + coords_s = coords_s @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": # calculate the xy coordinates to rotate to rpz @@ -745,7 +766,10 @@ def _x_s_SplineXYZCurve(params, transforms, profiles, data, **kwargs): ) coords = jnp.stack([Xq, Yq, Zq], axis=1) - coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :] + coords = ( + coords @ params["rotmat"].reshape((3, 3)).T + + params["shift"][jnp.newaxis, :] + ) coords_s = xyz2rpz_vec(coords_s, x=coords[:, 0], y=coords[:, 1]) data["x_s"] = coords_s @@ -798,7 +822,7 @@ def _x_ss_SplineXYZCurve(params, transforms, profiles, data, **kwargs): ) coords_ss = jnp.stack([d2Xq, d2Yq, d2Zq], axis=1) - coords_ss = coords_ss @ params["rotmat"].T + coords_ss = coords_ss @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": # calculate the xy coordinates to rotate to rpz @@ -827,7 +851,10 @@ def _x_ss_SplineXYZCurve(params, transforms, profiles, data, **kwargs): period=2 * jnp.pi, ) coords = jnp.stack([Xq, Yq, Zq], axis=1) - coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :] + coords = ( + coords @ params["rotmat"].reshape((3, 3)).T + + params["shift"][jnp.newaxis, :] + ) coords_ss = xyz2rpz_vec(coords_ss, x=coords[:, 0], y=coords[:, 1]) data["x_ss"] = coords_ss @@ -880,7 +907,7 @@ def _x_sss_SplineXYZCurve(params, transforms, profiles, data, **kwargs): ) coords_sss = jnp.stack([d3Xq, d3Yq, d3Zq], axis=1) - coords_sss = coords_sss @ params["rotmat"].T + coords_sss = coords_sss @ params["rotmat"].reshape((3, 3)).T if kwargs.get("basis", "rpz").lower() == "rpz": # calculate the xy coordinates to rotate to rpz @@ -909,7 +936,10 @@ def _x_sss_SplineXYZCurve(params, transforms, profiles, data, **kwargs): period=2 * jnp.pi, ) coords = jnp.stack([Xq, Yq, Zq], axis=1) - coords = coords @ params["rotmat"].T + params["shift"][jnp.newaxis, :] + coords = ( + coords @ params["rotmat"].reshape((3, 3)).T + + params["shift"][jnp.newaxis, :] + ) coords_sss = xyz2rpz_vec(coords_sss, x=coords[:, 0], y=coords[:, 1]) data["x_sss"] = coords_sss @@ -976,7 +1006,7 @@ def _frenet_normal(params, transforms, profiles, data, **kwargs): def _frenet_binormal(params, transforms, profiles, data, **kwargs): data["frenet_binormal"] = cross( data["frenet_tangent"], data["frenet_normal"] - ) * jnp.linalg.det(params["rotmat"]) + ) * jnp.linalg.det(params["rotmat"].reshape((3, 3))) return data diff --git a/desc/compute/geom_utils.py b/desc/compute/geom_utils.py index 62806f0184..15c62cb4ce 100644 --- a/desc/compute/geom_utils.py +++ b/desc/compute/geom_utils.py @@ -4,6 +4,8 @@ from desc.backend import jnp +from .utils import safenorm, safenormalize + def reflection_matrix(normal): """Matrix to reflect points across plane through origin with specified normal. @@ -35,17 +37,20 @@ def rotation_matrix(axis, angle=None): Returns ------- - rot : ndarray, shape(3,3) - Matrix to rotate points in cartesian (X,Y,Z) coordinates + rotmat : ndarray, shape(3,3) + Matrix to rotate points in cartesian (X,Y,Z) coordinates. + """ axis = jnp.asarray(axis) + norm = safenorm(axis) + axis = safenormalize(axis) if angle is None: - angle = jnp.linalg.norm(axis) - axis = axis / jnp.linalg.norm(axis) + angle = norm + eps = 1e2 * jnp.finfo(axis.dtype).eps R1 = jnp.cos(angle) * jnp.eye(3) R2 = jnp.sin(angle) * jnp.cross(axis, jnp.identity(axis.shape[0]) * -1) R3 = (1 - jnp.cos(angle)) * jnp.outer(axis, axis) - return R1 + R2 + R3 + return jnp.where(norm < eps, jnp.eye(3), R1 + R2 + R3) # if axis=0, no rotation def xyz2rpz(pts): @@ -152,17 +157,3 @@ def inner(vec, phi): return cart return inner(vec, phi) - - -def _rotation_matrix_from_normal(normal): - nx, ny, nz = normal - nxny = jnp.sqrt(nx**2 + ny**2) - R = jnp.array( - [ - [ny / nxny, -nx / nxny, 0], - [nx * nx / nxny, ny * nz / nxny, -nxny], - [nx, ny, nz], - ] - ).T - R = jnp.where(nxny == 0, jnp.eye(3), R) - return R diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 0cb98628c2..2a02bfbb81 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -496,7 +496,7 @@ def safenorm(x, ord=None, axis=None, fill=0, threshold=0): ---------- x : ndarray Vector or array to norm. - ord : {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional Order of norm. axis : {None, int, 2-tuple of ints}, optional Axis to take norm along. @@ -507,12 +507,36 @@ def safenorm(x, ord=None, axis=None, fill=0, threshold=0): """ is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) - x = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero - n = jnp.linalg.norm(x, ord=ord, axis=axis) + y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero + n = jnp.linalg.norm(y, ord=ord, axis=axis) n = jnp.where(is_zero.squeeze(), fill, n) # replace norm with zero if is_zero return n +def safenormalize(x, ord=None, axis=None, fill=0, threshold=0): + """Normalize a vector to unit length, but without nan gradient at x=0. + + Parameters + ---------- + x : ndarray + Vector or array to norm. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of norm. + axis : {None, int, 2-tuple of ints}, optional + Axis to take norm along. + fill : float, ndarray, optional + Value to return where x is zero. + threshold : float >= 0 + How small is x allowed to be. + + """ + is_zero = (jnp.abs(x) <= threshold).all(axis=axis, keepdims=True) + y = jnp.where(is_zero, jnp.ones_like(x), x) # replace x with ones if is_zero + n = safenorm(x, ord, axis, fill, threshold) * jnp.ones_like(x) + # return unit vector with equal components if norm <= threshold + return jnp.where(n <= threshold, jnp.ones_like(y) / jnp.sqrt(y.size), y / n) + + def safediv(a, b, fill=0, threshold=0): """Divide a/b with guards for division by zero. diff --git a/desc/examples/ARIES-CS_output.h5 b/desc/examples/ARIES-CS_output.h5 index 43751708eb..cfdfa3ca87 100644 Binary files a/desc/examples/ARIES-CS_output.h5 and b/desc/examples/ARIES-CS_output.h5 differ diff --git a/desc/examples/ATF_output.h5 b/desc/examples/ATF_output.h5 index 4c7c220049..3b81e4ddc7 100644 Binary files a/desc/examples/ATF_output.h5 and b/desc/examples/ATF_output.h5 differ diff --git a/desc/examples/DSHAPE_CURRENT_output.h5 b/desc/examples/DSHAPE_CURRENT_output.h5 index 72a4d9544f..5c0968ee21 100644 Binary files a/desc/examples/DSHAPE_CURRENT_output.h5 and b/desc/examples/DSHAPE_CURRENT_output.h5 differ diff --git a/desc/examples/DSHAPE_output.h5 b/desc/examples/DSHAPE_output.h5 index 9bb9e2e2c9..90663aef65 100644 Binary files a/desc/examples/DSHAPE_output.h5 and b/desc/examples/DSHAPE_output.h5 differ diff --git a/desc/examples/ESTELL_output.h5 b/desc/examples/ESTELL_output.h5 index d96abea719..0ded80c81e 100644 Binary files a/desc/examples/ESTELL_output.h5 and b/desc/examples/ESTELL_output.h5 differ diff --git a/desc/examples/HELIOTRON_output.h5 b/desc/examples/HELIOTRON_output.h5 index befea2fb29..d42f43641a 100644 Binary files a/desc/examples/HELIOTRON_output.h5 and b/desc/examples/HELIOTRON_output.h5 differ diff --git a/desc/examples/HSX_output.h5 b/desc/examples/HSX_output.h5 index ad3ab6ff4e..07be6fe4f7 100644 Binary files a/desc/examples/HSX_output.h5 and b/desc/examples/HSX_output.h5 differ diff --git a/desc/examples/NCSX_output.h5 b/desc/examples/NCSX_output.h5 index e9d3111dc6..da5550aa28 100644 Binary files a/desc/examples/NCSX_output.h5 and b/desc/examples/NCSX_output.h5 differ diff --git a/desc/examples/SOLOVEV_output.h5 b/desc/examples/SOLOVEV_output.h5 index 0ffe8d4146..7b84771784 100644 Binary files a/desc/examples/SOLOVEV_output.h5 and b/desc/examples/SOLOVEV_output.h5 differ diff --git a/desc/examples/W7-X_output.h5 b/desc/examples/W7-X_output.h5 index e2432f8270..e1a79ce065 100644 Binary files a/desc/examples/W7-X_output.h5 and b/desc/examples/W7-X_output.h5 differ diff --git a/desc/examples/WISTELL-A_output.h5 b/desc/examples/WISTELL-A_output.h5 index 45e1f590e7..b45b544cfc 100644 Binary files a/desc/examples/WISTELL-A_output.h5 and b/desc/examples/WISTELL-A_output.h5 differ diff --git a/desc/examples/precise_QA_output.h5 b/desc/examples/precise_QA_output.h5 index dc63e62b6b..ed844f2b58 100644 Binary files a/desc/examples/precise_QA_output.h5 and b/desc/examples/precise_QA_output.h5 differ diff --git a/desc/examples/precise_QH_output.h5 b/desc/examples/precise_QH_output.h5 index 0bb57812fa..426574f6e1 100644 Binary files a/desc/examples/precise_QH_output.h5 and b/desc/examples/precise_QH_output.h5 differ diff --git a/desc/geometry/core.py b/desc/geometry/core.py index 00b6b0bfb2..09f3144e95 100644 --- a/desc/geometry/core.py +++ b/desc/geometry/core.py @@ -17,18 +17,44 @@ ) from desc.grid import LinearGrid, QuadratureGrid, _Grid from desc.io import IOAble -from desc.optimizable import Optimizable +from desc.optimizable import Optimizable, optimizable_parameter class Curve(IOAble, Optimizable, ABC): """Abstract base class for 1D curves in 3D space.""" - _io_attrs_ = ["_name", "shift", "rotmat"] + _io_attrs_ = ["_name", "_shift", "_rotmat"] def __init__(self, name=""): - self.shift = jnp.array([0, 0, 0]).astype(float) - self.rotmat = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).astype(float) - self.name = name + self._shift = jnp.array([0, 0, 0], dtype=float) + self._rotmat = jnp.eye(3, dtype=float).flatten() + self._name = name + + @optimizable_parameter + @property + def shift(self): + """Displacement of curve in X, Y, Z.""" + return self._shift + + @shift.setter + def shift(self, new): + if len(new) == 3: + self._shift = jnp.asarray(new) + else: + raise ValueError("shift should be a 3 element vector, got {}".format(new)) + + @optimizable_parameter + @property + def rotmat(self): + """Rotation matrix of curve in X, Y, Z.""" + return self._rotmat + + @rotmat.setter + def rotmat(self, new): + if len(new) == 9: + self._rotmat = jnp.asarray(new) + else: + self._rotmat = jnp.asarray(new.flatten()) @property def name(self): @@ -143,19 +169,19 @@ def compute( return data def translate(self, displacement=[0, 0, 0]): - """Translate the curve by a rigid displacement in x, y, z.""" - self.shift += jnp.asarray(displacement) + """Translate the curve by a rigid displacement in X, Y, Z.""" + self.shift = self.shift + jnp.asarray(displacement) def rotate(self, axis=[0, 0, 1], angle=0): - """Rotate the curve by a fixed angle about axis in xyz coordinates.""" - R = rotation_matrix(axis, angle) - self.rotmat = R @ self.rotmat + """Rotate the curve by a fixed angle about axis in X, Y, Z coordinates.""" + R = rotation_matrix(axis=axis, angle=angle) + self.rotmat = (R @ self.rotmat.reshape(3, 3)).flatten() self.shift = self.shift @ R.T - def flip(self, normal): + def flip(self, normal=[0, 0, 1]): """Flip the curve about the plane with specified normal.""" F = reflection_matrix(normal) - self.rotmat = F @ self.rotmat + self.rotmat = (F @ self.rotmat.reshape(3, 3)).flatten() self.shift = self.shift @ F.T def __repr__(self): diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index e8ecf0b9d0..3dab7ef1b7 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -47,6 +47,8 @@ FixBoundaryR, FixBoundaryZ, FixCurrent, + FixCurveRotation, + FixCurveShift, FixElectronDensity, FixElectronTemperature, FixIonTemperature, diff --git a/desc/objectives/getters.py b/desc/objectives/getters.py index ec2a36ecc1..94dcda2ca8 100644 --- a/desc/objectives/getters.py +++ b/desc/objectives/getters.py @@ -19,6 +19,8 @@ FixBoundaryR, FixBoundaryZ, FixCurrent, + FixCurveRotation, + FixCurveShift, FixElectronDensity, FixElectronTemperature, FixIonTemperature, @@ -214,20 +216,36 @@ def get_NAE_constraints( return constraints -def maybe_add_self_consistency(eq, constraints): +def maybe_add_self_consistency(thing, constraints): """Add self consistency constraints if needed.""" def _is_any_instance(things, cls): return any([isinstance(t, cls) for t in things]) - if not _is_any_instance(constraints, BoundaryRSelfConsistency): - constraints += (BoundaryRSelfConsistency(eq=eq),) - if not _is_any_instance(constraints, BoundaryZSelfConsistency): - constraints += (BoundaryZSelfConsistency(eq=eq),) - if not _is_any_instance(constraints, FixLambdaGauge): - constraints += (FixLambdaGauge(eq=eq),) - if not _is_any_instance(constraints, AxisRSelfConsistency): - constraints += (AxisRSelfConsistency(eq=eq),) - if not _is_any_instance(constraints, AxisZSelfConsistency): - constraints += (AxisZSelfConsistency(eq=eq),) + # Equilibrium + if ( + hasattr(thing, "Ra_n") + and hasattr(thing, "Za_n") + and hasattr(thing, "Rb_lmn") + and hasattr(thing, "Zb_lmn") + and hasattr(thing, "L_lmn") + ): + if not _is_any_instance(constraints, BoundaryRSelfConsistency): + constraints += (BoundaryRSelfConsistency(eq=thing),) + if not _is_any_instance(constraints, BoundaryZSelfConsistency): + constraints += (BoundaryZSelfConsistency(eq=thing),) + if not _is_any_instance(constraints, FixLambdaGauge): + constraints += (FixLambdaGauge(eq=thing),) + if not _is_any_instance(constraints, AxisRSelfConsistency): + constraints += (AxisRSelfConsistency(eq=thing),) + if not _is_any_instance(constraints, AxisZSelfConsistency): + constraints += (AxisZSelfConsistency(eq=thing),) + + # Curve + elif hasattr(thing, "shift") and hasattr(thing, "rotmat"): + if not _is_any_instance(constraints, FixCurveShift): + constraints += (FixCurveShift(curve=thing),) + if not _is_any_instance(constraints, FixCurveRotation): + constraints += (FixCurveRotation(curve=thing),) + return constraints diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index bf405919ee..4a044f0df9 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -889,7 +889,7 @@ class FixLambdaGauge(_Objective): _scalar = False _linear = True _fixed = False - _units = "(radians)" + _units = "(rad)" _print_value_fmt = "lambda gauge error: {:10.3e} " def __init__( @@ -983,7 +983,7 @@ class FixThetaSFL(_Objective): _scalar = False _linear = True _fixed = True - _units = "(radians)" + _units = "(rad)" _print_value_fmt = "Theta - Theta SFL error: {:10.3e} " def __init__(self, eq, name="Theta SFL"): @@ -1602,7 +1602,7 @@ class FixModeLambda(_FixedObjective): """ _target_arg = "L_lmn" - _units = "(dimensionless)" + _units = "(rad)" _print_value_fmt = "Fixed-lambda modes error: {:10.3e} " def __init__( @@ -2074,7 +2074,7 @@ class FixSumModesLambda(_FixedObjective): _fixed = False # not "diagonal", since its fixing a sum _target_arg = "L_lmn" - _units = "(dimensionless)" + _units = "(rad)" _print_value_fmt = "Fixed-lambda sum modes error: {:10.3e} " def __init__( @@ -3238,3 +3238,193 @@ def compute(self, params, constants=None): """ return params["Psi"] + + +class FixCurveShift(_FixedObjective): + """Fixes Curve.shift attribute, which is redundant with other Curve params. + + Parameters + ---------- + curve : Curve + Curve that will be optimized to satisfy the Objective. + target : {float, ndarray}, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. + bounds : tuple of {float, ndarray}, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f + weight : {float, ndarray}, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + name : str, optional + Name of the objective function. + + """ + + _target_arg = "shift" + _units = "(m)" + _print_value_fmt = "Fixed-shift error: {:10.3e} " + + def __init__( + self, + curve, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + name="fixed-shift", + ): + self._target_from_user = setdefault(bounds, target) + super().__init__( + things=curve, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + name=name, + ) + + def build(self, use_jit=False, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + curve = self.things[0] + self._dim_f = curve.shift.size + + self.target, self.bounds = self._parse_target_from_user( + self._target_from_user, curve.shift, None, np.arange(self._dim_f) + ) + + if self._normalize: + self._normalization = 1 + + super().build(use_jit=use_jit, verbose=verbose) + + def compute(self, params, constants=None): + """Compute fixed-shift error. + + Parameters + ---------- + params : dict + Dictionary of curve degrees of freedom, eg Curve.params_dict + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self.constants + + Returns + ------- + f : ndarray + Curve shift (m). + + """ + return params["shift"] + + +class FixCurveRotation(_FixedObjective): + """Fixes Curve.rotmat attribute, which is redundant with other Curve params. + + Parameters + ---------- + curve : Curve + Curve that will be optimized to satisfy the Objective. + target : {float, ndarray}, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. + bounds : tuple of {float, ndarray}, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f + weight : {float, ndarray}, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + name : str, optional + Name of the objective function. + + """ + + _target_arg = "rotmat" + _units = "(rad)" + _print_value_fmt = "Fixed-rotation error: {:10.3e} " + + def __init__( + self, + curve, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + name="fixed-rotation", + ): + self._target_from_user = setdefault(bounds, target) + super().__init__( + things=curve, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + name=name, + ) + + def build(self, use_jit=False, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + curve = self.things[0] + self._dim_f = curve.rotmat.size + + self.target, self.bounds = self._parse_target_from_user( + self._target_from_user, curve.rotmat, None, np.arange(self._dim_f) + ) + + if self._normalize: + self._normalization = 1 + + super().build(use_jit=use_jit, verbose=verbose) + + def compute(self, params, constants=None): + """Compute fixed-rotation error. + + Parameters + ---------- + params : dict + Dictionary of curve degrees of freedom, eg Curve.params_dict + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self.constants + + Returns + ------- + f : ndarray + Curve rotation matrix (rad). + + """ + return params["rotmat"] diff --git a/desc/perturbations.py b/desc/perturbations.py index b51e5d33f0..ee1548b2f6 100644 --- a/desc/perturbations.py +++ b/desc/perturbations.py @@ -166,7 +166,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces if not objective.built: objective.build(eq, verbose=verbose) - constraints = maybe_add_self_consistency(eq=eq, constraints=constraints) + constraints = maybe_add_self_consistency(eq, constraints) for con in constraints: if not con.built: con.build(eq, verbose=verbose) @@ -544,7 +544,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces # FIXME: generalize to other constraints constraints = get_fixed_boundary_constraints(eq=eq) - constraints = maybe_add_self_consistency(eq=eq, constraints=constraints) + constraints = maybe_add_self_consistency(eq, constraints) for con in constraints: if not con.built: con.build(eq, verbose=verbose) diff --git a/tests/inputs/DSHAPE_output_saved_without_current.h5 b/tests/inputs/DSHAPE_output_saved_without_current.h5 index a8f9cda7ae..6ace46fd24 100644 Binary files a/tests/inputs/DSHAPE_output_saved_without_current.h5 and b/tests/inputs/DSHAPE_output_saved_without_current.h5 differ diff --git a/tests/inputs/LandremanPaul2022_QA_reactorScale_lowRes.h5 b/tests/inputs/LandremanPaul2022_QA_reactorScale_lowRes.h5 index 0d66cb21ad..015d8e71de 100644 Binary files a/tests/inputs/LandremanPaul2022_QA_reactorScale_lowRes.h5 and b/tests/inputs/LandremanPaul2022_QA_reactorScale_lowRes.h5 differ diff --git a/tests/inputs/LandremanPaul2022_QH_reactorScale_lowRes.h5 b/tests/inputs/LandremanPaul2022_QH_reactorScale_lowRes.h5 index 1fd44cce43..dfa7dd4a30 100644 Binary files a/tests/inputs/LandremanPaul2022_QH_reactorScale_lowRes.h5 and b/tests/inputs/LandremanPaul2022_QH_reactorScale_lowRes.h5 differ diff --git a/tests/inputs/circular_model_tokamak_output.h5 b/tests/inputs/circular_model_tokamak_output.h5 index b827251cab..0d37b0eaa3 100644 Binary files a/tests/inputs/circular_model_tokamak_output.h5 and b/tests/inputs/circular_model_tokamak_output.h5 differ diff --git a/tests/inputs/iotest_HELIOTRON.h5 b/tests/inputs/iotest_HELIOTRON.h5 index 92d5998510..452eac9fb1 100644 Binary files a/tests/inputs/iotest_HELIOTRON.h5 and b/tests/inputs/iotest_HELIOTRON.h5 differ diff --git a/tests/inputs/master_compute_data.pkl b/tests/inputs/master_compute_data.pkl index 71efea318a..15236b9c9d 100644 Binary files a/tests/inputs/master_compute_data.pkl and b/tests/inputs/master_compute_data.pkl differ diff --git a/tests/test_coils.py b/tests/test_coils.py index a2f8287059..e9393f045a 100644 --- a/tests/test_coils.py +++ b/tests/test_coils.py @@ -518,8 +518,8 @@ def test_save_and_load_makegrid_coils_rotated(tmpdir_factory): Z2 = coords2[:, 2] np.testing.assert_allclose(c1.current, c2.current, err_msg=f"Coil {i}") - np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}") - np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}") + np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}", atol=1e-16) + np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}", atol=1e-16) np.testing.assert_allclose(Z1, Z2, atol=2e-7, err_msg=f"Coil {i}") # check Bnormal on torus and ensure is near zero @@ -593,8 +593,8 @@ def test_save_and_load_makegrid_coils_rotated_int_grid(tmpdir_factory): Z2 = coords2[:, 2] np.testing.assert_allclose(c1.current, c2.current, err_msg=f"Coil {i}") - np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}") - np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}") + np.testing.assert_allclose(X1, X2, err_msg=f"Coil {i}", atol=1e-16) + np.testing.assert_allclose(Y1, Y2, err_msg=f"Coil {i}", atol=1e-16) np.testing.assert_allclose(Z1, Z2, atol=2e-7, err_msg=f"Coil {i}") # check Bnormal on torus and ensure is near zero diff --git a/tests/test_compute_utils.py b/tests/test_compute_utils.py index e1b8aa4cb8..02a6947b83 100644 --- a/tests/test_compute_utils.py +++ b/tests/test_compute_utils.py @@ -1,9 +1,12 @@ """Tests compute utilities.""" +import jax import numpy as np import pytest +from desc.backend import jnp from desc.basis import FourierZernikeBasis +from desc.compute.geom_utils import rotation_matrix from desc.compute.utils import ( _get_grid_surface, line_integrals, @@ -568,3 +571,15 @@ def test_surface_min_max(self): Bmin_alt[j] = np.min(B[mask]) np.testing.assert_allclose(Bmax_alt, grid.compress(surface_max(grid, B))) np.testing.assert_allclose(Bmin_alt, grid.compress(surface_min(grid, B))) + + +@pytest.mark.unit +def test_rotation_matrix(): + """Test that rotation_matrix works with fwd & rev AD for axis=[0, 0, 0].""" + dfdx_fwd = jax.jacfwd(rotation_matrix) + dfdx_rev = jax.jacrev(rotation_matrix) + x0 = jnp.array([0.0, 0.0, 0.0]) + + np.testing.assert_allclose(rotation_matrix(x0), np.eye(3)) + np.testing.assert_allclose(dfdx_fwd(x0), np.zeros((3, 3, 3))) + np.testing.assert_allclose(dfdx_rev(x0), np.zeros((3, 3, 3))) diff --git a/tests/test_curves.py b/tests/test_curves.py index 356bf83cd1..b4ff6d8740 100644 --- a/tests/test_curves.py +++ b/tests/test_curves.py @@ -402,6 +402,19 @@ def test_asserts(self): class TestPlanarCurve: """Tests for FourierPlanarCurve class.""" + @pytest.mark.unit + def test_rotation(self): + """Test rotation of planar curve.""" + cx = FourierPlanarCurve(center=[0, 0, 0], normal=[1, 0, 0], r_n=1) + cy = FourierPlanarCurve(center=[0, 0, 0], normal=[0, 1, 0], r_n=1) + cz = FourierPlanarCurve(center=[0, 0, 0], normal=[0, 0, 1], r_n=1) + datax = cx.compute("x", grid=20, basis="xyz") + datay = cy.compute("x", grid=20, basis="xyz") + dataz = cz.compute("x", grid=20, basis="xyz") + np.testing.assert_allclose(datax["x"][:, 0], 0, atol=1e-16) # only in Y-Z plane + np.testing.assert_allclose(datay["x"][:, 1], 0, atol=1e-16) # only in X-Z plane + np.testing.assert_allclose(dataz["x"][:, 2], 0, atol=1e-16) # only in X-Y plane + @pytest.mark.unit def test_length(self): """Test length of circular curve.""" @@ -472,7 +485,7 @@ def test_coords(self): np.testing.assert_allclose(z, 0) dr, dp, dz = c.compute("x_sss", grid=0, basis="rpz")["x_sss"].T np.testing.assert_allclose(dr, 0) - np.testing.assert_allclose(dp, 0) + np.testing.assert_allclose(dp, 0, atol=1e-14) np.testing.assert_allclose(dz, 2) c.rotate(angle=np.pi / 2) c.flip([0, 1, 0])