Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Optimization with OptimizableCollections to Work #857

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions desc/magnetic_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@
[type(field) for field in fields]
)
self._fields = fields
offset = jnp.concatenate(

Check warning on line 481 in desc/magnetic_fields.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields.py#L481

Added line #L481 was not covered by tests
[jnp.array([0]), jnp.cumsum(jnp.array([s.dim_x for s in self]))[:-1]]
)
# store split indices for unpacking
# as a numpy array to avoid jax issues later
self._split_idx = np.asarray(offset)

Check warning on line 486 in desc/magnetic_fields.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields.py#L486

Added line #L486 was not covered by tests
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Way to put this in a super init for OptimizableCollection? otherwise this would have to be done in every init of subclasses of OptimizableCollection

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this needed for? I think there already is something similar in OptimizableCollection

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was, but it used jnp.split and jax threw an error complaining that it could not deal with non-static indices, and it seems that jnp.split cannot be jitted, but if you use an np array as the indices it works

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm ok. We could add an init method to do this, and then call super init from all the subclasses. It might also be possible to do some metaclass stuff like we do to register things as pytrees:

class _AutoRegisterPytree(type):

Another option might be something like we do in the Optimizable class where it gets built the first time its called:

def optimizable_params(self):


def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None):
"""Compute magnetic field at a set of points.
Expand Down
1 change: 1 addition & 0 deletions desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._geometry import (
AspectRatio,
BScaleLength,
DummyFields,
Elongation,
GoodCoordinates,
MeanCurvature,
Expand Down
141 changes: 141 additions & 0 deletions desc/objectives/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,3 +1339,144 @@
f = data["g_rr"]

return jnp.concatenate([g, constants["sigma"] * f])


class DummyFields(_Objective):
"""Target a quantity at points from magnetic fields.

Parameters
----------
fields : MagneticField
MagneticField object.
eq : Equilibrium
eq to calc coords to target Bphi over BZ from fields at.
target : {float, ndarray}, optional
Target value(s) of the objective. Only used if bounds is None.
Must be broadcastable to Objective.dim_f.
bounds : tuple of {float, ndarray}, optional
Lower and upper bounds on the objective. Overrides target.
Both bounds must be broadcastable to to Objective.dim_f
weight : {float, ndarray}, optional
Weighting to apply to the Objective, relative to other Objectives.
Must be broadcastable to to Objective.dim_f
normalize : bool, optional
Whether to compute the error in physical units or non-dimensionalize.
normalize_target : bool
Whether target should be normalized before comparing to computed values.
if `normalize` is `True` and the target is in physical units, this should also
be set to True.
loss_function : {None, 'mean', 'min', 'max'}, optional
Loss function to apply to the objective values once computed. This loss function
is called on the raw compute value, before any shifting, scaling, or
normalization.
deriv_mode : {"auto", "fwd", "rev"}
Specify how to compute jacobian matrix, either forward mode or reverse mode AD.
"auto" selects forward or reverse mode based on the size of the input and output
of the objective. Has no effect on self.grad or self.hess which always use
reverse mode and forward over reverse mode respectively.
eval_grid : Grid, optional
Collocation grid containing the nodes to evaluate at.
source_grid : Grid, optional
Grid to discretize field objects with.
name : str, optional
Name of the objective function.

"""

_coordinates = "rtz"
_units = "unitless"
_print_value_fmt = "Magnetic field quantity: {:10.3e} "

def __init__(
self,
field,
eq,
target=None,
bounds=None,
weight=1,
normalize=True,
normalize_target=True,
loss_function=None,
deriv_mode="auto",
eval_grid=None,
source_grid=None,
name="Bphi over BZ",
):
if target is None and bounds is None:
bounds = (1, np.inf)
self._eval_grid = eval_grid
self._source_grid = source_grid
self._eq = eq

Check warning on line 1409 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1405-L1409

Added lines #L1405 - L1409 were not covered by tests

super().__init__(

Check warning on line 1411 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1411

Added line #L1411 was not covered by tests
things=field,
target=target,
bounds=bounds,
weight=weight,
normalize=normalize,
normalize_target=normalize_target,
loss_function=loss_function,
deriv_mode=deriv_mode,
name=name,
)

def build(self, use_jit=True, verbose=1):
"""Build constant arrays.

Parameters
----------
use_jit : bool, optional
Whether to just-in-time compile the objective and derivatives.
verbose : int, optional
Level of output.

"""
eq = self._eq
if self._eval_grid is None:
eval_grid = LinearGrid(M=eq.M * 2, N=eq.N * 2, NFP=eq.NFP)

Check warning on line 1436 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1434-L1436

Added lines #L1434 - L1436 were not covered by tests
else:
eval_grid = self._eval_grid

Check warning on line 1438 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1438

Added line #L1438 was not covered by tests

self._dim_f = eval_grid.num_nodes
self._data_keys = ["R", "phi", "Z", "e^theta", "e^zeta"]

Check warning on line 1441 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1440-L1441

Added lines #L1440 - L1441 were not covered by tests

timer = Timer()
if verbose > 0:
print("Precomputing transforms")
timer.start("Precomputing transforms")
data = eq.compute(self._data_keys, grid=eval_grid)
self._constants = {

Check warning on line 1448 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1443-L1448

Added lines #L1443 - L1448 were not covered by tests
"coords_rpz": np.vstack([data["R"], data["phi"], data["Z"]]).T,
"quad_weights": eval_grid.weights,
}

timer.stop("Precomputing transforms")
if verbose > 1:
timer.disp("Precomputing transforms")

Check warning on line 1455 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1453-L1455

Added lines #L1453 - L1455 were not covered by tests

super().build(use_jit=use_jit, verbose=verbose)

Check warning on line 1457 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1457

Added line #L1457 was not covered by tests

def compute(self, params, constants=None):
"""Compute magnetic field scale length.

Parameters
----------
params : dict
Dictionary of equilibrium degrees of freedom, eg Equilibrium.params_dict
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants

Returns
-------
f : ndarray
quantity.

"""
if constants is None:
constants = self.constants
B = self.things[0].compute_magnetic_field(

Check warning on line 1478 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1476-L1478

Added lines #L1476 - L1478 were not covered by tests
constants["coords_rpz"], basis="rpz", params=params
)
# return Bphi over BZ (some random quantity we can directly control)
return B[:, 1] / B[:, 2]

Check warning on line 1482 in desc/objectives/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_geometry.py#L1482

Added line #L1482 was not covered by tests
Loading
Loading