Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic coil objectives #853

Merged
merged 87 commits into from
Mar 28, 2024
Merged
Changes from 1 commit
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
24fdbc6
initial commit
kianorr Feb 3, 2024
dd1c024
Merge branch 'master' into ko/coil_length
kianorr Feb 3, 2024
68301c0
add simple test
kianorr Feb 4, 2024
1b8306f
Merge branch master into ko/coil_length
kianorr Feb 4, 2024
834b6d0
add default args plus todos
kianorr Feb 4, 2024
45b7282
Merge branch 'master' into ko/coil_length
kianorr Feb 5, 2024
06bef53
add to objectives __init__
kianorr Feb 5, 2024
aea78c1
merge PR#840 into ko/coil_length
kianorr Feb 5, 2024
2d64ab5
delete comments and change grid
kianorr Feb 5, 2024
a64ce35
add docstrings and grid option
kianorr Feb 6, 2024
30e72ad
Merge branch 'master' into ko/coil_length
dpanici Feb 6, 2024
5541546
Merge branch 'origin/ko/coil_length' into ko/coil_length
kianorr Feb 7, 2024
6026cf6
merge dd/curve into ko/coil_length
kianorr Feb 7, 2024
5c8823b
Merge branch 'master' into ko/coil_length
kianorr Feb 8, 2024
2ddc10f
initial commit
kianorr Feb 3, 2024
6326552
add simple test
kianorr Feb 4, 2024
c424ec1
add default args plus todos
kianorr Feb 4, 2024
22b4b63
add to objectives __init__
kianorr Feb 5, 2024
c874e84
delete comments and change grid
kianorr Feb 5, 2024
b1e34eb
add docstrings and grid option
kianorr Feb 6, 2024
0e2315d
Merge branch 'ko/coil_length' into ko/coil_length
kianorr Feb 8, 2024
ffc857c
Revert "Merge branch 'ko/coil_length' into ko/coil_length"
kianorr Feb 8, 2024
09fe488
add logic for CoilSet
kianorr Feb 9, 2024
e612fb4
add test for CoilSet
kianorr Feb 9, 2024
c21305c
add mixed coil functionality
kianorr Feb 10, 2024
edae7b8
correct dim_f
kianorr Feb 10, 2024
bc0e23e
Merge branch 'master' into ko/coil_length
kianorr Feb 10, 2024
1ab2e75
add base class for coil objectives
kianorr Feb 11, 2024
dbfb77d
add kwargs
kianorr Feb 11, 2024
6f2cf35
Merge branch 'ko/coil_length' of https://github.com/PlasmaControl/DES…
kianorr Feb 11, 2024
39bff34
Merge branch 'master' into ko/coil_length
kianorr Feb 19, 2024
01c872b
address easy reviews
kianorr Feb 21, 2024
fc6cfd9
Merge branch 'ko/coil_length' of https://github.com/PlasmaControl/DES…
kianorr Feb 21, 2024
05bf325
Merge branch 'master' into ko/coil_length
kianorr Feb 21, 2024
b0658af
Merge branch 'master' into ko/coil_length
kianorr Feb 24, 2024
a6ef1c2
add tree_flatten for nested coilsets
kianorr Feb 24, 2024
209ea6d
fix import bug
kianorr Feb 24, 2024
9d378be
Merge branch 'master' into ko/coil_length
dpanici Mar 6, 2024
47f615f
Merge branch 'master' into ko/coil_length
kianorr Mar 7, 2024
2559aa1
add nested coil sets functionality
kianorr Mar 7, 2024
af3a424
use tree_flatten from backend and import jax
kianorr Mar 7, 2024
024f0f2
use tree_map for transforms
kianorr Mar 7, 2024
7535497
add logic for default grid
kianorr Mar 7, 2024
00c7bcd
add torsion objective and tests
kianorr Mar 7, 2024
73eafa4
update docstrings
kianorr Mar 7, 2024
dabdba1
change curvature target to bounds
kianorr Mar 7, 2024
2acba22
correct self._dim_f
kianorr Mar 7, 2024
be8fcb4
add normalization
kianorr Mar 7, 2024
bc129d0
move everything to _coils.py
kianorr Mar 7, 2024
2c36497
Merge branch 'master' into ko/coil_length
kianorr Mar 8, 2024
7beb157
import from _coils
kianorr Mar 8, 2024
337c0fd
Merge branch 'ko/coil_length' of https://github.com/PlasmaControl/DES…
kianorr Mar 8, 2024
3523178
addressed reviews
kianorr Mar 9, 2024
cdd348e
fix logic for type checking
kianorr Mar 10, 2024
fdb6cae
add actual nested coilset
kianorr Mar 11, 2024
7c06e13
change grid logic
kianorr Mar 12, 2024
7e7c1c9
refactor
kianorr Mar 12, 2024
c471948
Merge branch 'master' into ko/coil_length
kianorr Mar 12, 2024
a3ffe3b
Merge branch 'ko/coil_length' of https://github.com/PlasmaControl/DES…
kianorr Mar 12, 2024
77645c5
add comments
kianorr Mar 13, 2024
d071f8c
change if to elif statements
kianorr Mar 13, 2024
752fec9
refactor
kianorr Mar 13, 2024
86537fd
Merge branch 'master' into ko/coil_length
f0uriest Mar 14, 2024
e6712f0
add missing names
dpanici Mar 15, 2024
118b1cc
fix current setting of coils, which could fail if a size 1 dim 1 ndar…
dpanici Mar 15, 2024
cf0ddef
add missing print value fmt, units and whether coil objs are scalar
dpanici Mar 15, 2024
0de46f1
add test for objectives, though torsion test is not working
dpanici Mar 15, 2024
e6814b3
fix torsion test
dpanici Mar 15, 2024
51d2f02
remove unneeded print
dpanici Mar 15, 2024
3c59c73
make test a little faster and more compact
dpanici Mar 15, 2024
1000d9d
Merge branch 'master' into ko/coil_length
dpanici Mar 17, 2024
e8522bd
Merge branch 'master' into ko/coil_length
kianorr Mar 18, 2024
14095f6
adress reviews
kianorr Mar 18, 2024
2dc0510
Merge branch 'ko/coil_length' of https://github.com/PlasmaControl/DES…
kianorr Mar 18, 2024
5dd7ca3
merge master and resolve import conflict
kianorr Mar 18, 2024
53a2a20
add quad_weights to constants
kianorr Mar 19, 2024
24c393e
add quad_weights with spacing
kianorr Mar 19, 2024
2a4dd3b
concatenate quad_weights
kianorr Mar 19, 2024
24afec7
Merge branch 'master' into ko/coil_length
kianorr Mar 19, 2024
c630ca9
update docstrings
kianorr Mar 20, 2024
d6a498e
Merge branch 'ko/coil_length' of https://github.com/PlasmaControl/DES…
kianorr Mar 20, 2024
321851b
fix docstring
kianorr Mar 21, 2024
4ab1518
Merge branch 'master' into ko/coil_length
kianorr Mar 21, 2024
01a6bc2
delete scalar comments
kianorr Mar 27, 2024
caaae44
Merge branch 'master' into ko/coil_length
f0uriest Mar 27, 2024
442fc96
Merge branch 'master' into ko/coil_length
ddudt Mar 28, 2024
b0662cc
Merge branch 'master' into ko/coil_length
kianorr Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
change grid logic
- have a case for a single LinearGrid
- have a case for nested grids
- get dim_f correctly
kianorr committed Mar 12, 2024
commit 7c06e13592ba87d3b27f9e0bf8222fbda56b9917
10 changes: 10 additions & 0 deletions desc/backend.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,9 @@
from jax.tree_util import (
register_pytree_node,
tree_flatten,
tree_leaves,
tree_map,
tree_structure,
tree_unflatten,
)

