Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dd/optimizable #956

Merged
merged 69 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
6402ab0
create FixCollectionParameters objective
daniel-dudt Mar 25, 2024
9dc9a38
allow fixing of non-default params
daniel-dudt Mar 26, 2024
9a78600
bug fix: use np instead of jnp
daniel-dudt Mar 26, 2024
b0df4e4
update maybe_add_self_consistency
daniel-dudt Mar 26, 2024
4d73c2e
Merge branch 'master' into dd/optimizable
ddudt Mar 26, 2024
bd1eedc
fix NFP int bug
daniel-dudt Mar 26, 2024
b29ec75
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Mar 26, 2024
0bdc245
Merge branch 'master' into dd/optimizable
ddudt Mar 29, 2024
03dfe05
allow fixing params for only some things in collection
daniel-dudt Mar 29, 2024
19a8955
add test for second stage optimization
daniel-dudt Mar 29, 2024
ef63085
Merge branch 'master' into dd/optimizable
f0uriest Apr 2, 2024
f6aa6c1
Merge branch 'master' into dd/optimizable
ddudt Apr 2, 2024
672dd48
merge with master
daniel-dudt Apr 12, 2024
efb5429
Merge branch 'master' into dd/optimizable
ddudt Apr 12, 2024
bfb5f61
update tests
daniel-dudt Apr 12, 2024
d5f8fc2
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Apr 12, 2024
4f75638
Merge branch 'master' into dd/optimizable
ddudt Apr 18, 2024
8db2c3f
broadcast_tree util function
daniel-dudt Apr 19, 2024
135c0a9
broadcast_tree until function + tests
daniel-dudt Apr 20, 2024
fafa8ca
FixCollectionParameters working with custom params input
daniel-dudt Apr 20, 2024
566c810
tiny typos
daniel-dudt Apr 20, 2024
c39faa2
repair 2nd stage opt test
daniel-dudt Apr 20, 2024
80e7dbe
more test cases for broadcast_tree
daniel-dudt Apr 20, 2024
65f6f6e
combine FixCollectionParameters into FixParameter
daniel-dudt Apr 20, 2024
3477a96
fix params_leaves sorting issue
daniel-dudt Apr 20, 2024
b41e761
fix list vs array bug
daniel-dudt Apr 21, 2024
31c94c6
clean up a few lines
daniel-dudt Apr 21, 2024
d41a1fc
fix missing tree_leaves call
daniel-dudt Apr 21, 2024
f0dbd83
add assert statement to fix later
daniel-dudt Apr 22, 2024
be981f6
remove debugging print statement
daniel-dudt Apr 24, 2024
34cb43d
proximal projection hack
daniel-dudt Apr 24, 2024
2818ecb
cast indices to array in test
daniel-dudt Apr 29, 2024
414f296
refactor broadcast_tree
daniel-dudt Apr 29, 2024
1a5ae51
refactor FixParameter objective
daniel-dudt Apr 29, 2024
9115d8b
Merge branch 'master' into dd/optimizable
ddudt Apr 29, 2024
70fd03f
add dtype option to broadcast_tree
daniel-dudt Apr 29, 2024
b997d63
replace some FixedObjectives with FixParameter
daniel-dudt Apr 29, 2024
2d7c033
add FixSheetCurrent objective
daniel-dudt Apr 29, 2024
f88eb4b
use FixParameter for axis/boundary/etc. objectives
daniel-dudt Apr 29, 2024
9dc6a17
replace FixProfile with FixParameter
daniel-dudt Apr 29, 2024
eb6dc48
FixTheatSFL inherit from FixParameter
daniel-dudt Apr 29, 2024
42d7042
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Apr 29, 2024
e9d95a8
syntax error
daniel-dudt Apr 30, 2024
900012c
bug fixes
daniel-dudt Apr 30, 2024
5fb1614
remove outdated FIXME comment
daniel-dudt Apr 30, 2024
1193017
Merge branch 'master' into dd/optimizable
ddudt Apr 30, 2024
7a314b9
repairing tests
daniel-dudt Apr 30, 2024
259a16d
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt Apr 30, 2024
6917eb6
copy and paste is dangerous
daniel-dudt Apr 30, 2024
32731e6
another copy/paste mistake
daniel-dudt Apr 30, 2024
3f3d3c2
syntax issue in freeb test
daniel-dudt Apr 30, 2024
a288df6
Merge branch 'master' into dd/optimizable
ddudt Apr 30, 2024
94532a6
syntax issue in freeb notebook
daniel-dudt Apr 30, 2024
13305b5
make Rory's suggested changes
daniel-dudt May 2, 2024
e0b746c
Merge branch 'master' into dd/optimizable
ddudt May 2, 2024
f524cc0
repair getter funs
daniel-dudt May 2, 2024
68de431
Merge branch 'dd/optimizable' of https://github.com/PlasmaControl/DES…
daniel-dudt May 2, 2024
255f242
add test to check input order is preserved
daniel-dudt May 2, 2024
1752190
re-add normalize kwargs
daniel-dudt May 2, 2024
0615d56
update 2nd stage opt test with note
daniel-dudt May 2, 2024
c690222
update FixParameters example
daniel-dudt May 2, 2024
022861d
I hate debugging code
daniel-dudt May 2, 2024
847e355
Merge branch 'master' into dd/optimizable
ddudt May 3, 2024
8f9acfc
set vacuum=True for second_stage test
daniel-dudt May 3, 2024
fee6990
remove old comment
dpanici May 8, 2024
d025324
add note about True to FixParameters
dpanici May 8, 2024
4015aff
and note of False
dpanici May 8, 2024
7315d76
Merge branch 'master' into dd/optimizable
dpanici May 8, 2024
0aff196
Merge branch 'master' into dd/optimizable
ddudt May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
FixAxisZ,
FixBoundaryR,
FixBoundaryZ,
FixCollectionParameters,
FixCurrent,
FixCurveRotation,
FixCurveShift,
Expand Down
3 changes: 1 addition & 2 deletions desc/objectives/_free_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,12 +481,11 @@ def build(self, use_jit=True, verbose=1):
if self._source_grid is None:
# for axisymmetry we still need to know about toroidal effects, so its
# cheapest to pretend there are extra field periods
source_NFP = eq.NFP if eq.N > 0 else 64
source_grid = LinearGrid(
rho=np.array([1.0]),
M=eq.M_grid,
N=eq.N_grid,
NFP=source_NFP,
NFP=int(eq.NFP) if eq.N > 0 else 64,
sym=False,
)
else:
Expand Down
15 changes: 8 additions & 7 deletions desc/objectives/getters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Utilities for getting standard groups of objectives and constraints."""

from desc.utils import flatten_list, unique_list

from ._equilibrium import Energy, ForceBalance, HelicalForceBalance, RadialForceBalance
from .linear_objectives import (
AxisRSelfConsistency,
Expand Down Expand Up @@ -213,12 +215,8 @@
return any([isinstance(t, cls) for t in things])

# Equilibrium
if (
hasattr(thing, "Ra_n")
and hasattr(thing, "Za_n")
and hasattr(thing, "Rb_lmn")
and hasattr(thing, "Zb_lmn")
and hasattr(thing, "L_lmn")
if {"Rb_lmn", "Zb_lmn", "L_lmn", "Ra_n", "Za_n"} <= set(
f0uriest marked this conversation as resolved.
Show resolved Hide resolved
unique_list(flatten_list(thing.optimizable_params))[0]
):
if not _is_any_instance(constraints, BoundaryRSelfConsistency):
constraints += (BoundaryRSelfConsistency(eq=thing),)
Expand All @@ -232,7 +230,10 @@
constraints += (AxisZSelfConsistency(eq=thing),)

# Curve
elif hasattr(thing, "shift") and hasattr(thing, "rotmat"):
# FIXME: make this work for CoilSet
elif not hasattr(thing, "__len__") and {"shift", "rotmat"} <= set(

Check warning on line 234 in desc/objectives/getters.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/getters.py#L234

Added line #L234 was not covered by tests
unique_list(flatten_list(thing.optimizable_params))[0]
):
if not _is_any_instance(constraints, FixCurveShift):
constraints += (FixCurveShift(curve=thing),)
if not _is_any_instance(constraints, FixCurveRotation):
Expand Down
203 changes: 200 additions & 3 deletions desc/objectives/linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from desc.backend import jnp
from desc.basis import zernike_radial, zernike_radial_coeffs
from desc.utils import errorif, setdefault
from desc.utils import errorif, flatten_list, setdefault, unique_list

from .normalization import compute_scaling_factors
from .objective_funs import _Objective
Expand Down Expand Up @@ -110,6 +110,7 @@
normalize_target=False,
name="Fixed parameter",
):
# TODO: assert that `thing` is not of type `OptimizableCollection`
self._target_from_user = target
self._params = params
self._indices = indices
Expand Down Expand Up @@ -160,8 +161,8 @@
errorif(
len(self._params) != len(self._indices),
ValueError,
f"not enough indices ({len(self._indices)}) "
+ f"for params ({len(self._params)})",
f"Unequal number of indices ({len(self._indices)}) "
+ f"and params ({len(self._params)}).",
)
for idx, par in zip(self._indices, self._params):
if isinstance(idx, bool) and idx:
Expand Down Expand Up @@ -210,6 +211,202 @@
)


class FixCollectionParameters(_FixedObjective):
"""Fix specific degrees of freedom associated with a given OptimizableCollection.

