Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalObjective function to wrap external codes #1028

Merged
merged 104 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
456a02b
initial commit
daniel-dudt May 17, 2024
6a557ec
get external objective working
daniel-dudt May 20, 2024
fc9ef77
test comparison to generic
daniel-dudt May 20, 2024
9a64f25
allow string kwargs in external fun
daniel-dudt May 21, 2024
ae84d5e
Merge branch 'master' into dd/external
ddudt May 21, 2024
11c1438
exclude ExternalObjective from tests
daniel-dudt May 21, 2024
bedcee1
Merge branch 'master' into dd/external
ddudt May 23, 2024
aff7d46
make external fun take eq as its argument
daniel-dudt May 23, 2024
7f3ff5b
Merge branch 'master' into dd/external
ddudt May 31, 2024
2bb9017
simplify wrapped fun to take params
daniel-dudt May 31, 2024
7b1cfaf
Merge branch 'master' into dd/external
ddudt Jun 4, 2024
debecad
numpifying to make vectorization work
daniel-dudt Jun 5, 2024
633fa5b
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jun 5, 2024
9ea37fb
Revert "numpifying to make vectorization work"
daniel-dudt Jun 5, 2024
87ab19f
vectorization working!
daniel-dudt Jun 7, 2024
5395611
allow vectorized to be an int
daniel-dudt Jun 11, 2024
52d58d0
fix numpy cond
daniel-dudt Jun 11, 2024
96bf929
Merge branch 'master' into dd/external
ddudt Jun 17, 2024
30aeea4
merging but no change?
daniel-dudt Jun 17, 2024
90296ea
update test with new UI
daniel-dudt Jun 17, 2024
d16e95d
remove unused pool code
daniel-dudt Jun 18, 2024
6b3f86d
Merge branch 'master' into dd/external
ddudt Jun 19, 2024
f9b7562
remove comment note
daniel-dudt Jul 17, 2024
fe5e95c
Merge branch 'master' into dd/external
ddudt Jul 17, 2024
f1f466b
fix black formatting from merge conflict
daniel-dudt Jul 18, 2024
ecc5b3b
repair test from merge conflict
daniel-dudt Jul 18, 2024
800b9bb
Merge branch 'master' into dd/external
ddudt Jul 18, 2024
bf62014
remove multiprocessing from ExternalObjective class
daniel-dudt Jul 18, 2024
0547bd7
jaxify as a util function
daniel-dudt Jul 19, 2024
e3057dd
Merge branch 'master' into dd/external
ddudt Jul 19, 2024
4323b8a
Merge branch 'master' into dd/external
ddudt Jul 22, 2024
03d0cb5
ExternalObjective no longer an ABC
daniel-dudt Jul 22, 2024
7723ecd
re-add print logic in backend
daniel-dudt Jul 22, 2024
4864b56
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
ef9711b
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
180c503
Merge branch 'yge/cpu' into dd/external
ddudt Jul 24, 2024
cea3f4a
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
09c02ec
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
0b2207f
exclude ExternalObjective from tests
daniel-dudt Jul 26, 2024
f6a395b
Merge branch 'master' into dd/external
ddudt Jul 26, 2024
aa570d4
scale FD derivatives by tangent norm
daniel-dudt Jul 26, 2024
d62d9ca
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jul 26, 2024
8c7bcb1
Merge branch 'master' into dd/external
ddudt Jul 30, 2024
7f1907b
Merge branch 'master' into dd/external
ddudt Aug 11, 2024
76b2a3c
Merge branch 'master' into dd/external
dpanici Aug 20, 2024
ef98142
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
16bb59b
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
a83a671
resolve merge conflict
daniel-dudt Aug 22, 2024
8beb2e6
Merge branch 'master' into dd/external
ddudt Aug 23, 2024
9e57ee1
Merge branch 'master' into dd/external
ddudt Aug 25, 2024
11521b2
fix formatting from merge conflict
daniel-dudt Aug 25, 2024
c004724
add static_attrs, update test
daniel-dudt Aug 26, 2024
37f9ee3
Merge branch 'master' into dd/external
ddudt Aug 27, 2024
aab0bdb
update with master
daniel-dudt Nov 7, 2024
829af5a
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
ba1a252
update depricated jax.pure_callback vmap arg
daniel-dudt Nov 12, 2024
bb8a535
update vmap_method
daniel-dudt Nov 12, 2024
56d6662
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
655fe06
Merge branch 'master' into dd/external
YigitElma Dec 4, 2024
44c25a2
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
795350d
remove duplicate line from merge conflict
daniel-dudt Dec 12, 2024
6fef120
fix test with block_until_ready
daniel-dudt Dec 12, 2024
021106d
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
0f35919
Merge branch 'master' into dd/external
ddudt Dec 17, 2024
24dd2f3
update documentation
daniel-dudt Dec 17, 2024
d3aa2dd
Merge branch 'master' into dd/external
ddudt Dec 18, 2024
ac1aa63
make vectorized a required arg
daniel-dudt Dec 18, 2024
4db5c9f
make ExternalObjective args keyword only
daniel-dudt Dec 19, 2024
96aec58
Merge branch 'master' into dd/external
ddudt Dec 19, 2024
d5453a9
Merge branch 'master' into dd/external
ddudt Jan 2, 2025
4f31c42
Merge branch 'master' into dd/external
ddudt Jan 3, 2025
a90459a
Merge branch 'master' into dd/external
ddudt Jan 28, 2025
b4463b5
Merge branch 'master' into dd/external
ddudt Jan 28, 2025
bcf77fd
remove ABC inheritance from ExternalObjective
daniel-dudt Jan 28, 2025
7ed2cf7
Merge branch 'master' into dd/external
ddudt Jan 28, 2025
db8e62c
Merge branch 'master' into dd/external
ddudt Jan 29, 2025
721ce5e
create print_info fun in backend
daniel-dudt Jan 30, 2025
a1dd8bc
wrapper for jax.pure_callback syntax
daniel-dudt Jan 30, 2025
85caade
Merge branch 'master' into dd/external
ddudt Jan 30, 2025
0dc4856
kwargs -> fun_kwargs, add example
daniel-dudt Jan 31, 2025
2d561ee
pure_callback -> io_callback
daniel-dudt Jan 31, 2025
2344d74
reference print_backend_info in docs
daniel-dudt Jan 31, 2025
7d125f7
edit docs args order
daniel-dudt Feb 3, 2025
f000dbd
Merge branch 'master' into dd/external
ddudt Feb 3, 2025
1657cb8
edit docs for print_backend_info
daniel-dudt Feb 3, 2025
8c4fd51
io_callback -> pure_callback
daniel-dudt Feb 3, 2025
3c1c318
Merge branch 'master' into dd/external
YigitElma Feb 4, 2025
6ac8021
Merge branch 'master' into dd/external
ddudt Feb 4, 2025
e33246d
Update desc/objectives/_generic.py
ddudt Feb 4, 2025
0a4c84b
Update desc/objectives/_generic.py
ddudt Feb 4, 2025
f1dd2de
Update desc/utils.py
ddudt Feb 4, 2025
67d7680
better jax version handling
daniel-dudt Feb 4, 2025
86ca1f5
Merge branch 'master' into dd/external
ddudt Feb 5, 2025
0801e22
fix versioning logic
daniel-dudt Feb 5, 2025
d372b06
add jnp lines back to backend
daniel-dudt Feb 5, 2025
7ce4266
Merge branch 'master' into dd/external
ddudt Feb 6, 2025
8060b02
Merge branch 'master' into dd/external
ddudt Feb 7, 2025
96a2a3f
Merge branch 'master' into dd/external
ddudt Feb 11, 2025
62267f1
improved version handling
daniel-dudt Feb 11, 2025
4dacc74
default kwarg values
daniel-dudt Feb 11, 2025
89b6677
Merge branch 'master' into dd/external
ddudt Feb 12, 2025
e1cfe95
Merge branch 'master' into dd/external
YigitElma Feb 13, 2025
ee345cf
Merge branch 'master' into dd/external
ddudt Feb 13, 2025
fd1ed51
Merge branch 'master' into dd/external
YigitElma Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions desc/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Backend functions for DESC, with options for JAX or regular numpy."""

import functools
import multiprocessing as mp
import os
import warnings

Expand All @@ -11,15 +12,19 @@
from desc import config as desc_config
from desc import set_device

# only print details in the main process, not child processes spawned by multiprocessing
verbose = bool(mp.current_process().name == "MainProcess")

if os.environ.get("DESC_BACKEND") == "numpy":
jnp = np
use_jax = False
set_device(kind="cpu")
print(
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
if verbose:
print(

Check warning on line 23 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L22-L23

Added lines #L22 - L23 were not covered by tests
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
)
)
)
else:
if desc_config.get("device") is None:
set_device("cpu")
Expand All @@ -41,11 +46,12 @@
x = jnp.linspace(0, 5)
y = jnp.exp(x)
use_jax = True
print(
f"DESC version {desc.__version__},"
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
if verbose:
print(
f"DESC version {desc.__version__}, "
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
del x, y
except ModuleNotFoundError:
jnp = np
Expand All @@ -59,11 +65,13 @@
desc.__version__, np.__version__, y.dtype
)
)
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")

if verbose:
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")
)
)
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
jit = jax.jit
Expand Down Expand Up @@ -515,7 +523,7 @@
val = body_fun(i, val)
return val

def cond(pred, true_fun, false_fun, *operand):
def cond(pred, true_fun, false_fun, *operands):
"""Conditionally apply true_fun or false_fun.