@@ -402,6 +404,14 @@ def tree_map(*args, **kwargs):
"""Map pytree for numpy backend."""
raise NotImplementedError

def tree_structure(*args, **kwargs):
"""Get structure of pytree for numpy backend."""
raise NotImplementedError

def tree_leaves(*args, **kwargs):
"""Get leaves of pytree for numpy backend."""
raise NotImplementedError

def register_pytree_node(foo, *args):
"""Dummy decorator for non-jax pytrees."""
return foo
108 changes: 82 additions & 26 deletions desc/objectives/_coils.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,14 @@

import numpy as np

from desc.backend import jnp, tree_flatten, tree_map
from desc.backend import (
jnp,
tree_flatten,
tree_leaves,
tree_map,
tree_structure,
tree_unflatten,
)
from desc.compute import get_transforms
from desc.grid import LinearGrid
from desc.utils import Timer
@@ -82,7 +89,7 @@
name=name,
)

def build(self, use_jit=True, verbose=1):
def build(self, use_jit=True, verbose=1): # noqa:C901
"""Build constant arrays.
Parameters
@@ -96,38 +103,62 @@
# local import to avoid circular import
from desc.coils import CoilSet, MixedCoilSet, _Coil

is_mixed_coils = isinstance(self.things[0], MixedCoilSet)
is_coil_set = isinstance(self.things[0], CoilSet)

coils = tree_flatten(
self.things[0],
is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet),
)[0]
self._dim_f = 0

def get_dim_f(coilset):
"""Turn a coilset into nested lists."""
if isinstance(coilset, list):
[get_dim_f(x) for x in coilset]
elif isinstance(coilset, MixedCoilSet):
[get_dim_f(x) for x in coilset]
elif isinstance(coilset, CoilSet):
get_dim_f(coilset.coils)
elif isinstance(coilset, LinearGrid):
dpanici marked this conversation as resolved.
Show resolved Hide resolved
self._dim_f += coilset.num_zeta

def to_list(coilset):
"""Turn a coilset into nested lists."""
if isinstance(coilset, list):
return [to_list(x) for x in coilset]

Check warning on line 122 in desc/objectives/_coils.py

Codecov / codecov/patch

desc/objectives/_coils.py#L122

Added line #L122 was not covered by tests
if isinstance(coilset, MixedCoilSet):
return [to_list(x) for x in coilset]
if isinstance(coilset, CoilSet):
# use the same grid/transform for CoilSet
return to_list(coilset.coils[0])
else:
return coilset

# if using single coil, make coils and grid a list so they can be
# used with tree_map
coils = [coils[0]] if not is_coil_set else coils
is_mixed_coils = isinstance(self.things[0], MixedCoilSet)
is_single_coil = lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet)

# check type
if isinstance(self._grid, numbers.Integral):
self._grid = LinearGrid(N=self._grid, endpoint=False)

Check warning on line 136 in desc/objectives/_coils.py

Codecov / codecov/patch

desc/objectives/_coils.py#L136

Added line #L136 was not covered by tests
if self._grid is None:
self._grid = tree_map(
lambda x: LinearGrid(
N=2 * x.N + 5, NFP=getattr(x, "NFP", 1), endpoint=False
),
self.things[0],
is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet),
is_leaf=lambda x: is_single_coil(x),
)
print(self._grid)

if not isinstance(self._grid, (tuple, list)):
self._grid = [self._grid]

if np.any([grid.num_rho > 1 or grid.num_theta > 1 for grid in self._grid]):
raise ValueError("Only use toroidal resolution for coil grids.")

self._dim_f = np.sum([grid.num_zeta for grid in self._grid])
elif isinstance(self._grid, LinearGrid):
treedef = tree_structure(
self.things[0],
is_leaf=lambda x: is_single_coil(x),
)
leaves = tree_leaves(self.things[0], is_leaf=lambda x: is_single_coil(x))
self._grid = [self._grid] * len(leaves)
self._grid = tree_unflatten(treedef, self._grid)
else:
flattened_grid = tree_flatten(
self._grid, is_leaf=lambda x: isinstance(x, LinearGrid)
kianorr marked this conversation as resolved.
Show resolved Hide resolved
)[0]
treedef = tree_structure(
self.things[0],
is_leaf=lambda x: is_single_coil(x),
)
self._grid = tree_unflatten(treedef, flattened_grid)

timer = Timer()
if verbose > 0:
@@ -138,9 +169,20 @@
lambda x, y: get_transforms(self._data_keys, obj=x, grid=y),
self.things[0],
self._grid,
is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, MixedCoilSet),
is_leaf=lambda x: is_single_coil(x),
)

get_dim_f(self._grid)
self._grid = to_list(self._grid)
transforms = to_list(transforms)

if not isinstance(self._grid, (tuple, list)):
self._grid = [self._grid]
transforms = [transforms]

if np.any([grid.num_rho > 1 or grid.num_theta > 1 for grid in self._grid]):
raise ValueError("Only use toroidal resolution for coil grids.")

Check warning on line 184 in desc/objectives/_coils.py

Codecov / codecov/patch

desc/objectives/_coils.py#L184

Added line #L184 was not covered by tests

# tree map always returns a list so take first transform and grid
# for when we are only using a single coil
if not is_mixed_coils:
@@ -151,10 +193,14 @@

timer.stop("Precomputing transforms")
if verbose > 1:
timer.disp("Precomputing transforms")

Check warning on line 196 in desc/objectives/_coils.py

Codecov / codecov/patch

desc/objectives/_coils.py#L196

Added line #L196 was not covered by tests

flattened_coils = tree_flatten(
self.things[0],
is_leaf=lambda x: is_single_coil(x),
)[0]
if self._normalize:
self._scales = compute_scaling_factors(coils[0])
self._scales = compute_scaling_factors(flattened_coils[0])

super().build(use_jit=use_jit, verbose=verbose)

@@ -269,14 +315,24 @@
Level of output.
"""
from desc.coils import CoilSet
from desc.coils import CoilSet, _Coil

