Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor surface and curve compute methods #583

Merged
merged 40 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0be4304
Break curve compute functions into standalone module
f0uriest Jun 27, 2023
314742f
Update formula for cross section area to work for surfaces
f0uriest Jul 5, 2023
d82ce84
Merge branch 'rc/compute_utils' into rc/compute_geometry
f0uriest Jul 6, 2023
c3711cd
Update curve compute funs to regular data_index
f0uriest Jul 6, 2023
4ddf1da
Add basic surface compute funs
f0uriest Jul 6, 2023
abdc04a
Fix typos
f0uriest Jul 6, 2023
b6bb297
Add curve/surface stuff to compute init
f0uriest Jul 6, 2023
b31030a
Merge branch 'rc/compute_utils' into rc/compute_geometry
f0uriest Jul 13, 2023
42d6822
Prevent overwriting of dependencies of shared functions
f0uriest Jul 14, 2023
0f33c18
Add rotmat and shift to transforms
f0uriest Jul 14, 2023
15f95e8
Rename bases for curve classes
f0uriest Jul 14, 2023
ca39f26
Remove grid, transform, compute methods from curve classes
f0uriest Jul 14, 2023
b9bf592
Fix norm axis in curve compute
f0uriest Jul 14, 2023
4eb0fb6
Add compute function to curve base class, add tests
f0uriest Jul 14, 2023
d6be2e0
Rename curve position vector from r to x
f0uriest Jul 14, 2023
67fc53f
Add basis to curve compute, fix basis conversion
f0uriest Jul 14, 2023
e44e70d
Finish updating curve classes
f0uriest Jul 14, 2023
d27ebde
Update curve tests with new API
f0uriest Jul 14, 2023
10d0224
Merge branch 'rc/geometry' into rc/compute_geometry
f0uriest Jul 15, 2023
b0ea3f7
Merge branch 'rc/geometry' into rc/compute_geometry
f0uriest Jul 15, 2023
ed15629
Add parameterization info to misc geometry compute funs
f0uriest Jul 15, 2023
5496dc0
Update surface compute data to match existing name conventions
f0uriest Jul 15, 2023
559c480
Merge branch 'rc/compute_utils' into rc/compute_geometry
f0uriest Jul 15, 2023
fd6a0bf
Add compute method to surface base class, add test
f0uriest Jul 15, 2023
2e7a558
Use surface label rather than resolution to determine grid type
f0uriest Jul 15, 2023
bed65cb
Add new method for computing outermost surface area
f0uriest Jul 15, 2023
2647cca
Remove grid, transform, old compute methods from surface classes
f0uriest Jul 15, 2023
8f9b598
Add basis kwarg to surface compute stuff
f0uriest Jul 15, 2023
ec7b657
Update coils to new compute API
f0uriest Jul 15, 2023
f9bc3cb
Merge branch 'rc/compute_utils' into rc/compute_geometry
f0uriest Jul 20, 2023
47f44dc
Remove old io attributes
f0uriest Jul 21, 2023
1f0f12d
Merge branch 'master' into rc/compute_utils
f0uriest Jul 28, 2023
489facd
Merge branch 'rc/compute_utils' into rc/compute_geometry
f0uriest Jul 28, 2023
9e6e84c
Merge branch 'rc/compute_utils' into rc/compute_geometry
f0uriest Jul 28, 2023
b9f3954
Update parameterization names
f0uriest Jul 28, 2023
bfe0892
Merge branch 'master' into rc/compute_geometry
f0uriest Jul 29, 2023
bb9f194
Merge branch 'master' into rc/compute_geometry
f0uriest Jul 31, 2023
e872df5
Merge branch 'master' into rc/compute_geometry
f0uriest Aug 2, 2023
88c3b84
Move geometry utils to compute module to avoid circular imports
f0uriest Aug 6, 2023
a925e00
Make geometry compute methods use default grid for global quantities
f0uriest Aug 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 79 additions & 48 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np

from desc.backend import jnp
from desc.compute import rpz2xyz, xyz2rpz_vec
from desc.geometry import FourierPlanarCurve, FourierRZCurve, FourierXYZCurve
from desc.geometry.utils import rpz2xyz, xyz2rpz_vec
from desc.grid import Grid
from desc.magnetic_fields import MagneticField, biot_savart

