Skip to content

Commit

Permalink
Make users pass in spacing/weights to custom grids (#985)
Browse files Browse the repository at this point in the history
Resolves #981
  • Loading branch information
dpanici authored May 8, 2024
2 parents 267d533 + abf5b4f commit 082be4e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 208 deletions.
16 changes: 13 additions & 3 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from desc.backend import fori_loop, jit, jnp, put, root, root_scalar, vmap
from desc.compute import compute as compute_fun
from desc.compute import data_index, get_profiles, get_transforms
from desc.compute import data_index, get_data_deps, get_profiles, get_transforms
from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid
from desc.transform import Transform
from desc.utils import setdefault
Expand Down Expand Up @@ -105,8 +105,11 @@ def periodic(x):
coords = periodic(coords)

params = setdefault(params, eq.params_dict)

profiles = get_profiles(inbasis + basis_derivs, eq, None)
p = "desc.equilibrium.equilibrium.Equilibrium"
names = inbasis + basis_derivs + outbasis
deps = list(set(get_data_deps(names, obj=p) + list(names)))

# do surface average to get iota once
if "iota" in profiles and profiles["iota"] is None:
profiles["iota"] = eq.get_profile("iota", params=params)
Expand All @@ -115,8 +118,15 @@ def periodic(x):
@functools.partial(jit, static_argnums=1)
def compute(y, basis):
grid = Grid(y, sort=False, jitable=True)
data = {}
if "iota" in deps:
data["iota"] = profiles["iota"](grid, params=params["i_l"])
if "iota_r" in deps:
data["iota_r"] = profiles["iota"](grid, dr=1, params=params["i_l"])
if "iota_rr" in deps:
data["iota_rr"] = profiles["iota"](grid, dr=2, params=params["i_l"])
transforms = get_transforms(basis, eq, grid, jitable=True)
data = compute_fun(eq, basis, params, transforms, profiles)
data = compute_fun(eq, basis, params, transforms, profiles, data)
x = jnp.array([data[k] for k in basis]).T
return x

Expand Down
79 changes: 50 additions & 29 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,6 @@ def _scale_weights(self):
# duplicates nodes are scaled down properly regardless of which two columns
# span the surface.

# scale areas sum to full area
# The following operation is not a general solution to return the weight
# removed from the duplicate nodes back to the unique nodes.
# (For the 3 predefined grid types this line of code has no effect).
# For this reason, duplicates should typically be deleted rather than rescaled.
# Note we multiply each column by duplicates^(1/6) to account for the extra
# division by duplicates^(1/2) in one of the columns above.
if (self.spacing.T * duplicates ** (1 / 6)).prod(axis=0).sum():
self._spacing *= (
4
* np.pi**2
/ (self.spacing.T * duplicates ** (1 / 6)).prod(axis=0).sum()
) ** (1 / 3)
return weights

@property
Expand Down Expand Up @@ -310,12 +297,22 @@ def nodes(self):
@property
def spacing(self):
"""ndarray: Node spacing, in (rho,theta,zeta)."""
return self.__dict__.setdefault("_spacing", np.array([]).reshape((0, 3)))
errorif(
not hasattr(self, "_spacing"),
AttributeError,
"Custom grids must have spacing specified by user.",
)
return self._spacing

@property
def weights(self):
"""ndarray: Weight for each node, either exact quadrature or volume based."""
return self.__dict__.setdefault("_weights", np.array([]).reshape((0, 3)))
errorif(
not hasattr(self, "_weights"),
AttributeError,
"Custom grids must have weights specified by user.",
)
return self._weights

def __repr__(self):
"""str: string form of the object."""
Expand Down Expand Up @@ -507,6 +504,10 @@ class Grid(_Grid):
----------
nodes : ndarray of float, size(num_nodes,3)
Node coordinates, in (rho,theta,zeta)
spacing : ndarray of float, size(num_nodes, 3)
Spacing between nodes in each direction.
weights : ndarray of float, size(num_nodes, )
Quadrature weights for each node.
sort : bool
Whether to sort the nodes for use with FFT method.
jitable : bool
Expand All @@ -515,14 +516,32 @@ class Grid(_Grid):
etc may be wrong if grid contains duplicate nodes.
"""

def __init__(self, nodes, sort=False, jitable=False, **kwargs):
def __init__(
self, nodes, spacing=None, weights=None, sort=False, jitable=False, **kwargs
):
# Python 3.3 (PEP 412) introduced key-sharing dictionaries.
# This change measurably reduces memory usage of objects that
# define all attributes in their __init__ method.
self._NFP = 1
self._sym = False
self._node_pattern = "custom"
self._nodes, self._spacing = self._create_nodes(nodes)
self._nodes = self._create_nodes(nodes)
if spacing is not None:
spacing = (
jnp.atleast_2d(jnp.asarray(spacing))
.reshape(self.nodes.shape)
.astype(float)
)
self._spacing = spacing
if weights is None and spacing is not None:
self._weights = self._spacing.prod(axis=1)
elif weights is not None:
weights = (
jnp.atleast_1d(jnp.asarray(weights))
.reshape(self.nodes.shape[0])
.astype(float)
)
self._weights = weights
if sort:
self._sort_nodes()
if jitable:
Expand All @@ -532,8 +551,6 @@ def __init__(self, nodes, sort=False, jitable=False, **kwargs):
r = jnp.where(r == 0, 1e-12, r)
self._nodes = jnp.array([r, t, z]).T
self._axis = np.array([], dtype=int)
# don't do anything fancy with weights
self._weights = self._spacing.prod(axis=1)
# allow for user supplied indices/inverse indices for special cases
for attr in [
"_unique_rho_idx",
Expand All @@ -546,7 +563,6 @@ def __init__(self, nodes, sort=False, jitable=False, **kwargs):
if attr in kwargs:
setattr(self, attr, jnp.asarray(kwargs.pop(attr)))
else:
self._enforce_symmetry()
self._axis = self._find_axis()
(
self._unique_rho_idx,
Expand All @@ -556,13 +572,25 @@ def __init__(self, nodes, sort=False, jitable=False, **kwargs):
self._unique_zeta_idx,
self._inverse_zeta_idx,
) = self._find_unique_inverse_nodes()
self._weights = self._scale_weights()

self._L = self.num_nodes
self._M = self.num_nodes
self._N = self.num_nodes
errorif(len(kwargs), ValueError, f"Got unexpected kwargs {kwargs.keys()}")

def _sort_nodes(self):
"""Sort nodes for use with FFT."""
sort_idx = np.lexsort((self.nodes[:, 1], self.nodes[:, 0], self.nodes[:, 2]))
self._nodes = self.nodes[sort_idx]
try:
self._spacing = self.spacing[sort_idx]
except AttributeError:
pass
try:
self._weights = self.weights[sort_idx]
except AttributeError:
pass

def _create_nodes(self, nodes):
"""Allow for custom node creation.
Expand All @@ -575,8 +603,6 @@ def _create_nodes(self, nodes):
-------
nodes : ndarray of float, size(num_nodes,3)
Node coordinates, in (rho,theta,zeta).
spacing : ndarray of float, size(num_nodes,3)
Node spacing, in (rho,theta,zeta).
"""
nodes = jnp.atleast_2d(jnp.asarray(nodes)).reshape((-1, 3)).astype(float)
Expand All @@ -585,12 +611,7 @@ def _create_nodes(self, nodes):
# This may cause the surface_integrals() function to fail recognizing
# surfaces outside the interval [0, 2pi] as duplicates. However, most
# surface integral computations are done with LinearGrid anyway.
spacing = ( # make weights sum to 4pi^2
jnp.ones_like(nodes)
* jnp.array([1, 2 * np.pi, 2 * np.pi])
/ nodes.shape[0] ** (1 / 3)
)
return nodes, spacing
return nodes


class LinearGrid(_Grid):
Expand Down
139 changes: 0 additions & 139 deletions tests/test_equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from desc.__main__ import main
from desc.backend import sign
from desc.compute.utils import cross, dot
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.examples import get
from desc.grid import Grid, LinearGrid
Expand Down Expand Up @@ -370,141 +369,3 @@ def test_backward_compatible_load_and_resolve():
f_obj = ForceBalance(eq=eq)
obj = ObjectiveFunction(f_obj, use_jit=False)
eq.solve(maxiter=1, objective=obj)


@pytest.mark.unit
def test_shifted_circle_geometry():
"""
In this test, we calculate a low-beta shifted circle equilibrium with DESC.
We then compare the various geometric coefficients with their respective analytical
expressions. These expression are available in Edmund Highcock's thesis on arxiv
https://arxiv.org/pdf/1207.4419.pdf (Table 3.5)
"""
eq = Equilibrium.load(".//tests//inputs//low-beta-shifted-circle.h5")

eq_keys = ["iota", "iota_r", "a", "rho", "psi"]

psi = 0.25 # rho^2 (or normalized psi)
alpha = 0

eq_keys = ["iota", "iota_r", "a", "rho", "psi"]

data_eq = eq.compute(eq_keys)

iotas = np.interp(np.sqrt(psi), data_eq["rho"], data_eq["iota"])
shears = np.interp(np.sqrt(psi), data_eq["rho"], data_eq["iota_r"])

N = int((2 * eq.M_grid) * 4 + 1)

zeta = np.linspace(-1.0 * np.pi / iotas, 1.0 * np.pi / iotas, N)
theta_PEST = alpha * np.ones(N, dtype=int) + iotas * zeta

coords1 = np.zeros((N, 3))
coords1[:, 0] = np.sqrt(psi) * np.ones(N, dtype=int)
coords1[:, 1] = theta_PEST
coords1[:, 2] = zeta

# Creating a grid along a field line
c1 = eq.compute_theta_coords(coords1)
grid = Grid(c1, sort=False)

data_keys = [
"kappa",
"|grad(psi)|^2",
"grad(|B|)",
"grad(alpha)",
"grad(psi)",
"B",
"grad(|B|)",
"iota",
"|B|",
"B^zeta",
"cvdrift0",
"cvdrift",
"gbdrift",
]

data = eq.compute(data_keys, grid=grid, override_grid=False)

psib = data_eq["psi"][-1]

# signs
sign_psi = psib / np.abs(psib)
sign_iota = iotas / np.abs(iotas)

# normalizations
Lref = data_eq["a"]
Bref = 2 * np.abs(psib) / Lref**2

modB = data["|B|"]
bmag = modB / Bref

x = Lref * np.sqrt(psi)
s_hat = -x / iotas * shears / Lref

grad_psi = data["grad(psi)"]
grad_alpha = data["grad(alpha)"]

iota = data["iota"]

gradpar = Lref * data["B^zeta"] / modB

gds21 = -sign_iota * np.array(dot(grad_psi, grad_alpha)) * s_hat / Bref

gbdrift = np.array(dot(cross(data["B"], data["grad(|B|)"]), grad_alpha))
gbdrift *= -sign_psi * 2 * Bref * Lref**2 / modB**3 * np.sqrt(psi)

cvdrift = (
-sign_psi
* 2
* Bref
* Lref**2
* np.sqrt(psi)
* dot(cross(data["B"], data["kappa"]), grad_alpha)
/ modB**2
)

cvdrift0 = np.array(dot(cross(data["B"], data["grad(|B|)"]), grad_psi))
cvdrift0 *= sign_iota * sign_psi * s_hat * 2 / modB**3 / np.sqrt(psi)

## Comparing coefficient calculation here with coefficients from compute/_mtric
cvdrift_2 = -2 * sign_psi * Bref * Lref**2 * np.sqrt(psi) * data["cvdrift"]
gbdrift_2 = -2 * sign_psi * Bref * Lref**2 * np.sqrt(psi) * data["gbdrift"]

# The error here should be of the same order as the max force error
np.testing.assert_allclose(gbdrift, gbdrift_2, atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(cvdrift, cvdrift_2, atol=8e-4, rtol=9e-5)

a0_over_R0 = Lref * np.sqrt(psi)

# For the rest of the expressions, the error ~ a0_over_R0
fudge_factor1 = -3.8
cvdrift0_an = fudge_factor1 * a0_over_R0 * s_hat * np.sin(theta_PEST)
np.testing.assert_allclose(cvdrift0, cvdrift0_an, atol=5e-3, rtol=5e-3)

bmag_an = np.mean(bmag) * (1 - a0_over_R0 * np.cos(theta_PEST))
np.testing.assert_allclose(bmag, bmag_an, atol=5e-3, rtol=5e-3)

gradpar_an = 2 * Lref * iota * (1 - a0_over_R0 * np.cos(theta_PEST))
np.testing.assert_allclose(gradpar, gradpar_an, atol=9e-3, rtol=5e-3)

dPdrho = np.mean(-0.5 * (cvdrift - gbdrift) * modB**2)
alpha_MHD = -dPdrho * 1 / iota**2 * 0.5

gds21_an = (
-1 * s_hat * (s_hat * theta_PEST - alpha_MHD / bmag**4 * np.sin(theta_PEST))
)
np.testing.assert_allclose(gds21, gds21_an, atol=1.7e-2, rtol=5e-4)

fudge_factor2 = 0.19
gbdrift_an = fudge_factor2 * (
-1 * s_hat + (np.cos(theta_PEST) - 1.0 * gds21 / s_hat * np.sin(theta_PEST))
)

fudge_factor3 = 0.07
cvdrift_an = gbdrift_an + fudge_factor3 * alpha_MHD / bmag**2

# Comparing coefficients with their analytical expressions
np.testing.assert_allclose(gbdrift, gbdrift_an, atol=1.5e-2, rtol=5e-3)
np.testing.assert_allclose(cvdrift, cvdrift_an, atol=9e-3, rtol=5e-3)
37 changes: 0 additions & 37 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,6 @@
class TestGrid:
"""Test for Grid classes."""

@pytest.mark.unit
def test_custom_grid(self):
"""Test creating a grid with custom set of nodes."""
nodes = np.array(
[
[0, 0, 0],
[0.25, 0, 0],
[0.5, np.pi / 2, np.pi / 3],
[0.5, np.pi / 2, np.pi / 3],
[0.75, np.pi, np.pi],
[1, 2 * np.pi, 3 * np.pi / 2],
]
)
grid = Grid(nodes)
weights = grid.weights

w = 4 * np.pi**2 / (grid.num_nodes - 1)
weights_ref = np.array([w, w, w / 2, w / 2, w, w])

np.testing.assert_allclose(weights, weights_ref)
np.testing.assert_allclose(grid.weights.sum(), (2 * np.pi) ** 2)

@pytest.mark.unit
def test_linear_grid(self):
"""Test node placement in a LinearGrid."""
Expand Down Expand Up @@ -832,18 +810,3 @@ def test_custom_jitable_grid_indexing():
_ = eq.compute(["|B|"], grid=grid2, override_grid=True)["|B|"]
b3 = eq.compute(["|B|"], grid=grid3, override_grid=True)["|B|"]
np.testing.assert_allclose(b1, b3)


@pytest.mark.unit
def test_custom_jitable_grid_weights():
"""Test that grid weights are set correctly when jitable=True."""
rho = np.random.random(100)
theta = np.random.random(100) * 2 * np.pi
zeta = np.random.random(100) * 2 * np.pi
grid1 = Grid(np.array([rho, theta, zeta]).T, jitable=True)
grid2 = Grid(np.array([rho, theta, zeta]).T, jitable=False)

np.testing.assert_allclose(grid1.spacing, grid2.spacing)
np.testing.assert_allclose(grid1.weights, grid2.weights)
np.testing.assert_allclose(grid1.weights.sum(), 4 * np.pi**2)
np.testing.assert_allclose(grid2.weights.sum(), 4 * np.pi**2)

0 comments on commit 082be4e

Please sign in to comment.