super().build(use_jit=use_jit, verbose=verbose)

if self._normalize:
self._normalization = self._scales["a"]

self._dim_f = len(self._coils.coils) if isinstance(self._coils, CoilSet) else 1
# TODO: repeated code but maybe it's fine
flattened_coils = tree_flatten(
self._coils,
is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet),
)[0]
flattened_coils = (
[flattened_coils[0]]
if not isinstance(self._coils, CoilSet)
else flattened_coils
)
self._dim_f = len(flattened_coils)

def compute(self, params, constants=None):
"""Compute coil length.
48 changes: 30 additions & 18 deletions tests/test_objective_funs.py
Original file line number Diff line number Diff line change
@@ -584,15 +584,19 @@ def test(coil, grid=None):
assert len(f) == obj.dim_f

coil = FourierPlanarCoil(r_n=1)
coils = CoilSet.linspaced_linear(coil, n=4)
mixed_coils = MixedCoilSet.linspaced_linear(coil, n=4)
mixed_coils_grid = [LinearGrid(N=5)] * len(mixed_coils.coils)
coils = CoilSet.linspaced_linear(coil, n=2)
mixed_coils = MixedCoilSet.linspaced_linear(coil, n=2)
nested_coils = MixedCoilSet(coils, coils)

test(coil)
nested_grids = [
[LinearGrid(N=5), LinearGrid(N=5)],
[LinearGrid(N=5), LinearGrid(N=5)],
]

test(coil, grid=LinearGrid(N=5))
test(coils)
test(mixed_coils, grid=mixed_coils_grid)
test(nested_coils)
test(mixed_coils, grid=[LinearGrid(N=5)] * len(mixed_coils.coils))
test(nested_coils, grid=nested_grids)

@pytest.mark.unit
def test_coil_curvature(self):
@@ -606,15 +610,19 @@ def test(coil, grid=None):
assert len(f) == obj.dim_f

coil = FourierPlanarCoil()
coils = CoilSet.linspaced_linear(coil, n=4)
mixed_coils = MixedCoilSet.linspaced_linear(coil, n=4)
mixed_coils_grid = [LinearGrid(N=5)] * len(mixed_coils.coils)
coils = CoilSet.linspaced_linear(coil, n=2)
mixed_coils = MixedCoilSet.linspaced_linear(coil, n=2)
nested_coils = MixedCoilSet(coils, coils)

test(coil)
nested_grids = [
[LinearGrid(N=5), LinearGrid(N=5)],
[LinearGrid(N=5), LinearGrid(N=5)],
]

test(coil, grid=LinearGrid(N=5))
test(coils)
test(mixed_coils, grid=mixed_coils_grid)
test(nested_coils)
test(mixed_coils, grid=[LinearGrid(N=5)] * len(mixed_coils.coils))
test(nested_coils, grid=nested_grids)

@pytest.mark.unit
def test_coil_torsion(self):
@@ -628,15 +636,19 @@ def test(coil, grid=None):
assert len(f) == obj.dim_f

coil = FourierPlanarCoil()
coils = CoilSet.linspaced_linear(coil, n=4)
mixed_coils = MixedCoilSet.linspaced_linear(coil, n=4)
mixed_coils_grid = [LinearGrid(N=5)] * len(mixed_coils.coils)
coils = CoilSet.linspaced_linear(coil, n=2)
mixed_coils = MixedCoilSet.linspaced_linear(coil, n=2)
nested_coils = MixedCoilSet(coils, coils)

test(coil)
nested_grids = [
[LinearGrid(N=5), LinearGrid(N=5)],
[LinearGrid(N=5), LinearGrid(N=5)],
]

test(coil, grid=LinearGrid(N=5))
test(coils)
test(mixed_coils, grid=mixed_coils_grid)
test(nested_coils)
test(mixed_coils, grid=[LinearGrid(N=5)] * len(mixed_coils.coils))
test(nested_coils, grid=nested_grids)


@pytest.mark.unit