Expand Down Expand Up @@ -46,7 +46,7 @@ def current(self, new):
assert jnp.isscalar(new) or new.size == 1
self._current = new

def compute_magnetic_field(self, coords, params={}, basis="rpz"):
def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None):
"""Compute magnetic field at a set of points.

The coil is discretized into a series of straight line segments, using
Expand All @@ -64,6 +64,9 @@ def compute_magnetic_field(self, coords, params={}, basis="rpz"):
parameters to pass to curve
basis : {"rpz", "xyz"}
basis for input coordinates and returned magnetic field
grid : Grid, int or None
Grid used to discretize coil. If an integer, uses that many equally spaced
points.

Returns
-------
Expand All @@ -76,8 +79,11 @@ def compute_magnetic_field(self, coords, params={}, basis="rpz"):
coords = jnp.atleast_2d(coords)
if basis == "rpz":
coords = rpz2xyz(coords)
current = params.pop("current", self.current)
coil_coords = self.compute_coordinates(**params, basis="xyz")
if params is None:
current = self.current
else:
current = params.pop("current", self.current)
coil_coords = self.compute("x", params=params, grid=grid, basis="xyz")["x"]
B = biot_savart(coords, coil_coords, current)
if basis == "rpz":
B = xyz2rpz_vec(B, x=coords[:, 0], y=coords[:, 1])
Expand Down Expand Up @@ -110,8 +116,6 @@ class FourierRZCoil(Coil, FourierRZCurve):
number of field periods
sym : bool
whether to enforce stellarator symmetry
grid : Grid
default grid for computation
name : str
name for this coil
"""
Expand All @@ -127,10 +131,9 @@ def __init__(
modes_Z=None,
NFP=1,
sym="auto",
grid=None,
name="",
):
super().__init__(current, R_n, Z_n, modes_R, modes_Z, NFP, sym, grid, name)
super().__init__(current, R_n, Z_n, modes_R, modes_Z, NFP, sym, name)


class FourierXYZCoil(Coil, FourierXYZCurve):
Expand All @@ -144,8 +147,6 @@ class FourierXYZCoil(Coil, FourierXYZCurve):
fourier coefficients for X, Y, Z
modes : array-like
mode numbers associated with X_n etc.
grid : Grid
default grid or computation
name : str
name for this coil

Expand All @@ -160,10 +161,9 @@ def __init__(
Y_n=[0, 0, 0],
Z_n=[-2, 0, 0],
modes=None,
grid=None,
name="",
):
super().__init__(current, X_n, Y_n, Z_n, modes, grid, name)
super().__init__(current, X_n, Y_n, Z_n, modes, name)


class FourierPlanarCoil(Coil, FourierPlanarCurve):
Expand All @@ -185,8 +185,6 @@ class FourierPlanarCoil(Coil, FourierPlanarCurve):
fourier coefficients for radius from center as function of polar angle
modes : array-like
mode numbers associated with r_n
grid : Grid
default grid for computation
name : str
name for this coil

Expand All @@ -201,10 +199,9 @@ def __init__(
normal=[0, 1, 0],
r_n=2,
modes=None,
grid=None,
name="",
):
super().__init__(current, center, normal, r_n, modes, grid, name)
super().__init__(current, center, normal, r_n, modes, name)


class CoilSet(Coil, MutableSequence):
Expand Down Expand Up @@ -251,35 +248,65 @@ def current(self, new):
for coil, cur in zip(self.coils, new):
coil.current = cur

@property
def grid(self):
"""Grid: nodes for computation."""
return self.coils[0].grid

@grid.setter
def grid(self, new):
for coil in self.coils:
coil.grid = new

def compute_coordinates(self, *args, **kwargs):
"""Compute real space coordinates using underlying curve method."""
return [coil.compute_coordinates(*args, **kwargs) for coil in self.coils]

def compute_frenet_frame(self, *args, **kwargs):
"""Compute Frenet frame using underlying curve method."""
return [coil.compute_frenet_frame(*args, **kwargs) for coil in self.coils]
def _make_arraylike(self, x):
if isinstance(x, dict):
x = [x] * len(self)
try:
len(x)
except TypeError:
x = [x] * len(self)
assert len(x) == len(self)
return x

