Skip to content

Commit

Permalink
HermiteSplineProfile to improve accuracy of coordinate mapping (#1199)
Browse files Browse the repository at this point in the history
- [x] Adds `HermiteSplineProfile`.
  - Relevant for #1201.
- [x] Resolves #1200.
- Another solution would be to use the radial basis polynomials of the
quadrature grid for spectral spline of iota.
- [x]  Resolves #1203
  • Loading branch information
dpanici authored Aug 20, 2024
2 parents dd3f472 + 236af07 commit 13108f6
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 46 deletions.
2 changes: 1 addition & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None):
Returns
-------
deps : list of str
deps : list[str]
Names of quantities needed to compute key.
"""
Expand Down
17 changes: 11 additions & 6 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,19 @@ def map_coordinates( # noqa: C901

# do surface average to get iota once
if "iota" in profiles and profiles["iota"] is None:
profiles["iota"] = eq.get_profile("iota", params=params)
profiles["iota"] = eq.get_profile(["iota", "iota_r"], params=params)
params["i_l"] = profiles["iota"].params

@functools.partial(jit, static_argnums=1)
def compute(y, basis):
grid = Grid(y, sort=False, jitable=True)
data = {}
if "iota" in deps:
data["iota"] = profiles["iota"](grid, params=params["i_l"])
data["iota"] = profiles["iota"].compute(grid, params=params["i_l"])
if "iota_r" in deps:
data["iota_r"] = profiles["iota"](grid, dr=1, params=params["i_l"])
data["iota_r"] = profiles["iota"].compute(grid, dr=1, params=params["i_l"])
if "iota_rr" in deps:
data["iota_rr"] = profiles["iota"](grid, dr=2, params=params["i_l"])
data["iota_rr"] = profiles["iota"].compute(grid, dr=2, params=params["i_l"])
transforms = get_transforms(basis, eq, grid, jitable=True)
data = compute_fun(eq, basis, params, transforms, profiles, data)
x = jnp.array([data[k] for k in basis]).T
Expand Down Expand Up @@ -243,7 +243,10 @@ def _initial_guess_heuristic(yk, coords, inbasis, eq, profiles):
theta = coords[:, inbasis.index(poloidal)]
elif poloidal == "alpha":
alpha = coords[:, inbasis.index("alpha")]
iota = profiles["iota"](rho)
rho = jnp.atleast_1d(rho)
zero = jnp.zeros_like(rho)
grid = Grid(nodes=jnp.column_stack([rho, zero, zero]), sort=False, jitable=True)
iota = profiles["iota"].compute(grid)
theta = (alpha + iota * zeta) % (2 * jnp.pi)

yk = jnp.column_stack([rho, theta, zeta])
Expand Down Expand Up @@ -677,7 +680,7 @@ def get_rtz_grid(
rtz : rho, theta, zeta
period : tuple of float
Assumed periodicity for each quantity in inbasis.
Use np.inf to denote no periodicity.
Use ``np.inf`` to denote no periodicity.
jitable : bool, optional
If false the returned grid has additional attributes.
Required to be false to retain nodes at magnetic axis.
Expand All @@ -691,6 +694,8 @@ def get_rtz_grid(
grid = Grid.create_meshgrid(
[radial, poloidal, toroidal], coordinates=coordinates, period=period
)
if "iota" in kwargs:
kwargs["iota"] = grid.expand(kwargs["iota"])
inbasis = {
"r": "rho",
"t": "theta",
Expand Down
21 changes: 12 additions & 9 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from desc.optimizable import Optimizable, optimizable_parameter
from desc.optimize import Optimizer
from desc.perturbations import perturb
from desc.profiles import PowerSeriesProfile, SplineProfile
from desc.profiles import HermiteSplineProfile, PowerSeriesProfile, SplineProfile
from desc.transform import Transform
from desc.utils import (
ResolutionWarning,
Expand Down Expand Up @@ -732,6 +732,8 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs):
----------
name : str
Name of the quantity to compute.
If list is given, then two names are expected: the quantity to spline
and its radial derivative.
grid : Grid, optional
Grid of coordinates to evaluate at. Defaults to the quadrature grid.
Note profile will only be a function of the radial coordinate.
Expand All @@ -748,14 +750,17 @@ def get_profile(self, name, grid=None, kind="spline", **kwargs):
if grid is None:
grid = QuadratureGrid(self.L_grid, self.M_grid, self.N_grid, self.NFP)
data = self.compute(name, grid=grid, **kwargs)
f = data[name]
f = grid.compress(f, surface_label="rho")
x = grid.nodes[grid.unique_rho_idx, 0]
p = SplineProfile(f, x, name=name)
knots = grid.compress(grid.nodes[:, 0])
if isinstance(name, str):
f = grid.compress(data[name])
p = SplineProfile(f, knots, name=name)
else:
f, df = map(grid.compress, (data[name[0]], data[name[1]]))
p = HermiteSplineProfile(f, df, knots, name=name)
if kind == "power_series":
p = p.to_powerseries(order=min(self.L, len(x)), xs=x, sym=True)
p = p.to_powerseries(order=min(self.L, grid.num_rho), xs=knots, sym=True)
if kind == "fourier_zernike":
p = p.to_fourierzernike(L=min(self.L, len(x)), xs=x)
p = p.to_fourierzernike(L=min(self.L, grid.num_rho), xs=knots)
return p

def get_axis(self):
Expand Down Expand Up @@ -1161,8 +1166,6 @@ def map_coordinates(
Parameters
----------
eq : Equilibrium
Equilibrium to use.
coords : ndarray
Shape (k, 3).
2D array of input coordinates. Each row is a different point in space.
Expand Down
159 changes: 129 additions & 30 deletions desc/profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
copy_coeffs,
errorif,
multinomial_coefficients,
setdefault,
warnif,
)

Expand Down Expand Up @@ -613,7 +614,7 @@ def get_params(self, l):

def set_params(self, l, a=None):
"""Set specific power series coefficients."""
l, a = np.atleast_1d(l), np.atleast_1d(a)
l, a = np.atleast_1d(l, a)
a = np.broadcast_to(a, l.shape)
for ll, aa in zip(l, a):
idx = self.basis.get_idx(ll, 0, 0)
Expand Down Expand Up @@ -793,24 +794,25 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0):


class SplineProfile(_Profile):
"""Profile represented by a piecewise cubic spline.
"""Radial profile represented by a piecewise cubic spline.
Parameters
----------
values: array-like
Values of the function at knot locations.
knots : int or ndarray
x locations to use for spline. If an integer, uses that many points linearly
spaced between 0,1
1-D array containing values of the dependent variable.
knots : array-like
1-D array containing values of the independent variable.
Must be real, finite, and in strictly increasing order in [0, 1].
If ``None``, assumes ``values`` is given on knots uniformly spaced in [0, 1].
method : str
method of interpolation
Method of interpolation. Default is cubic2.
- `'nearest'`: nearest neighbor interpolation
- `'linear'`: linear interpolation
- `'cubic'`: C1 cubic splines (aka local splines)
- `'cubic2'`: C2 cubic splines (aka natural splines)
- `'catmull-rom'`: C1 cubic centripetal "tension" splines
name : str
name of the profile
Optional name of the profile.
"""