This version is for the numpy backend, for jax backend see jax.lax.cond
Expand All @@ -528,7 +536,7 @@
Function (A -> B), to be applied if pred is True.
false_fun: callable
Function (A -> B), to be applied if pred is False.
operand: any
operands: any
input to either branch depending on pred. The type can be a scalar, array,
or any pytree (nested Python tuple/list/dict) thereof.

Expand All @@ -541,9 +549,9 @@

"""
if pred:
return true_fun(*operand)
return true_fun(*operands)
else:
return false_fun(*operand)
return false_fun(*operands)

def switch(index, branches, operand):
"""Apply exactly one of branches given by index.
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14)
has_endpoint_dupe,
lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]),
lambda _: mask,
operand=None,
None,
)
else:
# If we don't have the idx attributes, we are forced to expand out.
Expand Down
7 changes: 6 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
RadialForceBalance,
)
from ._free_boundary import BoundaryError, VacuumBoundaryError
from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser
from ._generic import (
ExternalObjective,
GenericObjective,
LinearObjectiveFromUser,
ObjectiveFromUser,
)
from ._geometry import (
AspectRatio,
BScaleLength,
Expand Down
177 changes: 172 additions & 5 deletions desc/objectives/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import re
from abc import ABC

import numpy as np

Expand All @@ -11,12 +12,178 @@
from desc.compute.utils import _parse_parameterization, get_profiles, get_transforms
from desc.grid import QuadratureGrid
from desc.optimizable import OptimizableCollection
from desc.utils import errorif, parse_argname_change
from desc.utils import errorif, jaxify, parse_argname_change

from .linear_objectives import _FixedObjective
from .objective_funs import _Objective


class ExternalObjective(_Objective, ABC):
"""Wrap an external code.

Similar to ``ObjectiveFromUser``, except derivatives of the objective function are
computed with finite differences instead of AD. The function does not need not be
JAX transformable.

The user supplied function must take an Equilibrium as its only positional argument,
but can take additional keyword arguments.

Parameters
----------
eq : Equilibrium
Equilibrium that will be optimized to satisfy the Objective.
fun : callable
External objective function. It must take an Equilibrium as its only positional
argument, but can take additional kewyord arguments. It does not need to be JAX
transformable.
dim_f : int
Dimension of the output of ``fun``.
target : {float, ndarray}, optional
Target value(s) of the objective. Only used if bounds is None.
Must be broadcastable to Objective.dim_f. Defaults to ``target=0``.
bounds : tuple of {float, ndarray}, optional
Lower and upper bounds on the objective. Overrides target.
Both bounds must be broadcastable to to Objective.dim_f.
Defaults to ``target=0``.
weight : {float, ndarray}, optional
Weighting to apply to the Objective, relative to other Objectives.
Must be broadcastable to to Objective.dim_f
normalize : bool, optional
Whether to compute the error in physical units or non-dimensionalize.
Has no effect for this objective.
normalize_target : bool, optional
Whether target and bounds should be normalized before comparing to computed
values. If `normalize` is `True` and the target is in physical units,
this should also be set to True.
loss_function : {None, 'mean', 'min', 'max'}, optional
Loss function to apply to the objective values once computed. This loss function
is called on the raw compute value, before any shifting, scaling, or
normalization.
vectorized : bool, optional
Whether or not ``fun`` is vectorized. Default = False.
abs_step : float, optional
Absolute finite difference step size. Default = 1e-4.
Total step size is ``abs_step + rel_step * mean(abs(x))``.
rel_step : float, optional
Relative finite difference step size. Default = 0.
Total step size is ``abs_step + rel_step * mean(abs(x))``.
name : str, optional
Name of the objective function.
kwargs : any, optional
Keyword arguments that are passed as inputs to ``fun``.

# TODO: add example

"""

_units = "(Unknown)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be better to take it as an argument, too? The default can still be (unknown).

_print_value_fmt = "External objective value: {:10.3e}"

def __init__(
self,
eq,
fun,
dim_f,
target=None,
bounds=None,
weight=1,
normalize=False,
normalize_target=False,
loss_function=None,
vectorized=False,
abs_step=1e-4,
rel_step=0,
name="external",
**kwargs,
):
if target is None and bounds is None:
target = 0

Check warning on line 100 in desc/objectives/_generic.py

View check run for this annotation

Codecov / codecov/patch

desc/objectives/_generic.py#L100

Added line #L100 was not covered by tests
self._eq = eq.copy()
self._fun = fun
self._dim_f = dim_f
self._vectorized = vectorized
self._abs_step = abs_step
self._rel_step = rel_step
self._kwargs = kwargs
super().__init__(
things=eq,
target=target,
bounds=bounds,
weight=weight,
normalize=normalize,
normalize_target=normalize_target,
loss_function=loss_function,
deriv_mode="fwd",
name=name,
)

def build(self, use_jit=True, verbose=1):
"""Build constant arrays.

Parameters
----------
use_jit : bool, optional
Whether to just-in-time compile the objective and derivatives.
verbose : int, optional
Level of output.

"""
self._scalar = self._dim_f == 1
self._constants = {"quad_weights": 1.0}

def fun_wrapped(params):
"""Wrap external function with possibly vectorized params."""
# number of equilibria for vectorized computations
param_shape = params["Psi"].shape
num_eq = param_shape[0] if len(param_shape) > 1 else 1

# convert params to list of equilibria
eqs = [self._eq.copy() for _ in range(num_eq)]
for k, eq in enumerate(eqs):
# update equilibria with new params
for param_key in self._eq.optimizable_params:
param_value = np.atleast_2d(params[param_key])[k, :]
if len(param_value):
setattr(eq, param_key, param_value)

# call external function on equilibrium or list of equilibria
if not self._vectorized:
eqs = eqs[0]
return self._fun(eqs, **self._kwargs)

# wrap external function to work with JAX
abstract_eval = lambda *args, **kwargs: jnp.empty(self._dim_f)
self._fun_wrapped = jaxify(
fun_wrapped,
abstract_eval,
vectorized=self._vectorized,
abs_step=self._abs_step,
rel_step=self._rel_step,
)

super().build(use_jit=use_jit, verbose=verbose)

def compute(self, params, constants=None):
"""Compute the quantity.

Parameters
----------
params : list of dict
List of dictionaries of degrees of freedom, eg CoilSet.params_dict
constants : dict
Dictionary of constant data, eg transforms, profiles etc. Defaults to
self.constants

Returns
-------
f : ndarray
Computed quantity.

"""
f = self._fun_wrapped(params)
return f


class GenericObjective(_Objective):
"""A generic objective that can compute any quantity from the `data_index`.

Expand Down Expand Up @@ -352,10 +519,9 @@
def myfun(grid, data):
# This will compute the flux surface average of the function
# R*B_T from the Grad-Shafranov equation
f = data['R']*data['B_phi']
f = data['R'] * data['B_phi']
f_fsa = surface_averages(grid, f, sqrt_g=data['sqrt_g'])
# this has the FSA values on the full grid, but we just want
# the unique values:
# this is the FSA on the full grid, but we only want the unique values:
return grid.compress(f_fsa)

myobj = ObjectiveFromUser(fun=myfun, thing=eq)
Expand Down Expand Up @@ -414,6 +580,8 @@
Level of output.

"""
import jax

thing = self.things[0]
if self._grid is None:
errorif(
Expand Down Expand Up @@ -444,7 +612,6 @@
).squeeze()

self._fun_wrapped = lambda data: self._fun(grid, data)
import jax

self._dim_f = jax.eval_shape(self._fun_wrapped, dummy_data).size
self._scalar = self._dim_f == 1
Expand Down
4 changes: 1 addition & 3 deletions desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ def jac_(op, x, constants=None):
for obj, const in zip(self.objectives, constants):
# get the xs that go to that objective
xi = [x for x, t in zip(xs, self.things) if t in obj.things]
Ji_ = getattr(obj, op)(
*xi, constants=const
) # jac wrt to just those things
Ji_ = getattr(obj, op)(*xi, constants=const) # jac wrt only xi
Ji = [] # jac wrt all things
for thing in self.things:
if thing in obj.things:
Expand Down
Loading
Loading