Parameters
----------
thing : OptimizableCollection
Object whose degrees of freedom are being fixed.
params : str or list of str
Names of parameters to fix. Defaults to all parameters.
index : array-like or list of array-like
Indices to fix for each parameter in params. Use True to fix all indices.
target : dict of {float, ndarray}, optional
Target value(s) of the objective. Only used if bounds is None.
Should have the same tree structure as thing.params. Defaults to things.params.
bounds : tuple of dict {float, ndarray}, optional
Lower and upper bounds on the objective. Overrides target.
Should have the same tree structure as thing.params.
weight : dict of {float, ndarray}, optional
Weighting to apply to the Objective, relative to other Objectives.
Should be a scalar or have the same tree structure as thing.params.
normalize : bool, optional
Whether to compute the error in physical units or non-dimensionalize.
Has no effect for this objective.
normalize_target : bool, optional
Whether target and bounds should be normalized before comparing to computed
values. If `normalize` is `True` and the target is in physical units,
this should also be set to True. Has no effect for this objective.
name : str, optional
Name of the objective function.

"""

_scalar = False
_linear = True
_fixed = True
_units = "(~)"
_print_value_fmt = "Fixed parameters error: {:10.3e} "

def __init__(
self,
thing,
params=None,
indices=True,
target=None,
bounds=None,
weight=1,
normalize=False,
normalize_target=False,
name="Fixed parameters",
):
errorif(
not hasattr(thing, "__len__"),
ValueError,
f"Thing must be of type `OptimizableCollection`; got {thing}.",
)
self._target_from_user = target
self._params = params
self._indices = indices
super().__init__(
things=thing,
target=target,
bounds=bounds,
weight=weight,
normalize=normalize,
normalize_target=normalize_target,
name=name,
)

def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Parameters
----------
use_jit : bool, optional
Whether to just-in-time compile the objective and derivatives.
verbose : int, optional
Level of output.

"""
thing = self.things[0]
thing_idx = range(len(thing)) # indices of things in OptimizableCollection
params = setdefault(self._params, thing.optimizable_params)