Expand All @@ -821,11 +823,12 @@ def __init__(self, values=None, knots=None, method="cubic2", name=""):

if values is None:
values = [0, 0, 0]
values = np.atleast_1d(values)
values = jnp.atleast_1d(values)
if knots is None:
knots = np.linspace(0, 1, values.size)
else:
knots = np.atleast_1d(knots)
knots = jnp.linspace(0, 1, values.size)
knots = jnp.atleast_1d(knots)
errorif(values.shape[-1] != knots.shape[-1])
errorif(not (values.ndim == knots.ndim == 1), NotImplementedError)
self._knots = knots
self._params = values
self._method = method
Expand All @@ -834,7 +837,7 @@ def __repr__(self):
"""Get the string form of the object."""
s = super().__repr__()
s = s[:-1]
s += ", method={}, num_knots={})".format(self._method, len(self._knots))
s += ", method={}, num_knots={})".format(self._method, self._knots.size)
return s

@property
Expand All @@ -849,24 +852,23 @@ def params(self):

@params.setter
def params(self, new):
if len(new) == len(self._knots):
self._params = jnp.asarray(new)
else:
raise ValueError(
"params should have the same size as the knots, "
+ f"got {len(new)} values for {len(self._knots)} knots"
)
errorif(
len(new) != self._knots.size,
msg="params should have the same size as the knots, "
+ f"got {len(new)} values for {self._knots.size} knots",
)
self._params = jnp.asarray(new)

