Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalObjective function to wrap external codes #1028

Merged
merged 104 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
456a02b
initial commit
daniel-dudt May 17, 2024
6a557ec
get external objective working
daniel-dudt May 20, 2024
fc9ef77
test comparison to generic
daniel-dudt May 20, 2024
9a64f25
allow string kwargs in external fun
daniel-dudt May 21, 2024
ae84d5e
Merge branch 'master' into dd/external
ddudt May 21, 2024
11c1438
exclude ExternalObjective from tests
daniel-dudt May 21, 2024
bedcee1
Merge branch 'master' into dd/external
ddudt May 23, 2024
aff7d46
make external fun take eq as its argument
daniel-dudt May 23, 2024
7f3ff5b
Merge branch 'master' into dd/external
ddudt May 31, 2024
2bb9017
simplify wrapped fun to take params
daniel-dudt May 31, 2024
7b1cfaf
Merge branch 'master' into dd/external
ddudt Jun 4, 2024
debecad
numpifying to make vectorization work
daniel-dudt Jun 5, 2024
633fa5b
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jun 5, 2024
9ea37fb
Revert "numpifying to make vectorization work"
daniel-dudt Jun 5, 2024
87ab19f
vectorization working!
daniel-dudt Jun 7, 2024
5395611
allow vectorized to be an int
daniel-dudt Jun 11, 2024
52d58d0
fix numpy cond
daniel-dudt Jun 11, 2024
96bf929
Merge branch 'master' into dd/external
ddudt Jun 17, 2024
30aeea4
merging but no change?
daniel-dudt Jun 17, 2024
90296ea
update test with new UI
daniel-dudt Jun 17, 2024
d16e95d
remove unused pool code
daniel-dudt Jun 18, 2024
6b3f86d
Merge branch 'master' into dd/external
ddudt Jun 19, 2024
f9b7562
remove comment note
daniel-dudt Jul 17, 2024
fe5e95c
Merge branch 'master' into dd/external
ddudt Jul 17, 2024
f1f466b
fix black formatting from merge conflict
daniel-dudt Jul 18, 2024
ecc5b3b
repair test from merge conflict
daniel-dudt Jul 18, 2024
800b9bb
Merge branch 'master' into dd/external
ddudt Jul 18, 2024
bf62014
remove multiprocessing from ExternalObjective class
daniel-dudt Jul 18, 2024
0547bd7
jaxify as a util function
daniel-dudt Jul 19, 2024
e3057dd
Merge branch 'master' into dd/external
ddudt Jul 19, 2024
4323b8a
Merge branch 'master' into dd/external
ddudt Jul 22, 2024
03d0cb5
ExternalObjective no longer an ABC
daniel-dudt Jul 22, 2024
7723ecd
re-add print logic in backend
daniel-dudt Jul 22, 2024
4864b56
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
ef9711b
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
180c503
Merge branch 'yge/cpu' into dd/external
ddudt Jul 24, 2024
cea3f4a
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
09c02ec
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
0b2207f
exclude ExternalObjective from tests
daniel-dudt Jul 26, 2024
f6a395b
Merge branch 'master' into dd/external
ddudt Jul 26, 2024
aa570d4
scale FD derivatives by tangent norm
daniel-dudt Jul 26, 2024
d62d9ca
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jul 26, 2024
8c7bcb1
Merge branch 'master' into dd/external
ddudt Jul 30, 2024
7f1907b
Merge branch 'master' into dd/external
ddudt Aug 11, 2024
76b2a3c
Merge branch 'master' into dd/external
dpanici Aug 20, 2024
ef98142
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
16bb59b
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
a83a671
resolve merge conflict
daniel-dudt Aug 22, 2024
8beb2e6
Merge branch 'master' into dd/external
ddudt Aug 23, 2024
9e57ee1
Merge branch 'master' into dd/external
ddudt Aug 25, 2024
11521b2
fix formatting from merge conflict
daniel-dudt Aug 25, 2024
c004724
add static_attrs, update test
daniel-dudt Aug 26, 2024
37f9ee3
Merge branch 'master' into dd/external
ddudt Aug 27, 2024
aab0bdb
update with master
daniel-dudt Nov 7, 2024
829af5a
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
ba1a252
update depricated jax.pure_callback vmap arg
daniel-dudt Nov 12, 2024
bb8a535
update vmap_method
daniel-dudt Nov 12, 2024
56d6662
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
655fe06
Merge branch 'master' into dd/external
YigitElma Dec 4, 2024
44c25a2
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
795350d
remove duplicate line from merge conflict
daniel-dudt Dec 12, 2024
6fef120
fix test with block_until_ready
daniel-dudt Dec 12, 2024
021106d
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
0f35919
Merge branch 'master' into dd/external
ddudt Dec 17, 2024
24dd2f3
update documentation
daniel-dudt Dec 17, 2024
d3aa2dd
Merge branch 'master' into dd/external
ddudt Dec 18, 2024
ac1aa63
make vectorized a required arg
daniel-dudt Dec 18, 2024
4db5c9f
make ExternalObjective args keyword only
daniel-dudt Dec 19, 2024
96aec58
Merge branch 'master' into dd/external
ddudt Dec 19, 2024
d5453a9
Merge branch 'master' into dd/external
ddudt Jan 2, 2025
4f31c42
Merge branch 'master' into dd/external
ddudt Jan 3, 2025
a90459a
Merge branch 'master' into dd/external
ddudt Jan 28, 2025
b4463b5
Merge branch 'master' into dd/external
ddudt Jan 28, 2025
bcf77fd
remove ABC inheritance from ExternalObjective
daniel-dudt Jan 28, 2025
7ed2cf7
Merge branch 'master' into dd/external
ddudt Jan 28, 2025
db8e62c
Merge branch 'master' into dd/external
ddudt Jan 29, 2025
721ce5e
create print_info fun in backend
daniel-dudt Jan 30, 2025
a1dd8bc
wrapper for jax.pure_callback syntax
daniel-dudt Jan 30, 2025
85caade
Merge branch 'master' into dd/external
ddudt Jan 30, 2025
0dc4856
kwargs -> fun_kwargs, add example
daniel-dudt Jan 31, 2025
2d561ee
pure_callback -> io_callback
daniel-dudt Jan 31, 2025
2344d74
reference print_backend_info in docs
daniel-dudt Jan 31, 2025
7d125f7
edit docs args order
daniel-dudt Feb 3, 2025
f000dbd
Merge branch 'master' into dd/external
ddudt Feb 3, 2025
1657cb8
edit docs for print_backend_info
daniel-dudt Feb 3, 2025
8c4fd51
io_callback -> pure_callback
daniel-dudt Feb 3, 2025
3c1c318
Merge branch 'master' into dd/external
YigitElma Feb 4, 2025
6ac8021
Merge branch 'master' into dd/external
ddudt Feb 4, 2025
e33246d
Update desc/objectives/_generic.py
ddudt Feb 4, 2025
0a4c84b
Update desc/objectives/_generic.py
ddudt Feb 4, 2025
f1dd2de
Update desc/utils.py
ddudt Feb 4, 2025
67d7680
better jax version handling
daniel-dudt Feb 4, 2025
86ca1f5
Merge branch 'master' into dd/external
ddudt Feb 5, 2025
0801e22
fix versioning logic
daniel-dudt Feb 5, 2025
d372b06
add jnp lines back to backend
daniel-dudt Feb 5, 2025
7ce4266
Merge branch 'master' into dd/external
ddudt Feb 6, 2025
8060b02
Merge branch 'master' into dd/external
ddudt Feb 7, 2025
96a2a3f
Merge branch 'master' into dd/external
ddudt Feb 11, 2025
62267f1
improved version handling
daniel-dudt Feb 11, 2025
4dacc74
default kwarg values
daniel-dudt Feb 11, 2025
89b6677
Merge branch 'master' into dd/external
ddudt Feb 12, 2025
e1cfe95
Merge branch 'master' into dd/external
YigitElma Feb 13, 2025
ee345cf
Merge branch 'master' into dd/external
ddudt Feb 13, 2025
fd1ed51
Merge branch 'master' into dd/external
YigitElma Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ New Features
for compatibility with other codes which expect such files from the Booz_Xform code.
- Renames compute quantity ``sqrt(g)_B`` to ``sqrt(g)_Boozer_DESC`` to more accurately reflect what the quantiy is (the jacobian from (rho,theta_B,zeta_B) to (rho,theta,zeta)), and adds a new function to compute ``sqrt(g)_Boozer`` which is the jacobian from (rho,theta_B,zeta_B) to (R,phi,Z).
- Allows specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file
- Adds a new objective ``desc.objectives.ExternalObjective`` for wrapping external codes with finite differences.
- DESC/JAX version and device info is no longer printed by default, but can be accessed with the function `desc.backend.print_backend_info()`.

