diff --git a/desc/backend.py b/desc/backend.py index 721920190c..77cf6f090d 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -84,6 +84,7 @@ tree_map, tree_structure, tree_unflatten, + treedef_is_leaf, ) def put(arr, inds, vals): @@ -412,6 +413,10 @@ def tree_leaves(*args, **kwargs): """Get leaves of pytree for numpy backend.""" raise NotImplementedError + def treedef_is_leaf(*args, **kwargs): + """Check is leaf of pytree for numpy backend.""" + raise NotImplementedError + def register_pytree_node(foo, *args): """Dummy decorator for non-jax pytrees.""" return foo diff --git a/desc/basis.py b/desc/basis.py index dafc6bf929..9ba700f499 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -114,11 +114,12 @@ def get_idx(self, L=0, M=0, N=0, error=True): N : int Toroidal mode number. error : bool - whether to raise exception if mode is not in basis, or return empty array + Whether to raise exception if the mode is not in the basis (default), + or to return an empty array. Returns ------- - idx : ndarray of int + idx : int Index of given mode numbers. """ @@ -130,7 +131,7 @@ def get_idx(self, L=0, M=0, N=0, error=True): "mode ({}, {}, {}) is not in basis {}".format(L, M, N, str(self)) ) from e else: - return np.array([]).astype(int) + return np.array([], dtype=int) @abstractmethod def _get_modes(self): diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index e0ae573e60..a99da8cfcf 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -63,9 +63,10 @@ FixOmniBmax, FixOmniMap, FixOmniWell, - FixParameter, + FixParameters, FixPressure, FixPsi, + FixSheetCurrent, FixSumModesLambda, FixSumModesR, FixSumModesZ, diff --git a/desc/objectives/_free_boundary.py b/desc/objectives/_free_boundary.py index 0545baa6cd..ebcdf71ab2 100644 --- a/desc/objectives/_free_boundary.py +++ b/desc/objectives/_free_boundary.py @@ -485,12 +485,11 @@ def build(self, use_jit=True, verbose=1): if self._source_grid is None: # for axisymmetry we still need to know about toroidal effects, so its # cheapest to pretend there are extra field periods - source_NFP = eq.NFP if eq.N > 0 else 64 source_grid = LinearGrid( rho=np.array([1.0]), M=eq.M_grid, N=eq.N_grid, - NFP=source_NFP, + NFP=eq.NFP if eq.N > 0 else 64, sym=False, ) else: diff --git a/desc/objectives/getters.py b/desc/objectives/getters.py index d9151c6cef..9a7c980e32 100644 --- a/desc/objectives/getters.py +++ b/desc/objectives/getters.py @@ -1,8 +1,6 @@ """Utilities for getting standard groups of objectives and constraints.""" -import numpy as np - -from desc.utils import is_any_instance +from desc.utils import flatten_list, is_any_instance, unique_list from ._equilibrium import Energy, ForceBalance, HelicalForceBalance, RadialForceBalance from .linear_objectives import ( @@ -24,9 +22,9 @@ FixIonTemperature, FixIota, FixLambdaGauge, - FixParameter, FixPressure, FixPsi, + FixSheetCurrent, ) from .nae_utils import calc_zeroth_order_lambda, make_RZ_cons_1st_order from .objective_funs import ObjectiveFunction @@ -61,18 +59,15 @@ def get_equilibrium_objective(eq, mode="force", normalize=True): ------- objective, ObjectiveFunction An objective function with default force balance objectives. + """ + kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize} if mode == "energy": - objectives = Energy(eq=eq, normalize=normalize, normalize_target=normalize) + objectives = Energy(**kwargs) elif mode == "force": - objectives = ForceBalance( - eq=eq, normalize=normalize, normalize_target=normalize - ) + objectives = ForceBalance(**kwargs) elif mode == "forces": - objectives = ( - RadialForceBalance(eq=eq, normalize=normalize, normalize_target=normalize), - HelicalForceBalance(eq=eq, normalize=normalize, normalize_target=normalize), - ) + objectives = (RadialForceBalance(**kwargs), HelicalForceBalance(**kwargs)) else: raise ValueError("got an unknown equilibrium objective type '{}'".format(mode)) return ObjectiveFunction(objectives) @@ -96,20 +91,13 @@ def get_fixed_axis_constraints(eq, profiles=True, normalize=True): A list of the linear constraints used in fixed-axis problems. """ - constraints = ( - FixAxisR(eq=eq, normalize=normalize, normalize_target=normalize), - FixAxisZ(eq=eq, normalize=normalize, normalize_target=normalize), - FixPsi(eq=eq, normalize=normalize, normalize_target=normalize), - ) + kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize} + constraints = (FixAxisR(**kwargs), FixAxisZ(**kwargs), FixPsi(**kwargs)) if profiles: for name, con in _PROFILE_CONSTRAINTS.items(): if getattr(eq, name) is not None: - constraints += ( - con(eq=eq, normalize=normalize, normalize_target=normalize), - ) - for param in ["I", "G", "Phi_mn"]: - if np.array(getattr(eq, param, [])).size: - constraints += (FixParameter(eq, param),) + constraints += (con(**kwargs),) + constraints += (FixSheetCurrent(**kwargs),) return constraints @@ -132,20 +120,13 @@ def get_fixed_boundary_constraints(eq, profiles=True, normalize=True): A list of the linear constraints used in fixed-boundary problems. """ - constraints = ( - FixBoundaryR(eq=eq, normalize=normalize, normalize_target=normalize), - FixBoundaryZ(eq=eq, normalize=normalize, normalize_target=normalize), - FixPsi(eq=eq, normalize=normalize, normalize_target=normalize), - ) + kwargs = {"eq": eq, "normalize": normalize, "normalize_target": normalize} + constraints = (FixBoundaryR(**kwargs), FixBoundaryZ(**kwargs), FixPsi(**kwargs)) if profiles: for name, con in _PROFILE_CONSTRAINTS.items(): if getattr(eq, name) is not None: - constraints += ( - con(eq=eq, normalize=normalize, normalize_target=normalize), - ) - for param in ["I", "G", "Phi_mn"]: - if np.array(getattr(eq, param, [])).size: - constraints += (FixParameter(eq, param),) + constraints += (con(**kwargs),) + constraints += (FixSheetCurrent(**kwargs),) return constraints @@ -186,24 +167,18 @@ def get_NAE_constraints( ------- constraints, tuple of _Objectives A list of the linear constraints used in fixed-axis problems. + """ + kwargs = {"eq": desc_eq, "normalize": normalize, "normalize_target": normalize} if not isinstance(fix_lambda, bool): fix_lambda = int(fix_lambda) - constraints = ( - FixAxisR(eq=desc_eq, normalize=normalize, normalize_target=normalize), - FixAxisZ(eq=desc_eq, normalize=normalize, normalize_target=normalize), - FixPsi(eq=desc_eq, normalize=normalize, normalize_target=normalize), - ) + constraints = (FixAxisR(**kwargs), FixAxisZ(**kwargs), FixPsi(**kwargs)) if profiles: for name, con in _PROFILE_CONSTRAINTS.items(): if getattr(desc_eq, name) is not None: - constraints += ( - con(eq=desc_eq, normalize=normalize, normalize_target=normalize), - ) - for param in ["I", "G", "Phi_mn"]: - if np.array(getattr(desc_eq, param, [])).size: - constraints += (FixParameter(desc_eq, param),) + constraints += (con(**kwargs),) + constraints += (FixSheetCurrent(**kwargs),) if fix_lambda or (fix_lambda >= 0 and type(fix_lambda) is int): L_axis_constraints, _, _ = calc_zeroth_order_lambda( @@ -222,30 +197,32 @@ def get_NAE_constraints( def maybe_add_self_consistency(thing, constraints): """Add self consistency constraints if needed.""" + params = set(unique_list(flatten_list(thing.optimizable_params))[0]) + # Equilibrium - if ( - hasattr(thing, "Ra_n") - and hasattr(thing, "Za_n") - and hasattr(thing, "Rb_lmn") - and hasattr(thing, "Zb_lmn") - and hasattr(thing, "L_lmn") + if {"R_lmn", "Rb_lmn"} <= params and not is_any_instance( + constraints, BoundaryRSelfConsistency + ): + constraints += (BoundaryRSelfConsistency(eq=thing),) + if {"Z_lmn", "Zb_lmn"} <= params and not is_any_instance( + constraints, BoundaryZSelfConsistency + ): + constraints += (BoundaryZSelfConsistency(eq=thing),) + if {"L_lmn"} <= params and not is_any_instance(constraints, FixLambdaGauge): + constraints += (FixLambdaGauge(eq=thing),) + if {"R_lmn", "Ra_n"} <= params and not is_any_instance( + constraints, AxisRSelfConsistency + ): + constraints += (AxisRSelfConsistency(eq=thing),) + if {"Z_lmn", "Za_n"} <= params and not is_any_instance( + constraints, AxisZSelfConsistency ): - if not is_any_instance(constraints, BoundaryRSelfConsistency): - constraints += (BoundaryRSelfConsistency(eq=thing),) - if not is_any_instance(constraints, BoundaryZSelfConsistency): - constraints += (BoundaryZSelfConsistency(eq=thing),) - if not is_any_instance(constraints, FixLambdaGauge): - constraints += (FixLambdaGauge(eq=thing),) - if not is_any_instance(constraints, AxisRSelfConsistency): - constraints += (AxisRSelfConsistency(eq=thing),) - if not is_any_instance(constraints, AxisZSelfConsistency): - constraints += (AxisZSelfConsistency(eq=thing),) + constraints += (AxisZSelfConsistency(eq=thing),) # Curve - elif hasattr(thing, "shift") and hasattr(thing, "rotmat"): - if not is_any_instance(constraints, FixCurveShift): - constraints += (FixCurveShift(curve=thing),) - if not is_any_instance(constraints, FixCurveRotation): - constraints += (FixCurveRotation(curve=thing),) + if {"shift"} <= params and not is_any_instance(constraints, FixCurveShift): + constraints += (FixCurveShift(curve=thing),) + if {"rotmat"} <= params and not is_any_instance(constraints, FixCurveRotation): + constraints += (FixCurveRotation(curve=thing),) return constraints diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index 5b3c6fc274..ef761dfc36 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -7,19 +7,19 @@ """ import warnings -from abc import ABC import numpy as np from termcolor import colored -from desc.backend import jnp +from desc.backend import jnp, tree_leaves, tree_map, tree_structure from desc.basis import zernike_radial, zernike_radial_coeffs -from desc.utils import errorif, setdefault +from desc.utils import broadcast_tree, errorif, setdefault from .normalization import compute_scaling_factors from .objective_funs import _Objective +# TODO: get rid of this class and inherit from FixParameters instead? class _FixedObjective(_Objective): _fixed = True _linear = True @@ -44,6 +44,8 @@ def update_target(self, thing): def _parse_target_from_user( self, target_from_user, default_target, default_bounds, idx ): + # FIXME: add logic here to deal with `target_from_user` as a pytree? + # FIXME: does this actually need idx? if target_from_user is None: target = default_target bounds = default_bounds @@ -61,31 +63,31 @@ def _parse_target_from_user( return target, bounds -class FixParameter(_FixedObjective): - """Fix specific degrees of freedom associated with a given Optimizable object. +class FixParameters(_Objective): + """Fix specific degrees of freedom associated with a given Optimizable thing. Parameters ---------- thing : Optimizable Object whose degrees of freedom are being fixed. - params : str or list of str - Names of parameters to fix. Defaults to all parameters. - index : array-like or list of array-like - Indices to fix for each parameter in params. Use True to fix all indices. + params : nested list of dicts + Dict keys are the names of parameters to fix (str), and dict values are the + indices to fix for each corresponding parameter (int array). + Use True (False) instead of an int array to fix all (none) of the indices + for that parameter. + Must have the same pytree structure as thing.params_dict. + The default is to fix all indices of all parameters. target : dict of {float, ndarray}, optional Target value(s) of the objective. Only used if bounds is None. - Should have the same tree structure as thing.params. - Defaults to ``target=thing.params``. + Should have the same tree structure as thing.params. Defaults to things.params. bounds : tuple of dict {float, ndarray}, optional Lower and upper bounds on the objective. Overrides target. Should have the same tree structure as thing.params. - Defaults to ``target=thing.params``. weight : dict of {float, ndarray}, optional Weighting to apply to the Objective, relative to other Objectives. Should be a scalar or have the same tree structure as thing.params. normalize : bool, optional Whether to compute the error in physical units or non-dimensionalize. - Has no effect for this objective. normalize_target : bool, optional Whether target and bounds should be normalized before comparing to computed values. If `normalize` is `True` and the target is in physical units, @@ -93,33 +95,62 @@ class FixParameter(_FixedObjective): name : str, optional Name of the objective function. + Examples + -------- + .. code-block:: python + + import numpy as np + from desc.coils import ( + CoilSet, FourierPlanarCoil, FourierRZCoil, FourierXYZCoil, MixedCoilSet + ) + from desc.objectives import FixParameters + + # toroidal field coil set with 3 coils + tf_coil = FourierPlanarCoil(center=[2, 0, 0], normal=[0, 1, 0], r_n=[1]) + tf_coilset = CoilSet.linspaced_angular(tf_coil, n=3) + # vertical field coil set with 2 coils + vf_coil = FourierRZCoil(R_n=3, Z_n=-1) + vf_coilset = CoilSet.linspaced_linear( + vf_coil, displacement=[0, 0, 2], n=2, endpoint=True + ) + # another single coil + coil = FourierXYZCoil() + # full coil set with TF coils, VF coils, and other single coil + full_coilset = MixedCoilSet((tf_coilset, vf_coilset, xy_coil)) + + params = [ + [ + {"current": True}, # fix "current" in 1st TF coil + # fix "center" and one component of "normal" for 2nd TF coil + {"center": True, "normal": np.array([1])}, + {}, # fix nothing in 3rd TF coil + ], + {"shift": True, "rotmat": True}, # fix "shift" & "rotmat" for all VF coils + # fix specified indices of "X_n" and "Z_n", but not "Y_n", for other coil + {"X_n": np.array([1, 2]), "Y_n": False, "Z_n": np.array([0])}, + ] + obj = FixParameters(full_coilset, params) + """ _scalar = False _linear = True _fixed = True _units = "(~)" - _print_value_fmt = "Fixed parameter error: {:10.3e} " + _print_value_fmt = "Fixed parameters error: {:10.3e} " def __init__( self, thing, params=None, - indices=True, target=None, bounds=None, weight=1, - normalize=False, - normalize_target=False, - name=None, + normalize=True, + normalize_target=True, + name="Fixed parameters", ): - self._target_from_user = target - self._params = params = setdefault(params, thing.optimizable_params) - self._indices = indices - self._print_value_fmt = ( - f"Fixed parameter ({self._params}) error: " + "{:10.3e} " - ) - name = setdefault(name, f"Fixed parameter ({self._params})") + self._params = params super().__init__( things=thing, target=target, @@ -142,57 +173,25 @@ def build(self, use_jit=False, verbose=1): """ thing = self.things[0] - params = setdefault(self._params, thing.optimizable_params) - - if not isinstance(params, (list, tuple)): - params = [params] - for par in params: - errorif( - par not in thing.optimizable_params, - ValueError, - f"couldn't find parameter {par} in optimizable_parameters: " - + f"{thing.optimizable_params}", - ) - self._params = params - # replace indices=True with actual indices - if isinstance(self._indices, bool) and self._indices: - self._indices = [np.arange(thing.dimensions[par]) for par in self._params] - # make sure its iterable if only a scalar was passed in - if not isinstance(self._indices, (list, tuple)): - self._indices = [self._indices] - # replace idx=True with array of all indices, throwing an error if the length - # of indices is different from number of params - indices = {} - errorif( - len(self._params) != len(self._indices), - ValueError, - f"not enough indices ({len(self._indices)}) " - + f"for params ({len(self._params)})", - ) - for idx, par in zip(self._indices, self._params): - if isinstance(idx, bool) and idx: - idx = np.arange(thing.dimensions[par]) - indices[par] = np.atleast_1d(idx) - self._indices = indices - self._dim_f = sum(t.size for t in self._indices.values()) - - default_target = { - par: thing.params_dict[par][self._indices[par]] for par in params - } - default_bounds = None - target, bounds = self._parse_target_from_user( - self._target_from_user, default_target, default_bounds, indices - ) - if target: - self.target = jnp.concatenate([target[par] for par in params]) - self.bounds = None - else: - self.target = None - self.bounds = ( - jnp.concatenate([bounds[0][par] for par in params]), - jnp.concatenate([bounds[1][par] for par in params]), + # default params + default_params = tree_map(lambda dim: np.arange(dim), thing.dimensions) + self._params = setdefault(self._params, default_params) + self._params = broadcast_tree(self._params, default_params) + self._indices = tree_leaves(self._params) + assert tree_structure(self._params) == tree_structure(default_params) + + self._dim_f = sum(idx.size for idx in self._indices) + + # default target + if self.target is None and self.bounds is None: + self.target = np.concatenate( + [ + np.atleast_1d(param[idx]) + for param, idx in zip(tree_leaves(thing.params_dict), self._indices) + ] ) + super().build(use_jit=use_jit, verbose=verbose) def compute(self, params, constants=None): @@ -200,8 +199,8 @@ def compute(self, params, constants=None): Parameters ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict + params : list of dict + List of dictionaries of degrees of freedom, eg CoilSet.params_dict constants : dict Dictionary of constant data, eg transforms, profiles etc. Defaults to self.constants @@ -213,9 +212,25 @@ def compute(self, params, constants=None): """ return jnp.concatenate( - [params[par][self._indices[par]] for par in self._params] + [ + jnp.atleast_1d(param[idx]) + for param, idx in zip(tree_leaves(params), self._indices) + ] ) + def update_target(self, thing): + """Update target values using an Optimizable object. + + Parameters + ---------- + thing : Optimizable + Optimizable object that will be optimized to satisfy the Objective. + + """ + self.target = self.compute(thing.params_dict) + if self._use_jit: + self.jit() + class BoundaryRSelfConsistency(_Objective): """Ensure that the boundary and interior surfaces are self-consistent. @@ -236,7 +251,7 @@ class BoundaryRSelfConsistency(_Objective): _scalar = False _linear = True - _fixed = False + _fixed = False # not "diagonal", since it is fixing a sum _units = "(m)" _print_value_fmt = "R boundary self consistency error: {:10.3e} " @@ -334,7 +349,7 @@ class BoundaryZSelfConsistency(_Objective): _scalar = False _linear = True - _fixed = False + _fixed = False # not "diagonal", since it is fixing a sum _units = "(m)" _print_value_fmt = "Z boundary self consistency error: {:10.3e} " @@ -432,7 +447,7 @@ class AxisRSelfConsistency(_Objective): _scalar = False _linear = True - _fixed = False + _fixed = False # not "diagonal", since it is fixing a sum _print_value_fmt = "R axis self consistency error: {:10.3e} (m)" def __init__( @@ -518,7 +533,7 @@ class AxisZSelfConsistency(_Objective): _scalar = False _linear = True - _fixed = False + _fixed = False # not "diagonal", since it is fixing a sum _print_value_fmt = "Z axis self consistency error: {:10.3e} (m)" def __init__( @@ -587,7 +602,7 @@ def compute(self, params, constants=None): return f -class FixBoundaryR(_FixedObjective): +class FixBoundaryR(FixParameters): """Boundary condition on the R boundary parameters. Parameters @@ -614,16 +629,9 @@ class FixBoundaryR(_FixedObjective): Basis modes numbers [l,m,n] of boundary modes to fix. len(target) = len(weight) = len(modes). If True/False uses all/none of the profile modes. - surface_label : float, optional - Surface to enforce boundary conditions on. Defaults to Equilibrium.surface.rho name : str, optional Name of the objective function. - Notes - ----- - If specifying particular modes to fix, the rows of the resulting constraint `A` - matrix and `target` vector will be re-sorted according to the ordering of - `basis.modes` which may be different from the order that was passed in. """ _units = "(m)" @@ -638,14 +646,17 @@ def __init__( normalize=True, normalize_target=True, modes=True, - surface_label=None, name="lcfs R", ): - self._modes = modes - self._target_from_user = setdefault(bounds, target) - self._surface_label = surface_label + if isinstance(modes, bool): + indices = modes + else: + indices = np.array([], dtype=int) + for mode in np.atleast_2d(modes): + indices = np.append(indices, eq.surface.R_basis.get_idx(*mode)) super().__init__( - things=eq, + thing=eq, + params={"Rb_lmn": indices}, target=target, bounds=bounds, weight=weight, @@ -666,74 +677,13 @@ def build(self, use_jit=False, verbose=1): """ eq = self.things[0] - if self._modes is False or self._modes is None: # no modes - modes = np.array([[]], dtype=int) - idx = np.array([], dtype=int) - modes_idx = idx - elif self._modes is True: # all modes - modes = eq.surface.R_basis.modes - idx = np.arange(eq.surface.R_basis.num_modes) - modes_idx = idx - else: # specified modes - modes = np.atleast_2d(self._modes) - dtype = { - "names": ["f{}".format(i) for i in range(3)], - "formats": 3 * [modes.dtype], - } - _, idx, modes_idx = np.intersect1d( - eq.surface.R_basis.modes.astype(modes.dtype).view(dtype), - modes.view(dtype), - return_indices=True, - ) - # rearrange modes to match order of eq.surface.R_basis.modes - # and eq.surface.R_lmn, - # necessary so that the A matrix rows match up with the target b - modes = np.atleast_2d(eq.surface.R_basis.modes[idx, :]) - - if idx.size < modes.shape[0]: - warnings.warn( - colored( - "Some of the given modes are not in the surface, " - + "these modes will not be fixed.", - "yellow", - ) - ) - - self._dim_f = idx.size - # Rb_lmn -> Rb optimization space - self._A = np.eye(eq.surface.R_basis.num_modes)[idx, :] - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.surface.R_lmn[idx], None, modes_idx - ) - if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["a"] - super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute boundary R errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - boundary R errors. - - """ - return jnp.dot(self._A, params["Rb_lmn"]) - -class FixBoundaryZ(_FixedObjective): +class FixBoundaryZ(FixParameters): """Boundary condition on the Z boundary parameters. Parameters @@ -760,16 +710,9 @@ class FixBoundaryZ(_FixedObjective): Basis modes numbers [l,m,n] of boundary modes to fix. len(target) = len(weight) = len(modes). If True/False uses all/none of the surface modes. - surface_label : float, optional - Surface to enforce boundary conditions on. Defaults to Equilibrium.surface.rho name : str, optional Name of the objective function. - Notes - ----- - If specifying particular modes to fix, the rows of the resulting constraint `A` - matrix and `target` vector will be re-sorted according to the ordering of - `basis.modes` which may be different from the order that was passed in. """ _units = "(m)" @@ -784,14 +727,17 @@ def __init__( normalize=True, normalize_target=True, modes=True, - surface_label=None, name="lcfs Z", ): - self._modes = modes - self._target_from_user = setdefault(bounds, target) - self._surface_label = surface_label + if isinstance(modes, bool): + indices = modes + else: + indices = np.array([], dtype=int) + for mode in np.atleast_2d(modes): + indices = np.append(indices, eq.surface.Z_basis.get_idx(*mode)) super().__init__( - things=eq, + thing=eq, + params={"Zb_lmn": indices}, target=target, bounds=bounds, weight=weight, @@ -812,72 +758,11 @@ def build(self, use_jit=False, verbose=1): """ eq = self.things[0] - if self._modes is False or self._modes is None: # no modes - modes = np.array([[]], dtype=int) - idx = np.array([], dtype=int) - modes_idx = idx - elif self._modes is True: # all modes - modes = eq.surface.Z_basis.modes - idx = np.arange(eq.surface.Z_basis.num_modes) - modes_idx = idx - else: # specified modes - modes = np.atleast_2d(self._modes) - dtype = { - "names": ["f{}".format(i) for i in range(3)], - "formats": 3 * [modes.dtype], - } - _, idx, modes_idx = np.intersect1d( - eq.surface.Z_basis.modes.astype(modes.dtype).view(dtype), - modes.view(dtype), - return_indices=True, - ) - # rearrange modes to match order of eq.surface.Z_basis.modes - # and eq.surface.Z_lmn, - # necessary so that the A matrix rows match up with the target b - modes = np.atleast_2d(eq.surface.Z_basis.modes[idx, :]) - - if idx.size < modes.shape[0]: - warnings.warn( - colored( - "Some of the given modes are not in the surface, " - + "these modes will not be fixed.", - "yellow", - ) - ) - - self._dim_f = idx.size - # Zb_lmn -> Zb optimization space - self._A = np.eye(eq.surface.Z_basis.num_modes)[idx, :] - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.surface.Z_lmn[idx], None, modes_idx - ) - if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["a"] - super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute boundary Z errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - boundary Z errors. - - """ - return jnp.dot(self._A, params["Zb_lmn"]) - class FixLambdaGauge(_Objective): """Fixes gauge freedom for lambda: lambda(theta=0,zeta=0)=0. @@ -889,6 +774,10 @@ class FixLambdaGauge(_Objective): ---------- eq : Equilibrium Equilibrium that will be optimized to satisfy the Objective. + normalize : bool, optional + Has no effect for this objective. + normalize_target : bool, optional + Has no effect for this objective. name : str, optional Name of the objective function. @@ -896,13 +785,15 @@ class FixLambdaGauge(_Objective): _scalar = False _linear = True - _fixed = False + _fixed = False # not "diagonal", since it is fixing a sum _units = "(rad)" _print_value_fmt = "lambda gauge error: {:10.3e} " def __init__( self, eq, + normalize=True, + normalize_target=True, name="lambda gauge", ): super().__init__( @@ -910,8 +801,8 @@ def __init__( target=0, bounds=None, weight=1, - normalize=False, - normalize_target=False, + normalize=normalize, + normalize_target=normalize_target, name=name, ) @@ -953,7 +844,6 @@ def build(self, use_jit=False, verbose=1): self._A = A self._dim_f = self._A.shape[0] - super().build(use_jit=use_jit, verbose=verbose) def compute(self, params, constants=None): @@ -976,71 +866,49 @@ def compute(self, params, constants=None): return jnp.dot(self._A, params["L_lmn"]) -class FixThetaSFL(_Objective): +class FixThetaSFL(FixParameters): """Fixes lambda=0 so that poloidal angle is the SFL poloidal angle. Parameters ---------- eq : Equilibrium Equilibrium that will be optimized to satisfy the Objective. + weight : {float, ndarray}, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Has no effect for this objective. + normalize_target : bool, optional + Has no effect for this objective. name : str, optional Name of the objective function. """ - _scalar = False - _linear = True - _fixed = True _units = "(rad)" - _print_value_fmt = "Theta - Theta SFL error: {:10.3e} " - - def __init__(self, eq, name="Theta SFL"): - super().__init__(things=eq, target=0, weight=1, name=name) + _print_value_fmt = "theta - theta SFL error: {:10.3e} " - def build(self, use_jit=False, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - eq = self.things[0] - idx = np.arange(eq.L_basis.num_modes) - modes_idx = idx - self._idx = idx - - self._dim_f = modes_idx.size - - self.target = np.zeros_like(modes_idx) - - super().build(use_jit=use_jit, verbose=verbose) - - def compute(self, params, constants=None): - """Compute Theta SFL errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Theta - Theta SFL errors. - - """ - fixed_params = params["L_lmn"][self._idx] - return fixed_params + def __init__( + self, + eq, + weight=1, + normalize=True, + normalize_target=True, + name="theta SFL", + ): + super().__init__( + thing=eq, + params={"L_lmn": True}, + target=0, + bounds=None, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + name=name, + ) -class FixAxisR(_FixedObjective): +class FixAxisR(FixParameters): """Fixes magnetic axis R coefficients. Parameters @@ -1086,10 +954,15 @@ def __init__( modes=True, name="axis R", ): - self._modes = modes - self._target_from_user = setdefault(bounds, target) + if isinstance(modes, bool): + indices = modes + else: + indices = np.array([], dtype=int) + for mode in np.atleast_2d(modes): + indices = np.append(indices, eq.axis.R_basis.get_idx(*mode)) super().__init__( - things=eq, + thing=eq, + params={"Ra_n": indices}, target=target, bounds=bounds, weight=weight, @@ -1110,75 +983,13 @@ def build(self, use_jit=False, verbose=1): """ eq = self.things[0] - - if self._modes is False or self._modes is None: # no modes - modes = np.array([[]], dtype=int) - idx = np.array([], dtype=int) - modes_idx = idx - elif self._modes is True: # all modes - modes = eq.axis.R_basis.modes - idx = np.arange(eq.axis.R_basis.num_modes) - modes_idx = idx - else: # specified modes - modes = np.atleast_1d(self._modes) - dtype = { - "names": ["f{}".format(i) for i in range(3)], - "formats": 3 * [modes.dtype], - } - _, idx, modes_idx = np.intersect1d( - eq.axis.R_basis.modes.astype(modes.dtype).view(dtype), - modes.view(dtype), - return_indices=True, - ) - # rearrange modes to match order of eq.axis.R_basis.modes and eq.axis.R_n, - # necessary so that the A matrix rows match up with the target b - modes = np.atleast_2d(eq.axis.R_basis.modes[idx, :]) - - if idx.size < modes.shape[0]: - warnings.warn( - colored( - "Some of the given modes are not in the axis, " - + "these modes will not be fixed.", - "yellow", - ) - ) - - self._dim_f = idx.size - # Ra_lmn -> Ra optimization space - self._A = np.eye(eq.axis.R_basis.num_modes)[idx, :] - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.axis.R_n[idx], None, modes_idx - ) - if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["a"] - super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute axis R errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Axis R errors. - - """ - f = jnp.dot(self._A, params["Ra_n"]) - return f - -class FixAxisZ(_FixedObjective): +class FixAxisZ(FixParameters): """Fixes magnetic axis Z coefficients. Parameters @@ -1224,10 +1035,15 @@ def __init__( modes=True, name="axis Z", ): - self._modes = modes - self._target_from_user = setdefault(bounds, target) + if isinstance(modes, bool): + indices = modes + else: + indices = np.array([], dtype=int) + for mode in np.atleast_2d(modes): + indices = np.append(indices, eq.axis.Z_basis.get_idx(*mode)) super().__init__( - things=eq, + thing=eq, + params={"Za_n": indices}, target=target, bounds=bounds, weight=weight, @@ -1248,75 +1064,13 @@ def build(self, use_jit=False, verbose=1): """ eq = self.things[0] - - if self._modes is False or self._modes is None: # no modes - modes = np.array([[]], dtype=int) - idx = np.array([], dtype=int) - modes_idx = idx - elif self._modes is True: # all modes - modes = eq.axis.Z_basis.modes - idx = np.arange(eq.axis.Z_basis.num_modes) - modes_idx = idx - else: # specified modes - modes = np.atleast_1d(self._modes) - dtype = { - "names": ["f{}".format(i) for i in range(3)], - "formats": 3 * [modes.dtype], - } - _, idx, modes_idx = np.intersect1d( - eq.axis.Z_basis.modes.astype(modes.dtype).view(dtype), - modes.view(dtype), - return_indices=True, - ) - # rearrange modes to match order of eq.axis.Z_basis.modes and eq.axis.Z_n, - # necessary so that the A matrix rows match up with the target b - modes = np.atleast_2d(eq.axis.Z_basis.modes[idx, :]) - - if idx.size < modes.shape[0]: - warnings.warn( - colored( - "Some of the given modes are not in the axis, " - + "these modes will not be fixed.", - "yellow", - ) - ) - - self._dim_f = idx.size - # Za_lmn -> Za optimization space - self._A = np.eye(eq.axis.Z_basis.num_modes)[idx, :] - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.axis.Z_n[idx], None, modes_idx - ) - if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["a"] - super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute axis Z errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Axis Z errors. - - """ - f = jnp.dot(self._A, params["Za_n"]) - return f - -class FixModeR(_FixedObjective): +class FixModeR(FixParameters): """Fixes Fourier-Zernike R coefficients. Parameters @@ -1342,8 +1096,7 @@ class FixModeR(_FixedObjective): modes : ndarray, optional Basis modes numbers [l,m,n] of Fourier-Zernike modes to fix. len(target) = len(weight) = len(modes). - If True uses all of the Equilibrium's modes. - Must be either True or specified as an array + If True/False uses all/none of the basis modes. name : str, optional Name of the objective function. @@ -1361,16 +1114,17 @@ def __init__( normalize=True, normalize_target=True, modes=True, - name="Fix Mode R", + name="fix mode R", ): - self._modes = modes - if modes is None or modes is False: - raise ValueError( - f"modes kwarg must be specified or True with FixModeR got {modes}" - ) - self._target_from_user = setdefault(bounds, target) + if isinstance(modes, bool): + indices = modes + else: + indices = np.array([], dtype=int) + for mode in np.atleast_2d(modes): + indices = np.append(indices, eq.R_basis.get_idx(*mode)) super().__init__( - things=eq, + thing=eq, + params={"R_lmn": indices}, target=target, bounds=bounds, weight=weight, @@ -1391,60 +1145,13 @@ def build(self, use_jit=False, verbose=1): """ eq = self.things[0] - if self._modes is True: # all modes - modes = eq.R_basis.modes - self._idx = np.arange(eq.R_basis.num_modes) - modes_idx = self._idx - else: # specified modes - modes = np.atleast_2d(self._modes) - dtype = { - "names": ["f{}".format(i) for i in range(3)], - "formats": 3 * [modes.dtype], - } - _, self._idx, modes_idx = np.intersect1d( - eq.R_basis.modes.astype(modes.dtype).view(dtype), - modes.view(dtype), - return_indices=True, - ) - if self._idx.size < modes.shape[0]: - warnings.warn( - colored( - "Some of the given modes are not in the basis, " - + "these modes will not be fixed.", - "yellow", - ) - ) - - self._dim_f = modes_idx.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.R_lmn[self._idx], None, modes_idx - ) - + if self._normalize: + scales = compute_scaling_factors(eq) + self._normalization = scales["a"] super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute Fixed mode R errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed mode R errors. - - """ - fixed_params = params["R_lmn"][self._idx] - return fixed_params - -class FixModeZ(_FixedObjective): +class FixModeZ(FixParameters): """Fixes Fourier-Zernike Z coefficients. Parameters @@ -1470,8 +1177,7 @@ class FixModeZ(_FixedObjective): modes : ndarray, optional Basis modes numbers [l,m,n] of Fourier-Zernike modes to fix. len(target) = len(weight) = len(modes). - If True uses all of the Equilibrium's modes. - Must be either True or specified as an array + If True/False uses all/none of the basis modes. name : str, optional Name of the objective function. @@ -1489,16 +1195,17 @@ def __init__( normalize=True, normalize_target=True, modes=True, - name="Fix Mode Z", + name="fix mode Z", ): - self._modes = modes - if modes is None or modes is False: - raise ValueError( - f"modes kwarg must be specified or True with FixModeZ got {modes}" - ) - self._target_from_user = setdefault(bounds, target) + if isinstance(modes, bool): + indices = modes + else: + indices = np.array([], dtype=int) + for mode in np.atleast_2d(modes): + indices = np.append(indices, eq.Z_basis.get_idx(*mode)) super().__init__( - things=eq, + thing=eq, + params={"Z_lmn": indices}, target=target, bounds=bounds, weight=weight, @@ -1519,60 +1226,13 @@ def build(self, use_jit=False, verbose=1): """ eq = self.things[0] - if self._modes is True: # all modes - modes = eq.Z_basis.modes - self._idx = np.arange(eq.Z_basis.num_modes) - modes_idx = self._idx - else: # specified modes - modes = np.atleast_2d(self._modes) - dtype = { - "names": ["f{}".format(i) for i in range(3)], - "formats": 3 * [modes.dtype], - } - _, self._idx, modes_idx = np.intersect1d( - eq.Z_basis.modes.astype(modes.dtype).view(dtype), - modes.view(dtype), - return_indices=True, - ) - if self._idx.size < modes.shape[0]: - warnings.warn( - colored( - "Some of the given modes are not in the basis, " - + "these modes will not be fixed.", - "yellow", - ) - ) - - self._dim_f = modes_idx.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.Z_lmn[self._idx], None, modes_idx - ) - + if self._normalize: + scales = compute_scaling_factors(eq) + self._normalization = scales["a"] super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute Fixed mode Z errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed mode Z errors. - - """ - fixed_params = params["Z_lmn"][self._idx] - return fixed_params - -class FixModeLambda(_FixedObjective): +class FixModeLambda(FixParameters): """Fixes Fourier-Zernike lambda coefficients. Parameters @@ -1590,24 +1250,21 @@ class FixModeLambda(_FixedObjective): weight : float, ndarray, optional Weighting to apply to the Objective, relative to other Objectives. Must be broadcastable to Objective.dim_f. - normalize : bool - Whether to compute the error in physical units or non-dimensionalize. - normalize_target : bool - Whether target should be normalized before comparing to computed values. - if `normalize` is `True` and the target is in physical units, this should also - be set to True. + normalize : bool, optional + Has no effect for this objective. + normalize_target : bool, optional + Has no effect for this objective. modes : ndarray, optional Basis modes numbers [l,m,n] of Fourier-Zernike modes to fix. len(target) = len(weight) = len(modes). - If True uses all of the Equilibrium's modes. - Must be either True or specified as an array + If True/False uses all/none of the basis modes. name : str Name of the objective function. """ _units = "(rad)" - _print_value_fmt = "Fixed-lambda modes error: {:10.3e} " + _print_value_fmt = "Fixed lambda modes error: {:10.3e} " def __init__( self, @@ -1618,89 +1275,25 @@ def __init__( normalize=True, normalize_target=True, modes=True, - name="Fix Mode lambda", + name="fix mode lambda", ): - self._modes = modes - if modes is None or modes is False: - raise ValueError( - "modes kwarg must be specified" - + f" or True with FixModeLambda got {modes}" - ) - self._target_from_user = target + if isinstance(modes, bool): + indices = modes + else: + indices = np.array([], dtype=int) + for mode in np.atleast_2d(modes): + indices = np.append(indices, eq.L_basis.get_idx(*mode)) super().__init__( - things=eq, + thing=eq, + params={"L_lmn": indices}, target=target, bounds=bounds, weight=weight, - name=name, normalize=normalize, normalize_target=normalize_target, + name=name, ) - def build(self, use_jit=False, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - eq = self.things[0] - if self._modes is True: # all modes - modes = eq.L_basis.modes - self._idx = np.arange(eq.L_basis.num_modes) - modes_idx = self._idx - else: # specified modes - modes = np.atleast_2d(self._modes) - dtype = { - "names": ["f{}".format(i) for i in range(3)], - "formats": 3 * [modes.dtype], - } - _, self._idx, modes_idx = np.intersect1d( - eq.L_basis.modes.astype(modes.dtype).view(dtype), - modes.view(dtype), - return_indices=True, - ) - if self._idx.size < modes.shape[0]: - warnings.warn( - colored( - "Some of the given modes are not in the basis, " - + "these modes will not be fixed.", - "yellow", - ) - ) - - self._dim_f = modes_idx.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.L_lmn[self._idx], None, modes_idx - ) - - super().build(use_jit=use_jit, verbose=verbose) - - def compute(self, params, constants=None): - """Compute Fixed mode lambda errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed mode lambda errors. - - """ - fixed_params = params["L_lmn"][self._idx] - return fixed_params - class FixSumModesR(_FixedObjective): """Fixes a linear sum of Fourier-Zernike R coefficients. @@ -1738,7 +1331,7 @@ class FixSumModesR(_FixedObjective): """ - _fixed = False # not "diagonal", since its fixing a sum + _fixed = False # not "diagonal", since it is fixing a sum _units = "(m)" _print_value_fmt = "Fixed-R sum modes error: {:10.3e} " @@ -1904,7 +1497,7 @@ class FixSumModesZ(_FixedObjective): """ - _fixed = False # not "diagonal", since its fixing a sum + _fixed = False # not "diagonal", since it is fixing a sum _units = "(m)" _print_value_fmt = "Fixed-Z sum modes error: {:10.3e} " @@ -2051,12 +1644,10 @@ class FixSumModesLambda(_FixedObjective): weight : {float, ndarray}, optional Weighting to apply to the Objective, relative to other Objectives. Must be broadcastable to to Objective.dim_f. - normalize : bool - Whether to compute the error in physical units or non-dimensionalize. - normalize_target : bool - Whether target should be normalized before comparing to computed values. - if `normalize` is `True` and the target is in physical units, this should also - be set to True. + normalize : bool, optional + Has no effect for this objective. + normalize_target : bool, optional + Has no effect for this objective. sum_weight : float, ndarray, optional Weights on the coefficients in the sum, should be same length as modes. Defaults to 1 i.e. target = 1*L_111 + 1*L_222... @@ -2072,7 +1663,7 @@ class FixSumModesLambda(_FixedObjective): """ - _fixed = False # not "diagonal", since its fixing a sum + _fixed = False # not "diagonal", since it is fixing a sum _units = "(rad)" _print_value_fmt = "Fixed-lambda sum modes error: {:10.3e} " @@ -2115,9 +1706,9 @@ def __init__( target=target, bounds=bounds, weight=weight, - name=name, normalize=normalize, normalize_target=normalize_target, + name=name, ) def build(self, use_jit=False, verbose=1): @@ -2204,8 +1795,8 @@ def compute(self, params, constants=None): return f -class _FixProfile(_FixedObjective, ABC): - """Fixes profile coefficients (or values, for SplineProfile). +class FixPressure(FixParameters): + """Fixes pressure coefficients. Parameters ---------- @@ -2213,10 +1804,11 @@ class _FixProfile(_FixedObjective, ABC): Equilibrium that will be optimized to satisfy the Objective. target : {float, ndarray}, optional Target value(s) of the objective. Only used if bounds is None. - Must be broadcastable to Objective.dim_f. + Must be broadcastable to Objective.dim_f. Defaults to ``target=eq.p_l``. bounds : tuple of {float, ndarray}, optional Lower and upper bounds on the objective. Overrides target. - Both bounds must be broadcastable to to Objective.dim_f + Both bounds must be broadcastable to to Objective.dim_f. + Defaults to ``target=eq.p_l``. weight : {float, ndarray}, optional Weighting to apply to the Objective, relative to other Objectives. Must be broadcastable to to Objective.dim_f @@ -2226,9 +1818,7 @@ class _FixProfile(_FixedObjective, ABC): Whether target and bounds should be normalized before comparing to computed values. If `normalize` is `True` and the target is in physical units, this should also be set to True. - profile : Profile, optional - Profile containing the radial modes to evaluate at. - indices : ndarray or Bool, optional + indices : ndarray or bool, optional indices of the Profile.params array to fix. (e.g. indices corresponding to modes for a PowerSeriesProfile or indices corresponding to knots for a SplineProfile). @@ -2239,7 +1829,8 @@ class _FixProfile(_FixedObjective, ABC): """ - _print_value_fmt = "Fix-profile error: {:10.3e} " + _units = "(Pa)" + _print_value_fmt = "Fixed pressure profile error: {:10.3e} " def __init__( self, @@ -2249,15 +1840,12 @@ def __init__( weight=1, normalize=True, normalize_target=True, - profile=None, indices=True, - name="", + name="fixed pressure", ): - self._profile = profile - self._indices = indices - self._target_from_user = setdefault(bounds, target) super().__init__( - things=eq, + thing=eq, + params={"p_l": indices}, target=target, bounds=bounds, weight=weight, @@ -2266,15 +1854,11 @@ def __init__( name=name, ) - def build(self, eq, profile, use_jit=False, verbose=1): + def build(self, use_jit=False, verbose=1): """Build constant arrays. Parameters ---------- - eq : Equilibrium - Equilibrium that will be optimized to satisfy the Objective. - profile : Profile, optional - profile to fix use_jit : bool, optional Whether to just-in-time compile the objective and derivatives. verbose : int, optional @@ -2282,134 +1866,19 @@ def build(self, eq, profile, use_jit=False, verbose=1): """ eq = self.things[0] - if self._profile is None or self._profile.params.size != eq.L + 1: - self._profile = profile + if eq.pressure is None: + raise RuntimeError( + "Attempting to fix pressure on an Equilibrium with no " + + "pressure profile assigned." + ) + if self._normalize: + scales = compute_scaling_factors(eq) + self._normalization = scales["p"] + super().build(use_jit=use_jit, verbose=verbose) - # find indices to fix - if self._indices is False or self._indices is None: # no indices to fix - self._idx = np.array([], dtype=int) - elif self._indices is True: # all indices of Profile.params - self._idx = np.arange(np.size(self._profile.params)) - else: # specified indices - self._idx = np.atleast_1d(self._indices) - self._dim_f = self._idx.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, self._profile.params[self._idx], None, self._idx - ) - - super().build(use_jit=use_jit, verbose=verbose) - - -class FixPressure(_FixProfile): - """Fixes pressure coefficients. - - Parameters - ---------- - eq : Equilibrium - Equilibrium that will be optimized to satisfy the Objective. - target : {float, ndarray}, optional - Target value(s) of the objective. Only used if bounds is None. - Must be broadcastable to Objective.dim_f. Defaults to ``target=eq.p_l``. - bounds : tuple of {float, ndarray}, optional - Lower and upper bounds on the objective. Overrides target. - Both bounds must be broadcastable to to Objective.dim_f. - Defaults to ``target=eq.p_l``. - weight : {float, ndarray}, optional - Weighting to apply to the Objective, relative to other Objectives. - Must be broadcastable to to Objective.dim_f - normalize : bool, optional - Whether to compute the error in physical units or non-dimensionalize. - normalize_target : bool, optional - Whether target and bounds should be normalized before comparing to computed - values. If `normalize` is `True` and the target is in physical units, - this should also be set to True. - profile : Profile, optional - Profile containing the radial modes to evaluate at. - indices : ndarray or bool, optional - indices of the Profile.params array to fix. - (e.g. indices corresponding to modes for a PowerSeriesProfile or indices - corresponding to knots for a SplineProfile). - Must have len(target) = len(weight) = len(indices). - If True/False uses all/none of the Profile.params indices. - name : str, optional - Name of the objective function. - - """ - - _units = "(Pa)" - _print_value_fmt = "Fixed-pressure profile error: {:10.3e} " - - def __init__( - self, - eq, - target=None, - bounds=None, - weight=1, - normalize=True, - normalize_target=True, - profile=None, - indices=True, - name="fixed-pressure", - ): - super().__init__( - eq=eq, - target=target, - bounds=bounds, - weight=weight, - normalize=normalize, - normalize_target=normalize_target, - profile=profile, - indices=indices, - name=name, - ) - - def build(self, use_jit=False, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - eq = self.things[0] - if eq.pressure is None: - raise RuntimeError( - "Attempting to fix pressure on an equilibrium with no " - + "pressure profile assigned" - ) - profile = eq.pressure - if self._normalize: - scales = compute_scaling_factors(eq) - self._normalization = scales["p"] - super().build(eq, profile, use_jit, verbose) - - def compute(self, params, constants=None): - """Compute fixed pressure profile errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["p_l"][self._idx] - - -class FixAnisotropy(_FixProfile): - """Fixes anisotropic pressure coefficients. +class FixAnisotropy(FixParameters): + """Fixes anisotropic pressure coefficients. Parameters ---------- @@ -2425,14 +1894,10 @@ class FixAnisotropy(_FixProfile): weight : {float, ndarray}, optional Weighting to apply to the Objective, relative to other Objectives. Must be broadcastable to to Objective.dim_f - normalize : bool - Whether to compute the error in physical units or non-dimensionalize. - normalize_target : bool - Whether target and bounds should be normalized before comparing to computed - values. If `normalize` is `True` and the target is in physical units, - this should also be set to True. - profile : Profile, optional - Profile containing the radial modes to evaluate at. + normalize : bool, optional + Has no effect for this objective. + normalize_target : bool, optional + Has no effect for this objective. indices : ndarray or bool, optional indices of the Profile.params array to fix. (e.g. indices corresponding to modes for a PowerSeriesProfile or indices @@ -2445,7 +1910,7 @@ class FixAnisotropy(_FixProfile): """ _units = "(dimensionless)" - _print_value_fmt = "Fixed-anisotropy profile error: {:10.3e} " + _print_value_fmt = "Fixed anisotropy profile error: {:10.3e} " def __init__( self, @@ -2455,23 +1920,21 @@ def __init__( weight=1, normalize=True, normalize_target=True, - profile=None, indices=True, - name="fixed-anisotropy", + name="fixed anisotropy", ): super().__init__( - eq=eq, + thing=eq, + params={"a_lmn": indices}, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, - profile=profile, - indices=indices, name=name, ) - def build(self, use_jit=True, verbose=1): + def build(self, use_jit=False, verbose=1): """Build constant arrays. Parameters @@ -2485,33 +1948,13 @@ def build(self, use_jit=True, verbose=1): eq = self.things[0] if eq.anisotropy is None: raise RuntimeError( - "Attempting to fix anisotropy on an equilibrium with no " - + "anisotropy profile assigned" + "Attempting to fix anisotropy on an Equilibrium with no " + + "anisotropy profile assigned." ) - profile = eq.anisotropy - super().build(eq, profile, use_jit, verbose) - - def compute(self, params, constants=None): - """Compute fixed pressure profile errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["a_lmn"][self._idx] + super().build(use_jit=use_jit, verbose=verbose) -class FixIota(_FixProfile): +class FixIota(FixParameters): """Fixes rotational transform coefficients. Parameters @@ -2529,14 +1972,9 @@ class FixIota(_FixProfile): Weighting to apply to the Objective, relative to other Objectives. Must be broadcastable to to Objective.dim_f normalize : bool, optional - Whether to compute the error in physical units or non-dimensionalize. Has no effect for this objective. normalize_target : bool, optional - Whether target and bounds should be normalized before comparing to computed - values. If `normalize` is `True` and the target is in physical units, - this should also be set to True. Has no effect for this objective. - profile : Profile, optional - Profile containing the radial modes to evaluate at. + Has no effect for this objective. indices : ndarray or bool, optional indices of the Profile.params array to fix. (e.g. indices corresponding to modes for a PowerSeriesProfile or indices. @@ -2549,7 +1987,7 @@ class FixIota(_FixProfile): """ _units = "(dimensionless)" - _print_value_fmt = "Fixed-iota profile error: {:10.3e} " + _print_value_fmt = "Fixed iota profile error: {:10.3e} " def __init__( self, @@ -2557,21 +1995,19 @@ def __init__( target=None, bounds=None, weight=1, - normalize=False, - normalize_target=False, - profile=None, + normalize=True, + normalize_target=True, indices=True, - name="fixed-iota", + name="fixed iota", ): super().__init__( - eq=eq, + thing=eq, + params={"i_l": indices}, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, - profile=profile, - indices=indices, name=name, ) @@ -2589,33 +2025,13 @@ def build(self, use_jit=False, verbose=1): eq = self.things[0] if eq.iota is None: raise RuntimeError( - "Attempt to fix rotational transform on an equilibrium with no " - + "rotational transform profile assigned" + "Attempting to fix iota on an Equilibrium with no " + + "iota profile assigned." ) - profile = eq.iota - super().build(eq, profile, use_jit, verbose) - - def compute(self, params, constants=None): - """Compute fixed iota errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["i_l"][self._idx] + super().build(use_jit=use_jit, verbose=verbose) -class FixCurrent(_FixProfile): +class FixCurrent(FixParameters): """Fixes toroidal current profile coefficients. Parameters @@ -2638,8 +2054,6 @@ class FixCurrent(_FixProfile): Whether target and bounds should be normalized before comparing to computed values. If `normalize` is `True` and the target is in physical units, this should also be set to True. - profile : Profile, optional - Profile containing the radial modes to evaluate at. indices : ndarray or bool, optional indices of the Profile.params array to fix. (e.g. indices corresponding to modes for a PowerSeriesProfile or indices @@ -2652,7 +2066,7 @@ class FixCurrent(_FixProfile): """ _units = "(A)" - _print_value_fmt = "Fixed-current profile error: {:10.3e} " + _print_value_fmt = "Fixed current profile error: {:10.3e} " def __init__( self, @@ -2662,19 +2076,17 @@ def __init__( weight=1, normalize=True, normalize_target=True, - profile=None, indices=True, - name="fixed-current", + name="fixed current", ): super().__init__( - eq=eq, + thing=eq, + params={"c_l": indices}, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, - profile=profile, - indices=indices, name=name, ) @@ -2692,36 +2104,16 @@ def build(self, use_jit=False, verbose=1): eq = self.things[0] if eq.current is None: raise RuntimeError( - "Attempting to fix toroidal current on an equilibrium with no " - + "current profile assigned" + "Attempting to fix current on an Equilibrium with no " + + "current profile assigned." ) - profile = eq.current if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["I"] - super().build(eq, profile, use_jit, verbose) - - def compute(self, params, constants=None): - """Compute fixed current errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["c_l"][self._idx] + super().build(use_jit=use_jit, verbose=verbose) -class FixElectronTemperature(_FixProfile): +class FixElectronTemperature(FixParameters): """Fixes electron temperature profile coefficients. Parameters @@ -2744,8 +2136,6 @@ class FixElectronTemperature(_FixProfile): Whether target and bounds should be normalized before comparing to computed values. If `normalize` is `True` and the target is in physical units, this should also be set to True. - profile : Profile, optional - Profile containing the radial modes to evaluate at. indices : ndarray or bool, optional indices of the Profile.params array to fix. (e.g. indices corresponding to modes for a PowerSeriesProfile or indices @@ -2758,7 +2148,7 @@ class FixElectronTemperature(_FixProfile): """ _units = "(eV)" - _print_value_fmt = "Fixed-electron-temperature profile error: {:10.3e} " + _print_value_fmt = "Fixed electron temperature profile error: {:10.3e} " def __init__( self, @@ -2768,23 +2158,21 @@ def __init__( weight=1, normalize=True, normalize_target=True, - profile=None, indices=True, - name="fixed-electron-temperature", + name="fixed electron temperature", ): super().__init__( - eq=eq, + thing=eq, + params={"Te_l": indices}, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, - profile=profile, - indices=indices, name=name, ) - def build(self, use_jit=True, verbose=1): + def build(self, use_jit=False, verbose=1): """Build constant arrays. Parameters @@ -2798,36 +2186,16 @@ def build(self, use_jit=True, verbose=1): eq = self.things[0] if eq.electron_temperature is None: raise RuntimeError( - "Attempting to fix electron temperature on an equilibrium with no " - + "electron temperature profile assigned" + "Attempting to fix electron temperature on an Equilibrium with no " + + "electron temperature profile assigned." ) - profile = eq.electron_temperature if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["T"] - super().build(eq, profile, use_jit, verbose) - - def compute(self, params, constants=None): - """Compute fixed electron temperature errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["Te_l"][self._idx] + super().build(use_jit=use_jit, verbose=verbose) -class FixElectronDensity(_FixProfile): +class FixElectronDensity(FixParameters): """Fixes electron density profile coefficients. Parameters @@ -2864,7 +2232,7 @@ class FixElectronDensity(_FixProfile): """ _units = "(m^-3)" - _print_value_fmt = "Fixed-electron-density profile error: {:10.3e} " + _print_value_fmt = "Fixed electron density profile error: {:10.3e} " def __init__( self, @@ -2874,23 +2242,21 @@ def __init__( weight=1, normalize=True, normalize_target=True, - profile=None, indices=True, - name="fixed-electron-density", + name="fixed electron density", ): super().__init__( - eq=eq, + thing=eq, + params={"ne_l": indices}, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, - profile=profile, - indices=indices, name=name, ) - def build(self, use_jit=True, verbose=1): + def build(self, use_jit=False, verbose=1): """Build constant arrays. Parameters @@ -2904,36 +2270,16 @@ def build(self, use_jit=True, verbose=1): eq = self.things[0] if eq.electron_density is None: raise RuntimeError( - "Attempting to fix electron density on an equilibrium with no " - + "electron density profile assigned" + "Attempting to fix electron density on an Equilibrium with no " + + "electron density profile assigned." ) - profile = eq.electron_density if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["n"] - super().build(eq, profile, use_jit, verbose) - - def compute(self, params, constants=None): - """Compute fixed electron density errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["ne_l"][self._idx] + super().build(use_jit=use_jit, verbose=verbose) -class FixIonTemperature(_FixProfile): +class FixIonTemperature(FixParameters): """Fixes ion temperature profile coefficients. Parameters @@ -2956,8 +2302,6 @@ class FixIonTemperature(_FixProfile): Whether target and bounds should be normalized before comparing to computed values. If `normalize` is `True` and the target is in physical units, this should also be set to True. - profile : Profile, optional - Profile containing the radial modes to evaluate at. indices : ndarray or bool, optional indices of the Profile.params array to fix. (e.g. indices corresponding to modes for a PowerSeriesProfile or indices @@ -2970,7 +2314,7 @@ class FixIonTemperature(_FixProfile): """ _units = "(eV)" - _print_value_fmt = "Fixed-ion-temperature profile error: {:10.3e} " + _print_value_fmt = "Fixed ion temperature profile error: {:10.3e} " def __init__( self, @@ -2980,23 +2324,21 @@ def __init__( weight=1, normalize=True, normalize_target=True, - profile=None, indices=True, - name="fixed-ion-temperature", + name="fixed ion temperature", ): super().__init__( - eq=eq, + thing=eq, + params={"Ti_l": indices}, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, - profile=profile, - indices=indices, name=name, ) - def build(self, use_jit=True, verbose=1): + def build(self, use_jit=False, verbose=1): """Build constant arrays. Parameters @@ -3010,36 +2352,16 @@ def build(self, use_jit=True, verbose=1): eq = self.things[0] if eq.ion_temperature is None: raise RuntimeError( - "Attempting to fix ion temperature on an equilibrium with no " - + "ion temperature profile assigned" + "Attempting to fix ion temperature on an Equilibrium with no " + + "ion temperature profile assigned." ) - profile = eq.ion_temperature if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["T"] - super().build(eq, profile, use_jit, verbose) - - def compute(self, params, constants=None): - """Compute fixed ion temperature errors. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["Ti_l"][self._idx] + super().build(use_jit=use_jit, verbose=verbose) -class FixAtomicNumber(_FixProfile): +class FixAtomicNumber(FixParameters): """Fixes effective atomic number profile coefficients. Parameters @@ -3057,14 +2379,9 @@ class FixAtomicNumber(_FixProfile): Weighting to apply to the Objective, relative to other Objectives. Must be broadcastable to to Objective.dim_f normalize : bool, optional - Whether to compute the error in physical units or non-dimensionalize. Has no effect for this objective. normalize_target : bool, optional - Whether target and bounds should be normalized before comparing to computed - values. If `normalize` is `True` and the target is in physical units, - this should also be set to True. Has no effect for this objective. - profile : Profile, optional - Profile containing the radial modes to evaluate at. + Has no effect for this objective. indices : ndarray or bool, optional indices of the Profile.params array to fix. (e.g. indices corresponding to modes for a PowerSeriesProfile or indices @@ -3077,7 +2394,7 @@ class FixAtomicNumber(_FixProfile): """ _units = "(dimensionless)" - _print_value_fmt = "Fixed-atomic-number profile error: {:10.3e} " + _print_value_fmt = "Fixed atomic number profile error: {:10.3e} " def __init__( self, @@ -3085,25 +2402,23 @@ def __init__( target=None, bounds=None, weight=1, - normalize=False, - normalize_target=False, - profile=None, + normalize=True, + normalize_target=True, indices=True, - name="fixed-atomic-number", + name="fixed atomic number", ): super().__init__( - eq=eq, + thing=eq, + params={"Zeff_l": indices}, target=target, bounds=bounds, weight=weight, normalize=normalize, normalize_target=normalize_target, - profile=profile, - indices=indices, name=name, ) - def build(self, use_jit=True, verbose=1): + def build(self, use_jit=False, verbose=1): """Build constant arrays. Parameters @@ -3117,34 +2432,14 @@ def build(self, use_jit=True, verbose=1): eq = self.things[0] if eq.atomic_number is None: raise RuntimeError( - "Attempting to fix atomic number on an equilibrium with no " - + "atomic number profile assigned" + "Attempting to fix atomic number on an Equilibrium with no " + + "atomic_number profile assigned." ) - profile = eq.atomic_number - super().build(eq, profile, use_jit, verbose) + super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute fixed atomic number errors. - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed profile errors. - - """ - return params["Zeff_l"][self._idx] - - -class FixPsi(_FixedObjective): - """Fixes total toroidal magnetic flux within the last closed flux surface. +class FixPsi(FixParameters): + """Fixes total toroidal magnetic flux within the last closed flux surface. Parameters ---------- @@ -3172,7 +2467,7 @@ class FixPsi(_FixedObjective): """ _units = "(Wb)" - _print_value_fmt = "Fixed-Psi error: {:10.3e} " + _print_value_fmt = "Fixed Psi error: {:10.3e} " def __init__( self, @@ -3182,11 +2477,11 @@ def __init__( weight=1, normalize=True, normalize_target=True, - name="fixed-Psi", + name="fixed Psi", ): - self._target_from_user = setdefault(bounds, target) super().__init__( - things=eq, + thing=eq, + params={"Psi": True}, target=target, bounds=bounds, weight=weight, @@ -3207,39 +2502,13 @@ def build(self, use_jit=False, verbose=1): """ eq = self.things[0] - self._dim_f = 1 - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, eq.Psi, None, np.array([0]) - ) - if self._normalize: scales = compute_scaling_factors(eq) self._normalization = scales["Psi"] - super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute fixed-Psi error. - - Parameters - ---------- - params : dict - Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Total toroidal magnetic flux error (Wb). - - """ - return params["Psi"] - -class FixCurveShift(_FixedObjective): +class FixCurveShift(FixParameters): """Fixes Curve.shift attribute, which is redundant with other Curve params. Parameters @@ -3267,7 +2536,7 @@ class FixCurveShift(_FixedObjective): """ _units = "(m)" - _print_value_fmt = "Fixed-shift error: {:10.3e} " + _print_value_fmt = "Fixed shift error: {:10.3e} " def __init__( self, @@ -3277,11 +2546,11 @@ def __init__( weight=1, normalize=True, normalize_target=True, - name="fixed-shift", + name="fixed shift", ): - self._target_from_user = setdefault(bounds, target) super().__init__( - things=curve, + thing=curve, + params={"shift": True}, target=target, bounds=bounds, weight=weight, @@ -3289,51 +2558,10 @@ def __init__( normalize_target=normalize_target, name=name, ) - - def build(self, use_jit=False, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - curve = self.things[0] - self._dim_f = curve.shift.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, curve.shift, None, np.arange(self._dim_f) - ) - - if self._normalize: - self._normalization = 1 - - super().build(use_jit=use_jit, verbose=verbose) - - def compute(self, params, constants=None): - """Compute fixed-shift error. - - Parameters - ---------- - params : dict - Dictionary of curve degrees of freedom, eg Curve.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Curve shift (m). - - """ - return params["shift"] + # TODO: add normalization? -class FixCurveRotation(_FixedObjective): +class FixCurveRotation(FixParameters): """Fixes Curve.rotmat attribute, which is redundant with other Curve params. Parameters @@ -3350,18 +2578,16 @@ class FixCurveRotation(_FixedObjective): Weighting to apply to the Objective, relative to other Objectives. Must be broadcastable to to Objective.dim_f normalize : bool, optional - Whether to compute the error in physical units or non-dimensionalize. + Has no effect for this objective. normalize_target : bool, optional - Whether target and bounds should be normalized before comparing to computed - values. If `normalize` is `True` and the target is in physical units, - this should also be set to True. + Has no effect for this objective. name : str, optional Name of the objective function. """ _units = "(rad)" - _print_value_fmt = "Fixed-rotation error: {:10.3e} " + _print_value_fmt = "Fixed rotation error: {:10.3e} " def __init__( self, @@ -3371,11 +2597,11 @@ def __init__( weight=1, normalize=True, normalize_target=True, - name="fixed-rotation", + name="fixed rotation", ): - self._target_from_user = setdefault(bounds, target) super().__init__( - things=curve, + thing=curve, + params={"rotmat": True}, target=target, bounds=bounds, weight=weight, @@ -3384,50 +2610,8 @@ def __init__( name=name, ) - def build(self, use_jit=False, verbose=1): - """Build constant arrays. - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - curve = self.things[0] - self._dim_f = curve.rotmat.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, curve.rotmat, None, np.arange(self._dim_f) - ) - - if self._normalize: - self._normalization = 1 - - super().build(use_jit=use_jit, verbose=verbose) - - def compute(self, params, constants=None): - """Compute fixed-rotation error. - - Parameters - ---------- - params : dict - Dictionary of curve degrees of freedom, eg Curve.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Curve rotation matrix (rad). - - """ - return params["rotmat"] - - -class FixOmniWell(_FixedObjective): +class FixOmniWell(FixParameters): """Fixes OmnigenousField.B_lm coefficients. Parameters @@ -3469,11 +2653,9 @@ def __init__( indices=True, name="fixed omnigenity well", ): - self._field = field - self._indices = indices - self._target_from_user = setdefault(bounds, target) super().__init__( - things=field, + thing=field, + params={"B_lm": indices}, target=target, bounds=bounds, weight=weight, @@ -3481,57 +2663,10 @@ def __init__( normalize_target=normalize_target, name=name, ) + # TODO: add normalization? - def build(self, use_jit=True, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - field = self.things[0] - - # find indices to fix - if self._indices is False or self._indices is None: # no indices to fix - self._idx = np.array([], dtype=int) - elif self._indices is True: # all indices - self._idx = np.arange(np.size(self._field.B_lm)) - else: # specified indices - self._idx = np.atleast_1d(self._indices) - - self._dim_f = self._idx.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, field.B_lm[self._idx], None, self._idx - ) - - super().build(use_jit=use_jit, verbose=verbose) - def compute(self, params, constants=None): - """Compute fixed omnigenity well error. - - Parameters - ---------- - params : dict - Dictionary of field degrees of freedom, eg OmnigenousField.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed well shape error. - - """ - return params["B_lm"][self._idx] - - -class FixOmniMap(_FixedObjective): +class FixOmniMap(FixParameters): """Fixes OmnigenousField.x_lmn coefficients. Parameters @@ -3544,12 +2679,10 @@ class FixOmniMap(_FixedObjective): Lower and upper bounds on the objective. Overrides target. weight : float, optional Weighting to apply to the Objective, relative to other Objectives. - normalize : bool - Whether to compute the error in physical units or non-dimensionalize. - normalize_target : bool - Whether target should be normalized before comparing to computed values. - if `normalize` is `True` and the target is in physical units, this should also - be set to True. + normalize : bool, optional + Has no effect for this objective. + normalize_target : bool, optional + Has no effect for this objective. indices : ndarray or bool, optional indices of the field.x_lmn array to fix. Must have len(target) = len(weight) = len(indices). @@ -3568,16 +2701,14 @@ def __init__( target=None, bounds=None, weight=1, - normalize=False, - normalize_target=False, + normalize=True, + normalize_target=True, indices=True, name="fixed omnigenity map", ): - self._field = field - self._indices = indices - self._target_from_user = setdefault(bounds, target) super().__init__( - things=field, + thing=field, + params={"x_lmn": indices}, target=target, bounds=bounds, weight=weight, @@ -3586,54 +2717,6 @@ def __init__( name=name, ) - def build(self, use_jit=True, verbose=1): - """Build constant arrays. - - Parameters - ---------- - use_jit : bool, optional - Whether to just-in-time compile the objective and derivatives. - verbose : int, optional - Level of output. - - """ - field = self.things[0] - - # find indices to fix - if self._indices is False or self._indices is None: # no indices to fix - self._idx = np.array([], dtype=int) - elif self._indices is True: # all indices - self._idx = np.arange(np.size(self._field.x_lmn)) - else: # specified indices - self._idx = np.atleast_1d(self._indices) - - self._dim_f = self._idx.size - - self.target, self.bounds = self._parse_target_from_user( - self._target_from_user, field.x_lmn[self._idx], None, self._idx - ) - - super().build(use_jit=use_jit, verbose=verbose) - - def compute(self, params, constants=None): - """Compute fixed omnigenity map error. - - Parameters - ---------- - params : dict - Dictionary of field degrees of freedom, eg OmnigenousField.params_dict - constants : dict - Dictionary of constant data, eg transforms, profiles etc. Defaults to - self.constants - - Returns - ------- - f : ndarray - Fixed omnigenity map error. - - """ - return params["x_lmn"][self._idx] - class FixOmniBmax(_FixedObjective): """Ensures the B_max contour is straight in Boozer coordinates. @@ -3648,12 +2731,10 @@ class FixOmniBmax(_FixedObjective): Lower and upper bounds on the objective. Overrides target. weight : float, optional Weighting to apply to the Objective, relative to other Objectives. - normalize : bool - Whether to compute the error in physical units or non-dimensionalize. - normalize_target : bool - Whether target should be normalized before comparing to computed values. - if `normalize` is `True` and the target is in physical units, this should also - be set to True. + normalize : bool, optional + Has no effect for this objective. + normalize_target : bool, optional + Has no effect for this objective. name : str Name of the objective function. @@ -3669,8 +2750,8 @@ def __init__( target=None, bounds=None, weight=1, - normalize=False, - normalize_target=False, + normalize=True, + normalize_target=True, name="fixed omnigenity B_max", ): self._target_from_user = setdefault(bounds, target) @@ -3739,3 +2820,60 @@ def compute(self, params, constants=None): """ f = jnp.dot(self._A, params["x_lmn"]) return f + + +class FixSheetCurrent(FixParameters): + """Fixes the sheet current parameters of a free-boundary equilibrium. + + Note: this constraint is automatically applied when needed, and does not need to be + included by the user. + + Parameters + ---------- + eq : Equilibrium + Equilibrium that will be optimized to satisfy the Objective. + target : {float, ndarray}, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. Default is ``target=eq.Psi``. + bounds : tuple of {float, ndarray}, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f. + Default is ``target=eq.Psi``. + weight : {float, ndarray}, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + name : str, optional + Name of the objective function. + + """ + + _units = "(~)" + _print_value_fmt = "Fixed sheet current error: {:10.3e} " + + def __init__( + self, + eq, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + name="fixed sheet current", + ): + super().__init__( + thing=eq, + params={"I": True, "G": True, "Phi_mn": True}, + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + name=name, + ) + # TODO: add normalization? diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 48ee1ef0f2..0a98bb0a28 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -865,17 +865,23 @@ def _check_dimensions(self): if self.bounds is not None: # must be a tuple of length 2 self._bounds = tuple([np.asarray(bound) for bound in self._bounds]) for bound in self.bounds: - if not is_broadcastable((self.dim_f,), bound.shape): + if not is_broadcastable((self.dim_f,), bound.shape) or ( + self.dim_f == 1 and bound.size != 1 + ): raise ValueError("len(bounds) != dim_f") if np.any(self.bounds[1] < self.bounds[0]): raise ValueError("bounds must be: (lower bound, upper bound)") else: # target only gets used if bounds is None self._target = np.asarray(self._target) - if not is_broadcastable((self.dim_f,), self.target.shape): + if not is_broadcastable((self.dim_f,), self.target.shape) or ( + self.dim_f == 1 and self.target.size != 1 + ): raise ValueError("len(target) != dim_f") self._weight = np.asarray(self._weight) - if not is_broadcastable((self.dim_f,), self.weight.shape): + if not is_broadcastable((self.dim_f,), self.weight.shape) or ( + self.dim_f == 1 and self.weight.size != 1 + ): raise ValueError("len(weight) != dim_f") @abstractmethod diff --git a/desc/optimizable.py b/desc/optimizable.py index f27ff53239..4427f00518 100644 --- a/desc/optimizable.py +++ b/desc/optimizable.py @@ -4,6 +4,8 @@ import warnings from abc import ABC +import numpy as np + from desc.backend import jnp @@ -86,6 +88,7 @@ def pack_params(self, p): x : ndarray optimizable parameters concatenated into a single array, with indices given by ``x_idx`` + """ return jnp.concatenate( [jnp.atleast_1d(jnp.asarray(p[key])) for key in self.optimizable_params] @@ -104,6 +107,7 @@ def unpack_params(self, x): ------- p : dict Dictionary of ndarray of optimizable parameters. + """ x_idx = self.x_idx params = {} @@ -115,7 +119,8 @@ def _sort_args(self, args): """Put arguments in a canonical order. Returns unique sorted elements. Actual order doesn't really matter as long as its consistent, though subclasses - may override this method to enforce a specific ordering + may override this method to enforce a specific ordering. + """ return sorted(set(list(args))) @@ -177,6 +182,7 @@ def pack_params(self, params): x : ndarray optimizable parameters concatenated into a single array, with indices given by ``x_idx`` + """ return jnp.concatenate([s.pack_params(p) for s, p in zip(self, params)]) @@ -193,8 +199,9 @@ def unpack_params(self, x): ------- p : list dict list of dictionary of ndarray of optimizable parameters. + """ - split_idx = jnp.cumsum(jnp.array([s.dim_x for s in self])) + split_idx = np.cumsum([s.dim_x for s in self]) # must be np not jnp xs = jnp.split(x, split_idx) params = [s.unpack_params(xi) for s, xi in zip(self, xs)] return params diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 1c6453d1a2..93bf5385ae 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -681,6 +681,13 @@ def unpack_state(self, x, per_objective=True): if t is self._eq: xi_splits = np.cumsum([self._eq.dimensions[arg] for arg in self._args]) p = {arg: xis for arg, xis in zip(self._args, jnp.split(xi, xi_splits))} + p.update( # add in dummy values for missing parameters + { + arg: jnp.zeros_like(xis) + for arg, xis in t.params_dict.items() + if arg not in self._args # R_lmn, Z_lmn, L_lmn, Ra_n, Za_n + } + ) params += [p] else: params += [t.unpack_params(xi)] diff --git a/desc/utils.py b/desc/utils.py index 710532d432..3447f4804d 100644 --- a/desc/utils.py +++ b/desc/utils.py @@ -608,3 +608,71 @@ def unique_list(thelist): def is_any_instance(things, cls): """Check if any of things is an instance of cls.""" return any([isinstance(t, cls) for t in things]) + + +def broadcast_tree(tree_in, tree_out, dtype=int): + """Broadcast tree_in to the same pytree structure as tree_out. + + Both trees must be nested lists of dicts with string keys and array values. + Or the values can be bools, where False broadcasts to an empty array and True + broadcasts to the corresponding array from tree_out. + + Parameters + ---------- + tree_in : pytree + Tree to broadcast. + tree_out : pytree + Tree with structure to broadcast to. + dtype : optional + Data type of array values. Default = int. + + Returns + ------- + tree : pytree + Tree with the leaves of tree_in broadcast to the structure of tree_out. + + """ + # both trees at leaf layer + if isinstance(tree_in, dict) and isinstance(tree_out, dict): + tree_new = {} + for key, value in tree_in.items(): + errorif( + key not in tree_out.keys(), + ValueError, + f"dict key '{key}' of tree_in must be a subset of those in tree_out: " + + f"{list(tree_out.keys())}", + ) + if isinstance(value, bool): + if value: + tree_new[key] = np.atleast_1d(tree_out[key]).astype(dtype=dtype) + else: + tree_new[key] = np.array([], dtype=dtype) + else: + tree_new[key] = np.atleast_1d(value).astype(dtype=dtype) + for key, value in tree_out.items(): + if key not in tree_new.keys(): + tree_new[key] = np.array([], dtype=dtype) + errorif( + not np.all(np.isin(tree_new[key], value)), + ValueError, + f"dict value {tree_new[key]} of tree_in must be a subset " + + f"of those in tree_out: {value}", + ) + return tree_new + # tree_out is deeper than tree_in + elif isinstance(tree_in, dict) and isinstance(tree_out, list): + return [broadcast_tree(tree_in.copy(), branch) for branch in tree_out] + # both trees at branch layer + elif isinstance(tree_in, list) and isinstance(tree_out, list): + errorif( + len(tree_in) != len(tree_out), + ValueError, + "tree_in must have the same number of branches as tree_out", + ) + return [broadcast_tree(tree_in[k], tree_out[k]) for k in range(len(tree_out))] + # tree_in is deeper than tree_out + elif isinstance(tree_in, list) and isinstance(tree_out, dict): + raise ValueError("tree_in cannot have a deeper structure than tree_out") + # invalid tree structure + else: + raise ValueError("trees must be nested lists of dicts") diff --git a/docs/api.rst b/docs/api.rst index 7f5cb4c9ae..ae0c1efe3f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -177,7 +177,7 @@ Objective Functions desc.objectives.FixOmniBmax desc.objectives.FixOmniMap desc.objectives.FixOmniWell - desc.objectives.FixParameter + desc.objectives.FixParameters desc.objectives.FixPressure desc.objectives.FixPsi desc.objectives.FixSumModesR diff --git a/docs/api_objectives.rst b/docs/api_objectives.rst index 48c650194b..c0885210d8 100644 --- a/docs/api_objectives.rst +++ b/docs/api_objectives.rst @@ -126,7 +126,7 @@ Fixing degrees of freedom desc.objectives.FixSumModesR desc.objectives.FixSumModesZ desc.objectives.FixThetaSFL - desc.objectives.FixParameter + desc.objectives.FixParameters User defined objectives diff --git a/docs/notebooks/tutorials/free_boundary_equilibrium.ipynb b/docs/notebooks/tutorials/free_boundary_equilibrium.ipynb index e1a4d7b651..9c7929f24a 100644 --- a/docs/notebooks/tutorials/free_boundary_equilibrium.ipynb +++ b/docs/notebooks/tutorials/free_boundary_equilibrium.ipynb @@ -446,12 +446,8 @@ ], "source": [ "# we know this is a pretty simple shape so we'll only use |m| <= 2\n", - "R_modes = (\n", - " eq2.surface.R_basis.modes[np.max(np.abs(eq2.surface.R_basis.modes), 1) > 2, :],\n", - ")\n", - "\n", + "R_modes = eq2.surface.R_basis.modes[np.max(np.abs(eq2.surface.R_basis.modes), 1) > 2, :]\n", "Z_modes = eq2.surface.Z_basis.modes[np.max(np.abs(eq2.surface.Z_basis.modes), 1) > 2, :]\n", - "\n", "bdry_constraints = (\n", " FixBoundaryR(eq=eq2, modes=R_modes),\n", " FixBoundaryZ(eq=eq2, modes=Z_modes),\n", @@ -971,10 +967,9 @@ "for k in [2, 4]:\n", "\n", " # get modes where |m|, |n| > k\n", - " R_modes = (\n", - " eq2.surface.R_basis.modes[np.max(np.abs(eq2.surface.R_basis.modes), 1) > k, :],\n", - " )\n", - "\n", + " R_modes = eq2.surface.R_basis.modes[\n", + " np.max(np.abs(eq2.surface.R_basis.modes), 1) > k, :\n", + " ]\n", " Z_modes = eq2.surface.Z_basis.modes[\n", " np.max(np.abs(eq2.surface.Z_basis.modes), 1) > k, :\n", " ]\n", diff --git a/tests/test_examples.py b/tests/test_examples.py index aa6e8c2503..adad7d7104 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,6 +21,7 @@ OmnigenousField, SplineMagneticField, ToroidalMagneticField, + VerticalMagneticField, ) from desc.objectives import ( AspectRatio, @@ -35,7 +36,7 @@ FixIota, FixOmniBmax, FixOmniMap, - FixParameter, + FixParameters, FixPressure, FixPsi, FixSumModesLambda, @@ -504,7 +505,7 @@ def test_NAE_QSC_solve(): np.testing.assert_allclose(iota[0], qsc.iota, atol=1e-5, err_msg=string) np.testing.assert_allclose(iota[1:10], qsc.iota, atol=1e-3, err_msg=string) - ### check lambda to match near axis + # check lambda to match near axis # Evaluate lambda near the axis data_nae = eqq.compute(["lambda", "|B|"], grid=grid_axis) lam_nae = data_nae["lambda"] @@ -639,8 +640,8 @@ def test_multiobject_optimization_al(): constraints = ( ForceBalance(eq=eq, bounds=(-1e-4, 1e-4), normalize_target=False), FixPressure(eq=eq), - FixParameter(surf, ["Z_lmn", "R_lmn"], [[-1], [0]]), - FixParameter(eq, ["Psi", "i_l"]), + FixParameters(surf, {"R_lmn": np.array([0]), "Z_lmn": np.array([3])}), + FixParameters(eq, {"Psi": True, "i_l": True}), FixBoundaryR(eq, modes=[[0, 0, 0]]), PlasmaVesselDistance(surface=surf, eq=eq, target=1), ) @@ -678,8 +679,8 @@ def test_multiobject_optimization_prox(): constraints = ( ForceBalance(eq=eq), FixPressure(eq=eq), - FixParameter(surf, ["Z_lmn", "R_lmn"], [[-1], [0]]), - FixParameter(eq, ["Psi", "i_l"]), + FixParameters(surf, {"R_lmn": np.array([0]), "Z_lmn": np.array([3])}), + FixParameters(eq, {"Psi": True, "i_l": True}), FixBoundaryR(eq, modes=[[0, 0, 0]]), ) @@ -969,7 +970,7 @@ def test_non_eq_optimization(): surf.change_resolution(M=eq.M, N=eq.N) constraints = ( - FixParameter(eq), + FixParameters(eq), MeanCurvature(surf, bounds=(-8, 8)), PrincipalCurvature(surf, bounds=(0, 15)), ) @@ -998,17 +999,14 @@ def test_only_non_eq_optimization(): """Test for optimizing only a non-eq object.""" eq = get("DSHAPE") surf = eq.surface - surf.change_resolution(M=eq.M, N=eq.N) constraints = ( - FixParameter(surf, params="R_lmn", indices=surf.R_basis.get_idx(0, 0, 0)), + FixParameters(surf, {"R_lmn": np.array(surf.R_basis.get_idx(0, 0, 0))}), ) - obj = PrincipalCurvature(surf, target=1) - objective = ObjectiveFunction((obj,)) optimizer = Optimizer("lsq-exact") - (surf), result = optimizer.optimize( + (surf), _ = optimizer.optimize( (surf), objective, constraints, verbose=3, maxiter=100 ) surf = surf[0] @@ -1032,9 +1030,9 @@ def test_freeb_vacuum(): modes_Z=[[-1, 0]], NFP=5, ) - eq = Equilibrium(M=6, N=6, Psi=-0.035, surface=surf) eq.solve() + constraints = ( ForceBalance(eq=eq), FixCurrent(eq=eq), @@ -1044,15 +1042,15 @@ def test_freeb_vacuum(): objective = ObjectiveFunction( VacuumBoundaryError(eq=eq, field=ext_field, field_fixed=True) ) - eq, out = eq.optimize( + eq, _ = eq.optimize( objective, constraints, optimizer="proximal-lsq-exact", verbose=3, options={}, ) - rho_err, _ = area_difference_vmec(eq, "tests/inputs/wout_test_freeb.nc") + rho_err, _ = area_difference_vmec(eq, "tests/inputs/wout_test_freeb.nc") np.testing.assert_allclose(rho_err[:, -1], 0, atol=4e-2) # only check rho=1 @@ -1090,9 +1088,9 @@ def test_freeb_axisym(): modes_Z=[[-1, 0]], NFP=1, ) - eq = Equilibrium(M=10, N=0, Psi=1.0, surface=surf, pressure=pres, iota=iota) eq.solve() + constraints = ( ForceBalance(eq=eq), FixIota(eq=eq), @@ -1104,27 +1102,26 @@ def test_freeb_axisym(): ) # we know this is a pretty simple shape so we'll only use |m| <= 2 - R_modes = ( - eq.surface.R_basis.modes[np.max(np.abs(eq.surface.R_basis.modes), 1) > 2, :], - ) - + R_modes = eq.surface.R_basis.modes[ + np.max(np.abs(eq.surface.R_basis.modes), 1) > 2, : + ] Z_modes = eq.surface.Z_basis.modes[ np.max(np.abs(eq.surface.Z_basis.modes), 1) > 2, : ] - bdry_constraints = ( FixBoundaryR(eq=eq, modes=R_modes), FixBoundaryZ(eq=eq, modes=Z_modes), ) - eq, out = eq.optimize( + + eq, _ = eq.optimize( objective, constraints + bdry_constraints, optimizer="proximal-lsq-exact", verbose=3, options={}, ) - rho_err, _ = area_difference_vmec(eq, "tests/inputs/wout_solovev_freeb.nc") + rho_err, _ = area_difference_vmec(eq, "tests/inputs/wout_solovev_freeb.nc") np.testing.assert_allclose(rho_err[:, -1], 0, atol=2e-2) # only check rho=1 @@ -1255,7 +1252,7 @@ def test_quadratic_flux_optimization_with_analytic_field(): optimizer = Optimizer("lsq-exact") - constraints = (FixParameter(field, ["R0"]),) + constraints = (FixParameters(field, {"R0": True}),) quadflux_obj = QuadraticFlux( eq=eq, field=field, @@ -1276,3 +1273,24 @@ def test_quadratic_flux_optimization_with_analytic_field(): # optimizer should zero out field since that's the easiest way # to get to Bnorm = 0 np.testing.assert_allclose(things[0].B0, 0, atol=1e-12) + + +@pytest.mark.unit +def test_second_stage_optimization(): + """Test optimizing magnetic field for a fixed axisymmetric equilibrium.""" + eq = get("DSHAPE") + field = ToroidalMagneticField(B0=1, R0=3.5) + VerticalMagneticField(B0=1) + objective = ObjectiveFunction(QuadraticFlux(eq=eq, field=field, vacuum=True)) + constraints = FixParameters(field, [{"R0": True}, {}]) + optimizer = Optimizer("scipy-trf") + (field,), _ = optimizer.optimize( + things=field, + objective=objective, + constraints=constraints, + ftol=0, + xtol=0, + verbose=2, + ) + np.testing.assert_allclose(field[0].R0, 3.5) # this value was fixed + np.testing.assert_allclose(field[0].B0, 1) # toroidal field (no change) + np.testing.assert_allclose(field[1].B0, 0, atol=1e-12) # vertical field (vanishes) diff --git a/tests/test_linear_objectives.py b/tests/test_linear_objectives.py index cecbf2bbe4..0aa7daf203 100644 --- a/tests/test_linear_objectives.py +++ b/tests/test_linear_objectives.py @@ -6,6 +6,13 @@ from qsc import Qsc import desc.examples +from desc.coils import ( + CoilSet, + FourierPlanarCoil, + FourierRZCoil, + FourierXYZCoil, + MixedCoilSet, +) from desc.equilibrium import Equilibrium from desc.geometry import FourierRZToroidalSurface from desc.grid import LinearGrid @@ -33,7 +40,7 @@ FixModeZ, FixOmniMap, FixOmniWell, - FixParameter, + FixParameters, FixPressure, FixPsi, FixSumModesLambda, @@ -346,7 +353,7 @@ def test_factorize_linear_constraints_asserts(): _ = factorize_linear_constraints(objective, constraint) # constraining a foreign thing - constraint = ObjectiveFunction(FixParameter(surf)) + constraint = ObjectiveFunction(FixParameters(surf)) constraint.build(verbose=0) with pytest.raises(UserWarning): _ = factorize_linear_constraints(objective, constraint) @@ -406,7 +413,7 @@ def test_kinetic_constraints(): @pytest.mark.unit def test_correct_indexing_passed_modes(): - """Test Indexing when passing in specified modes, related to gh issue #380.""" + """Test indexing when passing in specified modes, related to gh issue #380.""" n = 1 eq = desc.examples.get("W7-X") eq.change_resolution(3, 3, 3, 6, 6, 6) @@ -459,7 +466,7 @@ def test_correct_indexing_passed_modes(): @pytest.mark.unit def test_correct_indexing_passed_modes_and_passed_target(): - """Test Indexing when passing in specified modes, related to gh issue #380.""" + """Test indexing when passing in specified modes, related to gh issue #380.""" n = 1 eq = desc.examples.get("W7-X") eq.change_resolution(3, 3, 3, 6, 6, 6) @@ -521,7 +528,7 @@ def test_correct_indexing_passed_modes_and_passed_target(): @pytest.mark.unit def test_correct_indexing_passed_modes_axis(): - """Test Indexing when passing in specified axis modes, related to gh issue #380.""" + """Test indexing when passing in specified axis modes, related to gh issue #380.""" n = 1 eq = desc.examples.get("W7-X") eq.change_resolution(3, 3, 3, 6, 6, 6) @@ -580,7 +587,7 @@ def test_correct_indexing_passed_modes_axis(): @pytest.mark.unit def test_correct_indexing_passed_modes_and_passed_target_axis(): - """Test Indexing when passing in specified axis modes, related to gh issue #380.""" + """Test indexing when passing in specified axis modes, related to gh issue #380.""" n = 1 eq = desc.examples.get("W7-X") @@ -743,18 +750,18 @@ def test_FixBoundary_passed_target_no_passed_modes_error(): def test_FixAxis_passed_target_no_passed_modes_error(): """Test Fixing Axis with no passed-in modes.""" eq = Equilibrium() - FixZ = FixAxisZ(eq=eq, modes=True, target=np.array([0, 0])) - with pytest.raises(ValueError): - FixZ.build() - FixZ = FixAxisZ(eq=eq, modes=False, target=np.array([0, 0])) - with pytest.raises(ValueError): - FixZ.build() FixR = FixAxisR(eq=eq, modes=True, target=np.array([0, 0])) with pytest.raises(ValueError): FixR.build() FixR = FixAxisR(eq=eq, modes=False, target=np.array([0, 0])) with pytest.raises(ValueError): FixR.build() + FixZ = FixAxisZ(eq=eq, modes=True, target=np.array([0, 0])) + with pytest.raises(ValueError): + FixZ.build() + FixZ = FixAxisZ(eq=eq, modes=False, target=np.array([0, 0])) + with pytest.raises(ValueError): + FixZ.build() @pytest.mark.unit @@ -792,40 +799,22 @@ def test_FixSumModes_passed_target_too_long(): ) -@pytest.mark.unit -def test_FixMode_False_or_None_modes(): - """Test Fixing Modes without specifying modes or All modes.""" - eq = Equilibrium(L=3, M=4) - with pytest.raises(ValueError): - FixModeR(eq, modes=False, target=np.array([[0, 1]])) - with pytest.raises(ValueError): - FixModeR(eq, modes=None, target=np.array([[0, 1]])) - with pytest.raises(ValueError): - FixModeZ(eq, modes=False, target=np.array([[0, 1]])) - with pytest.raises(ValueError): - FixModeZ(eq, modes=None, target=np.array([[0, 1]])) - with pytest.raises(ValueError): - FixModeLambda(eq, modes=False, target=np.array([[0, 1]])) - with pytest.raises(ValueError): - FixModeLambda(eq, modes=None, target=np.array([[0, 1]])) - - @pytest.mark.unit def test_FixSumModes_False_or_None_modes(): """Test Fixing Sum Modes without specifying modes or All modes.""" eq = Equilibrium(L=3, M=4) with pytest.raises(ValueError): - FixSumModesZ(eq, modes=False, target=np.array([[0, 1]])) + FixSumModesR(eq, modes=False) with pytest.raises(ValueError): - FixSumModesZ(eq, modes=None, target=np.array([[0, 1]])) + FixSumModesR(eq, modes=None) with pytest.raises(ValueError): - FixSumModesR(eq, modes=False, target=np.array([[0, 1]])) + FixSumModesZ(eq, modes=False) with pytest.raises(ValueError): - FixSumModesR(eq, modes=None, target=np.array([[0, 1]])) + FixSumModesZ(eq, modes=None) with pytest.raises(ValueError): - FixSumModesLambda(eq, modes=False, target=np.array([[0, 1]])) + FixSumModesLambda(eq, modes=False) with pytest.raises(ValueError): - FixSumModesLambda(eq, modes=None, target=np.array([[0, 1]])) + FixSumModesLambda(eq, modes=None) def _is_any_instance(things, cls): @@ -897,24 +886,107 @@ def test_fix_omni_indices(): # no indices constraint = FixOmniWell(field=field, indices=False) constraint.build() - assert constraint._idx.size == 0 + assert constraint.dim_f == 0 constraint = FixOmniMap(field=field, indices=False) constraint.build() - assert constraint._idx.size == 0 + assert constraint.dim_f == 0 # all indices constraint = FixOmniWell(field=field, indices=True) constraint.build() - assert constraint._idx.size == field.B_lm.size + assert constraint.dim_f == field.B_lm.size constraint = FixOmniMap(field=field, indices=True) constraint.build() - assert constraint._idx.size == field.x_lmn.size + assert constraint.dim_f == field.x_lmn.size # specified indices indices = np.arange(3, 8) constraint = FixOmniWell(field=field, indices=indices) constraint.build() - assert constraint._idx.size == indices.size + assert constraint.dim_f == indices.size constraint = FixOmniMap(field=field, indices=indices) constraint.build() - assert constraint._idx.size == indices.size + assert constraint.dim_f == indices.size + + +@pytest.mark.unit +def test_fix_parameters_input_order(DummyStellarator): + """Test that FixParameters preserves the input indices and target ordering.""" + eq = load(load_from=str(DummyStellarator["output_path"]), file_format="hdf5") + default_target = eq.Rb_lmn + + # default objective + obj = FixBoundaryR(eq) + obj.build() + np.testing.assert_allclose(obj.target, default_target) + + # manually specify default + obj = FixBoundaryR(eq, modes=eq.surface.R_basis.modes) + obj.build() + np.testing.assert_allclose(obj.target, default_target) + + # reverse order + obj = FixBoundaryR(eq, modes=np.flipud(eq.surface.R_basis.modes)) + obj.build() + np.testing.assert_allclose(obj.target, np.flipud(default_target)) + + # custom order + obj = ObjectiveFunction( + FixBoundaryR(eq, modes=np.array([[0, 0, 0], [0, 1, 0], [0, 1, 1]])) + ) + obj.build() + np.testing.assert_allclose(obj.target_scaled, np.array([3, 1, 0.3])) + np.testing.assert_allclose(obj.compute_scaled_error(obj.x(eq)), np.zeros(obj.dim_f)) + + # custom target + obj = ObjectiveFunction( + FixBoundaryR( + eq, + modes=np.array([[0, 0, 0], [0, 1, 0], [0, 1, 1]]), + target=np.array([0, -1, 0.5]), + ) + ) + obj.build() + np.testing.assert_allclose( + obj.compute_scaled_error(obj.x(eq)), np.array([3, 2, -0.2]) + ) + + +@pytest.mark.unit +def test_fix_subset_of_params_in_collection(): + """Tests FixParameters fixing a subset of things in the collection.""" + tf_coil = FourierPlanarCoil(center=[2, 0, 0], normal=[0, 1, 0], r_n=[1]) + tf_coilset = CoilSet.linspaced_angular(tf_coil, n=4) + vf_coil = FourierRZCoil(R_n=3, Z_n=-1) + vf_coilset = CoilSet.linspaced_linear( + vf_coil, displacement=[0, 0, 2], n=3, endpoint=True + ) + xy_coil = FourierXYZCoil() + full_coilset = MixedCoilSet((tf_coilset, vf_coilset, xy_coil)) + + params = [ + [ + {"current": True}, + {"center": True, "normal": np.array([1])}, + {"r_n": True}, + {}, + ], + {"shift": True, "rotmat": True}, + {"X_n": np.array([1, 2]), "Y_n": False, "Z_n": np.array([0])}, + ] + target = np.concatenate( + ( + np.array([1, 2, 0, 0, 1, 1]), + np.eye(3).flatten(), + np.array([0, 0, 0]), + np.eye(3).flatten(), + np.array([0, 0, 1]), + np.eye(3).flatten(), + np.array([0, 0, 2]), + np.array([10, 2, -2]), + ) + ) + + obj = FixParameters(full_coilset, params) + obj.build() + np.testing.assert_allclose(obj.target, target) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 64f71ee2b3..18bf4bf7a7 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -27,7 +27,7 @@ FixBoundaryZ, FixCurrent, FixIota, - FixParameter, + FixParameters, FixPressure, FixPsi, ForceBalance, @@ -895,9 +895,7 @@ def test_constrained_AL_lsq(): constraints = ( FixBoundaryR(eq=eq, modes=[0, 0, 0]), # fix specified major axis position FixPressure(eq=eq), # fix pressure profile - FixParameter( - eq, "i_l", bounds=(eq.i_l * 0.9, eq.i_l * 1.1) - ), # linear inequality + FixIota(eq, bounds=(eq.i_l * 0.9, eq.i_l * 1.1)), # linear inequality FixPsi(eq=eq, bounds=(eq.Psi * 0.99, eq.Psi * 1.01)), # linear inequality ) # some random constraints to keep the shape from getting wacky @@ -1007,10 +1005,10 @@ def test_optimize_multiple_things_different_order(): NFP=eq.NFP, ) constraints = ( - # don't let eq vary - FixParameter(eq), - # only let the minor radius of the surface vary - FixParameter(surf, params=["R_lmn"], indices=surf.R_basis.get_idx(M=0, N=0)), + FixParameters(eq), # don't let eq vary + FixParameters( # only let the minor radius of the surface vary + surf, params={"R_lmn": np.array(surf.R_basis.get_idx(M=0, N=0))} + ), ) target_dist = 1 @@ -1029,7 +1027,7 @@ def test_optimize_multiple_things_different_order(): optimizer = Optimizer("lsq-exact") # ensure it runs when (eq,surf) are passed - (eq1, surf1), result = optimizer.optimize( + (eq1, surf1), _ = optimizer.optimize( (eq, surf), objective, constraints, verbose=3, maxiter=15, copy=True ) # ensure surface changed correctly @@ -1047,10 +1045,10 @@ def test_optimize_multiple_things_different_order(): # fresh start constraints = ( - # don't let eq vary - FixParameter(eq), - # only let the minor radius of the surface vary - FixParameter(surf, params=["R_lmn"], indices=surf.R_basis.get_idx(M=0, N=0)), + FixParameters(eq), # don't let eq vary + FixParameters( # only let the minor radius of the surface vary + surf, params={"R_lmn": np.array(surf.R_basis.get_idx(M=0, N=0))} + ), ) obj = PlasmaVesselDistance( surface=surf, @@ -1063,7 +1061,7 @@ def test_optimize_multiple_things_different_order(): objective = ObjectiveFunction((obj,)) # ensure it runs when (surf,eq) are passed which is opposite # the order of objective.things - (surf2, eq2), result = optimizer.optimize( + (surf2, eq2), _ = optimizer.optimize( (surf, eq), objective, constraints, verbose=3, maxiter=15, copy=True ) @@ -1087,8 +1085,18 @@ def test_optimize_with_single_constraint(): eq = Equilibrium() optimizer = Optimizer("lsq-exact") objectective = ObjectiveFunction(GenericObjective("|B|", eq), use_jit=False) - constraints = FixParameter( # Psi is not constrained - eq, ["R_lmn", "Z_lmn", "L_lmn", "Rb_lmn", "Zb_lmn", "p_l", "c_l"] + constraints = FixParameters( + eq, + { + "R_lmn": True, + "Z_lmn": True, + "L_lmn": True, + "Rb_lmn": True, + "Zb_lmn": True, + "p_l": True, + "c_l": True, + "Psi": False, # Psi is not constrained + }, ) # test depends on verbose > 0 diff --git a/tests/test_utils.py b/tests/test_utils.py index ff421875b6..6bfadb4008 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,8 +3,9 @@ import numpy as np import pytest +from desc.backend import tree_leaves, tree_structure from desc.grid import LinearGrid -from desc.utils import isalmostequal, islinspaced +from desc.utils import broadcast_tree, isalmostequal, islinspaced @pytest.mark.unit @@ -53,3 +54,146 @@ def test_islinspaced(): # 0D arrays will return True assert islinspaced(np.array(0)) + + +@pytest.mark.unit +def test_broadcast_tree(): + """Test that broadcast_tree works on various pytree structures.""" + tree_out = [ + {"a": np.arange(1), "b": np.arange(2), "c": np.arange(3)}, + [ + {"a": np.arange(2)}, + [{"a": np.arange(1), "d": np.arange(3)}, {"a": np.arange(2)}], + ], + ] + + # tree with tuples, not lists + tree_in = [{}, ({}, [{}, {}])] + with pytest.raises(ValueError): + _ = broadcast_tree(tree_in, tree_out) + + # tree_in is deeper than tree_out + tree_in = [ + [{"a": np.arange(1)}, {"b": np.arange(2), "c": np.arange(3)}], + [{}, [{}, {"a": np.arange(2)}]], + ] + with pytest.raises(ValueError): + _ = broadcast_tree(tree_in, tree_out) + + # tree_in has different number of branches as tree_out + tree_in = [{}, [{}, [{}]]] + with pytest.raises(ValueError): + _ = broadcast_tree(tree_in, tree_out) + + # tree with incorrect keys + tree_in = [{"a": np.arange(1), "b": np.arange(2)}, {"d": np.arange(2)}] + with pytest.raises(ValueError): + _ = broadcast_tree(tree_in, tree_out) + + # tree with incorrect values + tree_in = [{"a": np.arange(1), "b": np.arange(2)}, {"a": np.arange(2)}] + with pytest.raises(ValueError): + _ = broadcast_tree(tree_in, tree_out) + + # tree with proper structure already does not change + tree_in = tree_out.copy() + tree = broadcast_tree(tree_in, tree_out) + assert tree_structure(tree) == tree_structure(tree_out) + for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_out)): + np.testing.assert_allclose(leaf, leaf_correct) + + # broadcast single leaf to full tree + tree_in = {"a": np.arange(1)} + tree = broadcast_tree(tree_in, tree_out) + assert tree_structure(tree) == tree_structure(tree_out) + tree_correct = [ + {"a": np.arange(1), "b": np.array([], dtype=int), "c": np.array([], dtype=int)}, + [ + {"a": np.arange(1)}, + [{"a": np.arange(1), "d": np.array([], dtype=int)}, {"a": np.arange(1)}], + ], + ] + for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): + np.testing.assert_allclose(leaf, leaf_correct) + + # broadcast from only major branches + tree_in = [{"b": np.arange(2), "c": np.arange(1, 3)}, {"a": np.arange(1)}] + tree = broadcast_tree(tree_in, tree_out) + assert tree_structure(tree) == tree_structure(tree_out) + tree_correct = [ + {"a": np.array([], dtype=int), "b": np.arange(2), "c": np.arange(1, 3)}, + [ + {"a": np.arange(1)}, + [{"a": np.arange(1), "d": np.array([], dtype=int)}, {"a": np.arange(1)}], + ], + ] + for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): + np.testing.assert_allclose(leaf, leaf_correct) + + # broadcast from minor branches + tree_in = [ + {"b": np.arange(2), "c": np.arange(1, 3)}, + [{"a": np.arange(2)}, {"a": np.arange(1)}], + ] + tree = broadcast_tree(tree_in, tree_out) + assert tree_structure(tree) == tree_structure(tree_out) + tree_correct = [ + {"a": np.array([], dtype=int), "b": np.arange(2), "c": np.arange(1, 3)}, + [ + {"a": np.arange(2)}, + [{"a": np.arange(1), "d": np.array([], dtype=int)}, {"a": np.arange(1)}], + ], + ] + for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): + np.testing.assert_allclose(leaf, leaf_correct) + + # tree_in with empty dicts and arrays + tree_in = [ + {}, + [ + {"a": np.array([], dtype=int)}, + [{"a": np.arange(1), "d": np.array([0, 2], dtype=int)}, {}], + ], + ] + tree = broadcast_tree(tree_in, tree_out) + assert tree_structure(tree) == tree_structure(tree_out) + tree_correct = [ + { + "a": np.array([], dtype=int), + "b": np.array([], dtype=int), + "c": np.array([], dtype=int), + }, + [ + {"a": np.array([], dtype=int)}, + [ + {"a": np.arange(1), "d": np.array([0, 2], dtype=int)}, + {"a": np.array([], dtype=int)}, + ], + ], + ] + for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): + np.testing.assert_allclose(leaf, leaf_correct) + + # tree_in with bool values + tree_in = [ + {"a": False, "b": True, "c": np.array([0, 2], dtype=int)}, + [ + {"a": True}, + [{"a": False, "d": np.arange(2)}, {"a": True}], + ], + ] + tree = broadcast_tree(tree_in, tree_out) + assert tree_structure(tree) == tree_structure(tree_out) + tree_correct = [ + { + "a": np.array([], dtype=int), + "b": np.arange(2), + "c": np.array([0, 2], dtype=int), + }, + [ + {"a": np.arange(2)}, + [{"a": np.array([], dtype=int), "d": np.arange(2)}, {"a": np.arange(2)}], + ], + ] + for leaf, leaf_correct in zip(tree_leaves(tree), tree_leaves(tree_correct)): + np.testing.assert_allclose(leaf, leaf_correct)