Skip to content

Commit

Permalink
Make Bounce1D pytree and ioable and ensure eigh_tridiagonal is revers…
Browse files Browse the repository at this point in the history
…e mode diffable
  • Loading branch information
unalmis committed Aug 29, 2024
1 parent e4dcd2e commit 1c1fa96
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 19 deletions.
17 changes: 8 additions & 9 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
get_quadrature,
grad_automorphism_sin,
)
from desc.io import IOAble
from desc.utils import setdefault, warnif


class Bounce1D:
class Bounce1D(IOAble):
"""Computes bounce integrals using one-dimensional local spline methods.
The bounce integral is defined as ∫ f(ℓ) dℓ, where
Expand Down Expand Up @@ -86,6 +87,8 @@ class Bounce1D:
Attributes
----------
required_names : list
Names in ``data_index`` required to compute bounce integrals.
_B : jnp.ndarray
TODO: Make this (4, M, L, N-1) now that tensor product in rho and alpha
required as well after GitHub PR #1214.
Expand All @@ -97,6 +100,7 @@ class Bounce1D:
"""

required_names = ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"]
plot_ppoly = staticmethod(plot_ppoly)
get_pitch_inv = staticmethod(get_pitch_inv)

Expand All @@ -121,7 +125,7 @@ def __init__(
L = ``grid.num_rho``, M = ``grid.num_alpha``, and N = ``grid.num_zeta``.
data : dict[str, jnp.ndarray]
Data evaluated on ``grid``.
Must include names in ``Bounce1D.required_names()``.
Must include names in ``Bounce1D.required_names``.
quad : (jnp.ndarray, jnp.ndarray)
Quadrature points xₖ and weights wₖ for the approximate evaluation of an
integral ∫₋₁¹ g(x) dx = ∑ₖ wₖ g(xₖ). Default is 32 points.
Expand Down Expand Up @@ -157,8 +161,8 @@ def __init__(
"|B|_z|r,a": data["|B|_z|r,a"] / Bref, # This is already the correct sign.
}
self._data = {
key: grid.meshgrid_reshape(val, "raz").reshape(-1, grid.num_zeta)
for key, val in data.items()
name: grid.meshgrid_reshape(data[name], "raz").reshape(-1, grid.num_zeta)
for name in Bounce1D.required_names
}
self._x, self._w = get_quadrature(quad, automorphism)

Expand All @@ -181,11 +185,6 @@ def __init__(
assert self._dB_dz.shape[0] == degree
assert self._B.shape[-1] == self._dB_dz.shape[-1] == grid.num_zeta - 1

@staticmethod
def required_names():
"""Return names in ``data_index`` required to compute bounce integrals."""
return ["B^zeta", "B^zeta_z|r,a", "|B|", "|B|_z|r,a"]

@staticmethod
def reshape_data(grid, *arys):
"""Reshape arrays for acceptable input to ``integrate``.
Expand Down
4 changes: 2 additions & 2 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def bounce_quadrature(
These functions should be arguments to the callable ``integrand``.
data : dict[str, jnp.ndarray]
Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``.
Must include names in ``Bounce1D.required_names()``.
Must include names in ``Bounce1D.required_names``.
knots : jnp.ndarray
Shape (knots.size, ).
Unique ζ coordinates where the arrays in ``data`` and ``f`` were evaluated.
Expand Down Expand Up @@ -420,7 +420,7 @@ def _interpolate_and_integrate(
Quadrature points in ζ coordinates.
data : dict[str, jnp.ndarray]
Data evaluated on ``grid`` and reshaped with ``Bounce1D.reshape_data``.
Must include names in ``Bounce1D.required_names()``.
Must include names in ``Bounce1D.required_names``.
Returns
-------
Expand Down
13 changes: 7 additions & 6 deletions desc/io/optimizable_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,16 +169,17 @@ class IOAble(ABC, metaclass=_CombinedMeta):
"""Abstract Base Class for savable and loadable objects.
Objects inheriting from this class can be saved and loaded via hdf5 or pickle.
To save properly, each object should have an attribute `_io_attrs_` which
To save properly, each object should have an attribute ``_io_attrs_`` which
is a list of strings of the object attributes or properties that should be
saved and loaded.
For saved objects to be loaded correctly, the __init__ method of any custom
types being saved should only assign attributes that are listed in `_io_attrs_`.
For saved objects to be loaded correctly, the ``__init__`` method of any custom
types being saved should only assign attributes that are listed in ``_io_attrs_``.
Other attributes or other initialization should be done in a separate
`set_up()` method that can be called during __init__. The loading process
will involve creating an empty object, bypassing init, then setting any `_io_attrs_`
of the object, then calling `_set_up()` without any arguments, if it exists.
``set_up()`` method that can be called during ``__init__``. The loading process
will involve creating an empty object, bypassing init, then setting any
``_io_attrs_`` of the object, then calling ``_set_up()`` without any arguments,
if it exists.
"""

Expand Down
4 changes: 2 additions & 2 deletions tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ def test_integrate_checks(self):
)
# 4. Compute input data.
data = eq.compute(
Bounce1D.required_names() + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid
Bounce1D.required_names + ["min_tz |B|", "max_tz |B|", "g_zz"], grid=grid
)
# 5. Make the bounce integration operator.
bounce = Bounce1D(grid.source_grid, data, quad=leggauss(3), check=True)
Expand Down Expand Up @@ -1292,7 +1292,7 @@ def test_binormal_drift_bounce1d(self):
iota=iota,
)
data = eq.compute(
Bounce1D.required_names()
Bounce1D.required_names
+ [
"cvdrift",
"gbdrift",
Expand Down
10 changes: 10 additions & 0 deletions tests/test_quad_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import numpy as np
import pytest
from jax import grad

from desc.backend import jnp
from desc.integrals.quad_utils import (
automorphism_arcsin,
automorphism_sin,
Expand Down Expand Up @@ -91,3 +93,11 @@ def test_leggauss_lobatto():
np.testing.assert_allclose(x, [-1, -np.sqrt(3 / 7), 0, np.sqrt(3 / 7), 1])
np.testing.assert_allclose(w, [1 / 10, 49 / 90, 32 / 45, 49 / 90, 1 / 10])
np.testing.assert_allclose(leggauss_lob(x.size - 2, True), (x[1:-1], w[1:-1]))

def fun(a):
x, w = leggauss_lob(a.size)
return jnp.dot(x * a, w)

# make sure differentiable
# https://github.com/PlasmaControl/DESC/pull/854#discussion_r1733323161
assert np.isfinite(grad(fun)(jnp.arange(10) * np.pi)).all()

0 comments on commit 1c1fa96

Please sign in to comment.