Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExternalObjective function to wrap external codes #1028

Draft
wants to merge 69 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
456a02b
initial commit
daniel-dudt May 17, 2024
6a557ec
get external objective working
daniel-dudt May 20, 2024
fc9ef77
test comparison to generic
daniel-dudt May 20, 2024
9a64f25
allow string kwargs in external fun
daniel-dudt May 21, 2024
ae84d5e
Merge branch 'master' into dd/external
ddudt May 21, 2024
11c1438
exclude ExternalObjective from tests
daniel-dudt May 21, 2024
bedcee1
Merge branch 'master' into dd/external
ddudt May 23, 2024
aff7d46
make external fun take eq as its argument
daniel-dudt May 23, 2024
7f3ff5b
Merge branch 'master' into dd/external
ddudt May 31, 2024
2bb9017
simplify wrapped fun to take params
daniel-dudt May 31, 2024
7b1cfaf
Merge branch 'master' into dd/external
ddudt Jun 4, 2024
debecad
numpifying to make vectorization work
daniel-dudt Jun 5, 2024
633fa5b
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jun 5, 2024
9ea37fb
Revert "numpifying to make vectorization work"
daniel-dudt Jun 5, 2024
87ab19f
vectorization working!
daniel-dudt Jun 7, 2024
5395611
allow vectorized to be an int
daniel-dudt Jun 11, 2024
52d58d0
fix numpy cond
daniel-dudt Jun 11, 2024
96bf929
Merge branch 'master' into dd/external
ddudt Jun 17, 2024
30aeea4
merging but no change?
daniel-dudt Jun 17, 2024
90296ea
update test with new UI
daniel-dudt Jun 17, 2024
d16e95d
remove unused pool code
daniel-dudt Jun 18, 2024
6b3f86d
Merge branch 'master' into dd/external
ddudt Jun 19, 2024
f9b7562
remove comment note
daniel-dudt Jul 17, 2024
fe5e95c
Merge branch 'master' into dd/external
ddudt Jul 17, 2024
f1f466b
fix black formatting from merge conflict
daniel-dudt Jul 18, 2024
ecc5b3b
repair test from merge conflict
daniel-dudt Jul 18, 2024
800b9bb
Merge branch 'master' into dd/external
ddudt Jul 18, 2024
bf62014
remove multiprocessing from ExternalObjective class
daniel-dudt Jul 18, 2024
0547bd7
jaxify as a util function
daniel-dudt Jul 19, 2024
e3057dd
Merge branch 'master' into dd/external
ddudt Jul 19, 2024
4323b8a
Merge branch 'master' into dd/external
ddudt Jul 22, 2024
03d0cb5
ExternalObjective no longer an ABC
daniel-dudt Jul 22, 2024
7723ecd
re-add print logic in backend
daniel-dudt Jul 22, 2024
4864b56
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
ef9711b
Merge branch 'yge/cpu' into dd/external
ddudt Jul 23, 2024
180c503
Merge branch 'yge/cpu' into dd/external
ddudt Jul 24, 2024
cea3f4a
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
09c02ec
Merge branch 'master' into dd/external
ddudt Jul 24, 2024
0b2207f
exclude ExternalObjective from tests
daniel-dudt Jul 26, 2024
f6a395b
Merge branch 'master' into dd/external
ddudt Jul 26, 2024
aa570d4
scale FD derivatives by tangent norm
daniel-dudt Jul 26, 2024
d62d9ca
Merge branch 'dd/external' of https://github.com/PlasmaControl/DESC i…
daniel-dudt Jul 26, 2024
8c7bcb1
Merge branch 'master' into dd/external
ddudt Jul 30, 2024
7f1907b
Merge branch 'master' into dd/external
ddudt Aug 11, 2024
76b2a3c
Merge branch 'master' into dd/external
dpanici Aug 20, 2024
ef98142
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
16bb59b
Merge branch 'master' into dd/external
ddudt Aug 22, 2024
a83a671
resolve merge conflict
daniel-dudt Aug 22, 2024
8beb2e6
Merge branch 'master' into dd/external
ddudt Aug 23, 2024
9e57ee1
Merge branch 'master' into dd/external
ddudt Aug 25, 2024
11521b2
fix formatting from merge conflict
daniel-dudt Aug 25, 2024
c004724
add static_attrs, update test
daniel-dudt Aug 26, 2024
37f9ee3
Merge branch 'master' into dd/external
ddudt Aug 27, 2024
aab0bdb
update with master
daniel-dudt Nov 7, 2024
829af5a
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
ba1a252
update depricated jax.pure_callback vmap arg
daniel-dudt Nov 12, 2024
bb8a535
update vmap_method
daniel-dudt Nov 12, 2024
56d6662
Merge branch 'master' into dd/external
ddudt Nov 12, 2024
655fe06
Merge branch 'master' into dd/external
YigitElma Dec 4, 2024
44c25a2
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
795350d
remove duplicate line from merge conflict
daniel-dudt Dec 12, 2024
6fef120
fix test with block_until_ready
daniel-dudt Dec 12, 2024
021106d
Merge branch 'master' into dd/external
ddudt Dec 12, 2024
0f35919
Merge branch 'master' into dd/external
ddudt Dec 17, 2024
24dd2f3
update documentation
daniel-dudt Dec 17, 2024
d3aa2dd
Merge branch 'master' into dd/external
ddudt Dec 18, 2024
ac1aa63
make vectorized a required arg
daniel-dudt Dec 18, 2024
4db5c9f
make ExternalObjective args keyword only
daniel-dudt Dec 19, 2024
96aec58
Merge branch 'master' into dd/external
ddudt Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions desc/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Backend functions for DESC, with options for JAX or regular numpy."""

import multiprocessing
import os
import warnings

Expand All @@ -10,15 +11,23 @@
from desc import config as desc_config
from desc import set_device

verbose = True

# set child processes to use numpy backend and suppress print statements
if not multiprocessing.current_process().name == "MainProcess":
ddudt marked this conversation as resolved.
Show resolved Hide resolved
os.environ["DESC_BACKEND"] = "numpy"
verbose = False

Check warning on line 19 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L18-L19

Added lines #L18 - L19 were not covered by tests

if os.environ.get("DESC_BACKEND") == "numpy":
jnp = np
use_jax = False
set_device(kind="cpu")
print(
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
if verbose:
print(

Check warning on line 26 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L25-L26

Added lines #L25 - L26 were not covered by tests
"DESC version {}, using numpy backend, version={}, dtype={}".format(
desc.__version__, np.__version__, np.linspace(0, 1).dtype
)
)
)
else:
if desc_config.get("device") is None:
set_device("cpu")
Expand All @@ -40,11 +49,12 @@
x = jnp.linspace(0, 5)
y = jnp.exp(x)
use_jax = True
print(
f"DESC version {desc.__version__},"
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
if verbose:
print(
f"DESC version {desc.__version__}, "
+ f"using JAX backend, jax version={jax.__version__}, "
+ f"jaxlib version={jaxlib.__version__}, dtype={y.dtype}"
)
del x, y
except ModuleNotFoundError:
jnp = np
Expand All @@ -58,11 +68,13 @@
desc.__version__, np.__version__, y.dtype
)
)
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")

if verbose:
print(
"Using device: {}, with {:.2f} GB available memory".format(
desc_config.get("device"), desc_config.get("avail_mem")
)
)
)

if use_jax: # noqa: C901 - FIXME: simplify this, define globally and then assign?
jit = jax.jit
Expand Down Expand Up @@ -489,7 +501,7 @@
val = body_fun(i, val)
return val

def cond(pred, true_fun, false_fun, *operand):
def cond(pred, true_fun, false_fun, *operands):
"""Conditionally apply true_fun or false_fun.