Speed Improvements

Expand Down
2 changes: 2 additions & 0 deletions desc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def main(cl_args=sys.argv[1:]):

import matplotlib.pyplot as plt

from desc.backend import print_backend_info
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.plotting import plot_section, plot_surfaces

if ir.args.verbose:
print_backend_info()
print("Reading input from {}".format(ir.input_path))
print("Outputs will be written to {}".format(ir.output_path))

Expand Down
68 changes: 46 additions & 22 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

import numpy as np
from packaging.version import Version
from termcolor import colored

import desc
Expand All @@ -15,11 +16,6 @@
jnp = np
use_jax = False
set_device(kind="cpu")
print(
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
)
)
else:
if desc_config.get("device") is None:
set_device("cpu")
Expand All @@ -41,29 +37,31 @@
x = jnp.linspace(0, 5)
y = jnp.exp(x)
use_jax = True
print(
f"DESC version {desc.__version__},"
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
del x, y
except ModuleNotFoundError:
jnp = np
x = jnp.linspace(0, 5)
y = jnp.exp(x)
use_jax = False
set_device(kind="cpu")
warnings.warn(colored("Failed to load JAX", "red"))


def print_backend_info():
"""Prints DESC version, backend type & version, device type & memory."""
print(f"DESC version={desc.__version__}.")
if use_jax:
print(
"DESC version {}, using NumPy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, y.dtype
)
f"Using JAX backend: jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}."
)
else:
print(f"Using NumPy backend: version={np.__version__}, dtype={y.dtype}.")