def compute(
self,
names,
grid=None,
params=None,
transforms=None,
data=None,
**kwargs,
):
"""Compute the quantity given by name on grid, for each coil in the coilset.

def compute_curvature(self, *args, **kwargs):
"""Compute curvature using underlying curve method."""
return [coil.compute_curvature(*args, **kwargs) for coil in self.coils]
Parameters
----------
names : str or array-like of str
Name(s) of the quantity(s) to compute.
grid : Grid or int or array-like, optional
Grid of coordinates to evaluate at. Defaults to a Linear grid.
If an integer, uses that many equally spaced points.
If array-like, should be 1 value per coil.
params : dict of ndarray or array-like
Parameters from the equilibrium. Defaults to attributes of self.
If array-like, should be 1 value per coil.
transforms : dict of Transform or array-like
Transforms for R, Z, lambda, etc. Default is to build from grid.
If array-like, should be 1 value per coil.
data : dict of ndarray or array-like
Data computed so far, generally output from other compute functions
If array-like, should be 1 value per coil.

def compute_torsion(self, *args, **kwargs):
"""Compute torsion using underlying curve method."""
return [coil.compute_torsion(*args, **kwargs) for coil in self.coils]
Returns
-------
data : list of dict of ndarray
Computed quantity and intermediate variables, for each coil in the set.
List entries map to coils in coilset, each dict contains data for an
individual coil.

def compute_length(self, *args, **kwargs):
"""Compute the length of the curve using underlying curve method."""
return [coil.compute_length(*args, **kwargs) for coil in self.coils]
"""
grid = self._make_arraylike(grid)
params = self._make_arraylike(params)
transforms = self._make_arraylike(transforms)
data = self._make_arraylike(data)
return [
coil.compute(
names, grid=grd, params=par, transforms=tran, data=dat, **kwargs
)
for (coil, grd, par, tran, dat) in zip(
self.coils, grid, params, transforms, data
)
]

def translate(self, *args, **kwargs):
"""Translate the coils along an axis."""
Expand All @@ -293,7 +320,7 @@ def flip(self, *args, **kwargs):
"""Flip the coils across a plane."""
[coil.flip(*args, **kwargs) for coil in self.coils]

def compute_magnetic_field(self, coords, params={}, basis="rpz"):
def compute_magnetic_field(self, coords, params=None, basis="rpz", grid=None):
"""Compute magnetic field at a set of points.

Parameters
Expand All @@ -305,18 +332,22 @@ def compute_magnetic_field(self, coords, params={}, basis="rpz"):
or one for each member
basis : {"rpz", "xyz"}
basis for input coordinates and returned magnetic field
grid : Grid, int or None or array-like, optional
Grid used to discretize coil, either the same for all coils or one for each
member of the coilset. If an integer, uses that many equally spaced
points.

Returns
-------
field : ndarray, shape(n,3)
magnetic field at specified points, in either rpz or xyz coordinates
"""
if isinstance(params, dict):
params = [params] * len(self)
assert len(params) == len(self)
params = self._make_arraylike(params)
grid = self._make_arraylike(grid)

B = 0
for coil, par in zip(self.coils, params):
B += coil.compute_magnetic_field(coords, par, basis)
for coil, par, grd in zip(self.coils, params, grid):
B += coil.compute_magnetic_field(coords, par, basis, grd)

return B

Expand Down
3 changes: 3 additions & 0 deletions desc/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@
_basis_vectors,
_bootstrap,
_core,
_curve,
_equil,
_field,
_geometry,
_metric,
_profiles,
_qs,
_stability,
_surface,
)
from .data_index import data_index
from .geom_utils import rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .utils import (
arg_order,
compute,
Expand Down
12 changes: 12 additions & 0 deletions desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,10 @@ def _b(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_theta", "e_zeta", "|e_theta x e_zeta|"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_rho(params, transforms, profiles, data, **kwargs):
# equal to e^rho / |e^rho| but works correctly for surfaces as well that don't have
Expand All @@ -688,6 +692,10 @@ def _n_rho(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_rho", "e_zeta", "|e_zeta x e_rho|"],
dpanici marked this conversation as resolved.
Show resolved Hide resolved
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_theta(params, transforms, profiles, data, **kwargs):
data["n_theta"] = (
Expand All @@ -708,6 +716,10 @@ def _n_theta(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="rtz",
data=["e_rho", "e_theta", "|e_rho x e_theta|"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
)
def _n_zeta(params, transforms, profiles, data, **kwargs):
data["n_zeta"] = (
Expand Down
Loading