diff --git a/desc/integrals/bounce_integral.py b/desc/integrals/bounce_integral.py index 19583de69d..3fd3f0b1cc 100644 --- a/desc/integrals/bounce_integral.py +++ b/desc/integrals/bounce_integral.py @@ -18,10 +18,11 @@ get_quadrature, grad_automorphism_sin, ) +from desc.io import IOAble from desc.utils import setdefault, warnif -class Bounce1D: +class Bounce1D(IOAble): """Computes bounce integrals using one-dimensional local spline methods. The bounce integral is defined as ∫ f(ℓ) dℓ, where @@ -86,6 +87,8 @@ class Bounce1D: Attributes ---------- + required_names : list + Names in ``data_index`` required to compute bounce integrals. _B : jnp.ndarray TODO: Make this (4, M, L, N-1) now that tensor product in rho and alpha required as well after GitHub PR #1214. @@ -97,6 +100,7 @@ class Bounce1D: """ + required_names = ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"] plot_ppoly = staticmethod(plot_ppoly) get_pitch_inv = staticmethod(get_pitch_inv) @@ -121,7 +125,7 @@ def __init__( L = ``grid.num_rho``, M = ``grid.num_alpha``, and N = ``grid.num_zeta``. data : dict[str, jnp.ndarray] Data evaluated on ``grid``. - Must include names in ``Bounce1D.required_names()``. + Must include names in ``Bounce1D.required_names``. quad : (jnp.ndarray, jnp.ndarray) Quadrature points xₖ and weights wₖ for the approximate evaluation of an integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points. @@ -157,8 +161,8 @@ def __init__( "|B|_z|r,a": data["|B|_z|r,a"] / Bref, # This is already the correct sign. } self._data = { - key: grid.meshgrid_reshape(val, "raz").reshape(-1, grid.num_zeta) - for key, val in data.items() + name: grid.meshgrid_reshape(data[name], "raz").reshape(-1, grid.num_zeta) + for name in Bounce1D.required_names } self._x, self._w = get_quadrature(quad, automorphism) @@ -181,11 +185,6 @@ def __init__( assert self._dB_dz.shape[0] == degree assert self._B.shape[-1] == self._dB_dz.shape[-1] == grid.num_zeta - 1 - @staticmethod - def required_names(): - """Return names in ``data_index`` required to compute bounce integrals.""" - return ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"] - @staticmethod def reshape_data(grid, *arys): """Reshape arrays for acceptable input to ``integrate``. diff --git a/desc/integrals/bounce_utils.py b/desc/integrals/bounce_utils.py index e9a9cff613..90d7a30273 100644 --- a/desc/integrals/bounce_utils.py +++ b/desc/integrals/bounce_utils.py @@ -319,7 +319,7 @@ def bounce_quadrature( These functions should be arguments to the callable ``integrand``. data : dict[str, jnp.ndarray] Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. - Must include names in ``Bounce1D.required_names()``. + Must include names in ``Bounce1D.required_names``. knots : jnp.ndarray Shape (knots.size, ). Unique ζ coordinates where the arrays in ``data`` and ``f`` were evaluated. @@ -420,7 +420,7 @@ def _interpolate_and_integrate( Quadrature points in ζ coordinates. data : dict[str, jnp.ndarray] Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``. - Must include names in ``Bounce1D.required_names()``. + Must include names in ``Bounce1D.required_names``. Returns ------- diff --git a/desc/io/optimizable_io.py b/desc/io/optimizable_io.py index 554cdac070..e15a21756e 100644 --- a/desc/io/optimizable_io.py +++ b/desc/io/optimizable_io.py @@ -169,16 +169,17 @@ class IOAble(ABC, metaclass=_CombinedMeta): """Abstract Base Class for savable and loadable objects. Objects inheriting from this class can be saved and loaded via hdf5 or pickle. - To save properly, each object should have an attribute `_io_attrs_` which + To save properly, each object should have an attribute ``_io_attrs_`` which is a list of strings of the object attributes or properties that should be saved and loaded. - For saved objects to be loaded correctly, the __init__ method of any custom - types being saved should only assign attributes that are listed in `_io_attrs_`. + For saved objects to be loaded correctly, the ``__init__`` method of any custom + types being saved should only assign attributes that are listed in ``_io_attrs_``. Other attributes or other initialization should be done in a separate - `set_up()` method that can be called during __init__. The loading process - will involve creating an empty object, bypassing init, then setting any `_io_attrs_` - of the object, then calling `_set_up()` without any arguments, if it exists. + ``set_up()`` method that can be called during ``__init__``. The loading process + will involve creating an empty object, bypassing init, then setting any + ``_io_attrs_`` of the object, then calling ``_set_up()`` without any arguments, + if it exists. """ diff --git a/tests/test_integrals.py b/tests/test_integrals.py index 08ffeb0042..dd14cc0dcd 100644 --- a/tests/test_integrals.py +++ b/tests/test_integrals.py @@ -1066,7 +1066,7 @@ def test_integrate_checks(self): ) # 4. Compute input data. data = eq.compute( - Bounce1D.required_names() + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid + Bounce1D.required_names + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid ) # 5. Make the bounce integration operator. bounce = Bounce1D(grid.source_grid, data, quad=leggauss(3), check=True) @@ -1292,7 +1292,7 @@ def test_binormal_drift_bounce1d(self): iota=iota, ) data = eq.compute( - Bounce1D.required_names() + Bounce1D.required_names + [ "cvdrift", "gbdrift", diff --git a/tests/test_quad_utils.py b/tests/test_quad_utils.py index 07dfcd85e6..5a7c3d00e7 100644 --- a/tests/test_quad_utils.py +++ b/tests/test_quad_utils.py @@ -2,7 +2,9 @@ import numpy as np import pytest +from jax import grad +from desc.backend import jnp from desc.integrals.quad_utils import ( automorphism_arcsin, automorphism_sin, @@ -91,3 +93,11 @@ def test_leggauss_lobatto(): np.testing.assert_allclose(x, [-1, -np.sqrt(3 / 7), 0, np.sqrt(3 / 7), 1]) np.testing.assert_allclose(w, [1 / 10, 49 / 90, 32 / 45, 49 / 90, 1 / 10]) np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1])) + + def fun(a): + x, w = leggauss_lob(a.size) + return jnp.dot(x * a, w) + + # make sure differentiable + # https://github.com/PlasmaControl/DESC/pull/854#discussion_r1733323161 + assert np.isfinite(grad(fun)(jnp.arange(10) * np.pi)).all()