Check warning on line 58 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L58

Added line #L58 was not covered by tests
print(
"Using device: {}, with {:.2f} GB available memory.".format(
desc_config.get("device"), desc_config.get("avail_mem")
)
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")
)
)


if use_jax: # noqa: C901
from jax import custom_jvp, jit, vmap
Expand All @@ -85,13 +83,35 @@
treedef_is_leaf,
)

# TODO: update this when JAX min version >= 0.4.26
if hasattr(jnp, "trapezoid"):
trapezoid = jnp.trapezoid # for JAX 0.4.26 and later
elif hasattr(jax.scipy, "integrate"):
trapezoid = jax.scipy.integrate.trapezoid
else:
trapezoid = jnp.trapz # for older versions of JAX, deprecated by jax 0.4.16

# TODO: update this when JAX min version >= 0.4.35
if Version(jax.__version__) >= Version("0.4.35"):

def pure_callback(func, result_shape_dtype, *args, vectorized=False, **kwargs):
"""Wrapper for jax.pure_callback for versions >=0.4.35."""
return jax.pure_callback(
func,
result_shape_dtype,
*args,
vmap_method="expand_dims" if vectorized else "sequential",
**kwargs,
)

else:

def pure_callback(func, result_shape_dtype, *args, vectorized=False, **kwargs):
"""Wrapper for jax.pure_callback for versions <0.4.35."""
return jax.pure_callback(

Check warning on line 111 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L111

Added line #L111 was not covered by tests
func, result_shape_dtype, *args, vectorized=vectorized, **kwargs
)

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.

Expand Down Expand Up @@ -481,6 +501,10 @@
"""
return lambda xs: _map(fun, xs, in_axes=in_axes, out_axes=out_axes)

def pure_callback(*args, **kwargs):
"""IO callback for numpy backend."""
raise NotImplementedError

def tree_stack(*args, **kwargs):
"""Stack pytree for numpy backend."""
raise NotImplementedError
Expand Down Expand Up @@ -586,7 +610,7 @@
val = body_fun(i, val)
return val

def cond(pred, true_fun, false_fun, *operand):
def cond(pred, true_fun, false_fun, *operands):
"""Conditionally apply true_fun or false_fun.

This version is for the numpy backend, for jax backend see jax.lax.cond
Expand All @@ -599,7 +623,7 @@
Function (A -> B), to be applied if pred is True.
false_fun: callable
Function (A -> B), to be applied if pred is False.
operand: any
operands: any
input to either branch depending on pred. The type can be a scalar, array,
or any pytree (nested Python tuple/list/dict) thereof.

Expand All @@ -612,9 +636,9 @@

"""
if pred:
return true_fun(*operand)
return true_fun(*operands)
else:
return false_fun(*operand)
return false_fun(*operands)

def switch(index, branches, operand):
"""Apply exactly one of branches given by index.
Expand Down
2 changes: 1 addition & 1 deletion desc/integrals/surface_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14)
has_endpoint_dupe,
lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]),
lambda _: mask,
operand=None,
None,
)
else:
# If we don't have the idx attributes, we are forced to expand out.
Expand Down
7 changes: 6 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
)
from ._fast_ion import GammaC
from ._free_boundary import BoundaryError, VacuumBoundaryError
from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser
from ._generic import (
ExternalObjective,
GenericObjective,
LinearObjectiveFromUser,
ObjectiveFromUser,
)
from ._geometry import (
AspectRatio,
BScaleLength,
Expand Down
Loading
Loading