if not isinstance(params, (list, tuple)):
params = [params]
if not all([isinstance(par, (list, tuple)) for par in params]):
params = [
params if set(params) <= set(thing[k].optimizable_params) else []
for k in thing_idx
]
for k in thing_idx:
for par in params[k]:
errorif(
par not in unique_list(flatten_list(thing.optimizable_params))[0],
ValueError,
f"Parameter {par} not found in optimizable_parameters: "
+ f"{thing.optimizable_params}.",
)
self._params = params

# replace indices=True with actual indices
if isinstance(self._indices, bool) and self._indices:
self._indices = [
[
np.arange(thing.dimensions[k][par])
for par in self._params[k]
if par in thing.dimensions[k].keys()
]
for k in thing_idx
]
# make sure its iterable if only a scalar was passed in
if not isinstance(self._indices, (list, tuple)):
self._indices = [self._indices]

Check warning on line 326 in desc/objectives/linear_objectives.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/linear_objectives.py#L326

Added line #L326 was not covered by tests
# replace idx=True with array of all indices, throwing an error if the length
# of indices is different from number of params
errorif(
len(sum(self._params, [])) != len(sum(self._indices, [])),
ValueError,
f"Unequal number of indices ({len(sum(self._indices, []))}) "
+ f"and params ({len(sum(self._params, []))}).",
)
indices = []
for k in thing_idx:
indices.append({})
for idx, par in zip(self._indices[k], self._params[k]):
if isinstance(idx, bool) and idx:
idx = np.arange(thing.dimensions[k][par])

Check warning on line 340 in desc/objectives/linear_objectives.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/linear_objectives.py#L340

Added line #L340 was not covered by tests
indices[k][par] = np.atleast_1d(idx)
self._indices = indices
self._dim_f = sum(
sum(t.size for t in self._indices[k].values()) for k in thing_idx
)

# FIXME: I don't think custom target/bounds works yet (default target is ok)
default_target = [
{par: thing.params_dict[k][par][self._indices[k][par]] for par in params[k]}
for k in thing_idx
]
default_bounds = None
target, bounds = self._parse_target_from_user(
self._target_from_user, default_target, default_bounds, indices
)
if target:
self.target = jnp.concatenate(
[
(
jnp.concatenate([target[k][par] for par in params[k]])
if par in target[k].keys()
else jnp.array([])
)
for k in thing_idx
]
)
self.bounds = None
else:
self.target = None
self.bounds = (

Check warning on line 370 in desc/objectives/linear_objectives.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/linear_objectives.py#L369-L370

Added lines #L369 - L370 were not covered by tests
jnp.concatenate([bounds[0][par] for par in params]),
jnp.concatenate([bounds[1][par] for par in params]),
)
super().build(use_jit=use_jit, verbose=verbose)