This version is for the numpy backend, for jax backend see jax.lax.cond
Expand All @@ -502,7 +514,7 @@
Function (A -> B), to be applied if pred is True.
false_fun: callable
Function (A -> B), to be applied if pred is False.
operand: any
operands: any
input to either branch depending on pred. The type can be a scalar, array,
or any pytree (nested Python tuple/list/dict) thereof.

Expand All @@ -515,9 +527,9 @@

"""
if pred:
return true_fun(*operand)
return true_fun(*operands)
else:
return false_fun(*operand)
return false_fun(*operands)

def switch(index, branches, operand):
"""Apply exactly one of branches given by index.
Expand Down
2 changes: 1 addition & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def surface_integrals_map(grid, surface_label="rho", expand_out=True, tol=1e-14)
has_endpoint_dupe,
lambda _: put(mask, jnp.array([0, -1]), mask[0] | mask[-1]),
lambda _: mask,
operand=None,
None,
)
else:
expand_out = False
Expand Down
7 changes: 6 additions & 1 deletion desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
RadialForceBalance,
)
from ._free_boundary import BoundaryError, VacuumBoundaryError
from ._generic import GenericObjective, LinearObjectiveFromUser, ObjectiveFromUser
from ._generic import (
GenericObjective,
LinearObjectiveFromUser,
ObjectiveFromUser,
_ExternalObjective,
)
from ._geometry import (
AspectRatio,
BScaleLength,
Expand Down
Loading
Loading