def compute(self, grid, params=None, dr=0, dt=0, dz=0):
"""Compute values of profile at specified nodes.
Parameters
----------
grid : Grid
locations to compute values at.
Locations to compute values at.
params : array-like
spline values to use. If not given, uses the
values given by the params attribute
Values of the function at ``self.knots``.
If not given, uses ``self.params``.
dr, dt, dz : int
derivative order in rho, theta, zeta
Expand All @@ -876,15 +878,112 @@ def compute(self, grid, params=None, dr=0, dt=0, dz=0):
values of the profile or its derivative at the points specified
"""
if params is None:
params = self.params
if dt != 0 or dz != 0:
return jnp.zeros_like(grid.nodes[:, 0])
x = self.knots
f = params
xq = grid.nodes[:, 0]
fq = interp1d(xq, x, f, method=self._method, derivative=dr, extrap=True)
return fq
params = setdefault(params, self._params)
return interp1d(
xq=grid.nodes[:, 0],
x=self._knots,
f=params,
method=self._method,
derivative=dr,
extrap=True,
)


class HermiteSplineProfile(_Profile):
"""Radial profile represented by a piecewise cubic Hermite spline.
Parameters
----------
f: array-like
1-D array containing values of the dependent variable.
df: array-like
1-D array containing derivatives of the dependent variable.
knots : array-like
1-D array containing values of the independent variable.
Must be real, finite, and in strictly increasing order in [0, 1].
If ``None``, assumes ``f`` and ``df`` are given on knots uniformly
spaced in [0, 1].
name : str
Optional name of the profile.
"""

_io_attrs_ = _Profile._io_attrs_ + ["_knots", "_params"]

def __init__(self, f, df, knots=None, name=""):
super().__init__(name)

f, df = jnp.atleast_1d(f, df)
if knots is None:
knots = jnp.linspace(0, 1, f.size)
knots = jnp.atleast_1d(knots)
errorif(not (f.shape[-1] == df.shape[-1] == knots.shape[-1]))
errorif(not (f.ndim == df.ndim == knots.ndim == 1), NotImplementedError)
self._knots = knots
self._params = jnp.concatenate([f, df])

def __repr__(self):
"""Get the string form of the object."""
s = super().__repr__()
s = s[:-1]
s += ", num_knots={})".format(self._knots.size)
return s

@property
def knots(self):
"""ndarray: Knot locations."""
return self._knots

@property
def params(self):
"""ndarray: Parameters for computation.
First (second) half stores function (derivative) values at ``knots``.
"""
return self._params

@params.setter
def params(self, new):
new = jnp.asarray(new)
errorif(
new.ndim != 1 or new.size != 2 * self._knots.size,
msg="Params should be 1D with size twice number of knots. "
f"Got {new.shape} params for {self._knots.size} knots.",
)
self._params = new

def compute(self, grid, params=None, dr=0, dt=0, dz=0):
"""Compute values of profile at specified nodes.
Parameters
----------
grid : Grid
Locations to compute values at.
params : array-like
First (second) half stores function (derivative) values at ``knots``.
If not given, uses ``self.params``.
dr, dt, dz : int
derivative order in rho, theta, zeta
Returns
-------
f : ndarray
Array containing values of the dependent variable at the points specified.
"""
if dt != 0 or dz != 0:
return jnp.zeros_like(grid.nodes[:, 0])
params = setdefault(params, self._params)
return interp1d(
xq=grid.nodes[:, 0],
x=self._knots,
f=params[: self._knots.size],
fx=params[self._knots.size :],
derivative=dr,
extrap=True,
)


class MTanhProfile(_Profile):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scipy.interpolate import interp1d

from desc.equilibrium import Equilibrium
from desc.examples import get
from desc.grid import LinearGrid
from desc.io import InputReader
from desc.objectives import (
Expand All @@ -15,6 +16,7 @@
)
from desc.profiles import (
FourierZernikeProfile,
HermiteSplineProfile,
MTanhProfile,
PowerSeriesProfile,
SplineProfile,
Expand Down Expand Up @@ -507,3 +509,14 @@ def test_kinetic_pressure(self):
assert np.all(data2["Te_r"] == data2["Ti_r"])
np.testing.assert_allclose(data1["p"], data2["p"])
np.testing.assert_allclose(data1["p_r"], data2["p_r"])

@pytest.mark.unit
def test_hermite_spline_solve(self):
"""Test that spline with double number of parameters is optimized."""
eq = get("DSHAPE")
rho = np.linspace(0, 1.0, 20, endpoint=True)
eq.pressure = HermiteSplineProfile(
eq.pressure(rho), eq.pressure(rho, dr=1), rho
)
eq.solve()
assert eq.is_nested()

0 comments on commit 13108f6

Please sign in to comment.