def compute(self, params, constants=None):
"""Compute fixed degree of freedom errors.

Parameters
----------
params : list of dict
f0uriest marked this conversation as resolved.
Show resolved Hide resolved
List of dictionaries of degrees of freedom, eg CoilSet.params_dict
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants

Returns
-------
f : ndarray
Fixed degree of freedom errors.

"""
return jnp.concatenate(
[
(
jnp.concatenate(
[
params[k][par][self._indices[k][par]]
for par in self._params[k]
]
)
if len(self._params[k])
else jnp.array([])
)
for k in range(len(self._params))
]
)


class BoundaryRSelfConsistency(_Objective):
"""Ensure that the boundary and interior surfaces are self-consistent.

Expand Down
11 changes: 9 additions & 2 deletions desc/optimizable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import warnings
from abc import ABC

import numpy as np

from desc.backend import jnp


Expand Down Expand Up @@ -86,6 +88,7 @@ def pack_params(self, p):
x : ndarray
optimizable parameters concatenated into a single array, with indices
given by ``x_idx``

"""
return jnp.concatenate(
[jnp.atleast_1d(jnp.asarray(p[key])) for key in self.optimizable_params]
Expand All @@ -104,6 +107,7 @@ def unpack_params(self, x):
-------
p : dict
Dictionary of ndarray of optimizable parameters.

"""
x_idx = self.x_idx
params = {}
Expand All @@ -115,7 +119,8 @@ def _sort_args(self, args):
"""Put arguments in a canonical order. Returns unique sorted elements.

Actual order doesn't really matter as long as its consistent, though subclasses
may override this method to enforce a specific ordering
may override this method to enforce a specific ordering.

"""
return sorted(set(list(args)))

Expand Down Expand Up @@ -177,6 +182,7 @@ def pack_params(self, params):
x : ndarray
optimizable parameters concatenated into a single array, with indices
given by ``x_idx``

"""
return jnp.concatenate([s.pack_params(p) for s, p in zip(self, params)])

Expand All @@ -193,8 +199,9 @@ def unpack_params(self, x):
-------
p : list dict
list of dictionary of ndarray of optimizable parameters.

"""
split_idx = jnp.cumsum(jnp.array([s.dim_x for s in self]))
split_idx = np.cumsum([s.dim_x for s in self]) # must be np not jnp
ddudt marked this conversation as resolved.
Show resolved Hide resolved
xs = jnp.split(x, split_idx)
params = [s.unpack_params(xi) for s, xi in zip(self, xs)]
return params
Expand Down
28 changes: 27 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from desc.geometry import FourierRZToroidalSurface
from desc.grid import LinearGrid
from desc.io import load
from desc.magnetic_fields import OmnigenousField, SplineMagneticField
from desc.magnetic_fields import (
OmnigenousField,
SplineMagneticField,
ToroidalMagneticField,
VerticalMagneticField,
)
from desc.objectives import (
AspectRatio,
BoundaryError,
Expand All @@ -27,6 +32,7 @@
CurrentDensity,
FixBoundaryR,
FixBoundaryZ,
FixCollectionParameters,
FixCurrent,
FixIota,
FixOmniBmax,
Expand Down Expand Up @@ -1344,3 +1350,23 @@ def test_single_coil_optimization():
np.testing.assert_allclose(
coil.compute("torsion", grid=grid)["torsion"], target, atol=1e-5
)


@pytest.mark.unit
def test_second_stage_optimization():
"""Test optimizing magnetic field for a fixed axisymmetric equilibrium."""
# This also tests that FixCollectionParameters works properly when fixing a
# parameter that does not exist for all things in the collection.

# TODO: change the objective to QuadraticFlux
eq = get("DSHAPE")
field = ToroidalMagneticField(B0=1, R0=3.5) + VerticalMagneticField(B0=1)
objective = ObjectiveFunction(BoundaryError(eq=eq, field=field))
constraints = (FixParameter(eq), FixCollectionParameters(field, "R0"))
optimizer = Optimizer("lsq-exact")
(eq, field), _ = optimizer.optimize(
things=(eq, field), objective=objective, constraints=constraints
)
np.testing.assert_allclose(field[0].R0, 3.5)
np.testing.assert_allclose(field[0].B0, 0.218, rtol=1e-3) # toroidal field
np.testing.assert_allclose(field[1].B0, -0.021, rtol=2e-2) # vertical field
Loading