Skip to content

Commit

Permalink
Change interp2argmin to expect already reshaped data to be consistent…
Browse files Browse the repository at this point in the history
… with nw API
  • Loading branch information
unalmis committed Aug 25, 2024
1 parent 4b9ff2d commit cee3da7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 34 deletions.
38 changes: 19 additions & 19 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@
from desc.integrals.bounce_utils import (
_add2legend,
_check_bounce_points,
_interp_to_argmin_B_soft,
_plot_intersect,
bounce_points,
bounce_quadrature,
chebroots_vec,
epigraph_and,
flatten_matrix,
get_alpha,
interp_to_argmin_B_soft,
plot_ppoly,
subtract,
)
from desc.integrals.interp_utils import (
_filter_distinct,
cheb_from_dct,
cheb_pts,
filter_distinct,
fourier_pts,
harmonic,
idct_non_uniform,
Expand Down Expand Up @@ -297,7 +297,7 @@ def intersect2d(self, k=0.0, eps=_eps):

# Intersects must satisfy y ∈ [-1, 1].
# Pick sentinel such that only distinct roots are considered intersects.
y = _filter_distinct(y, sentinel=-2.0, eps=eps)
y = filter_distinct(y, sentinel=-2.0, eps=eps)
is_intersect = (jnp.abs(y.imag) <= eps) & (jnp.abs(y.real) <= 1.0)
y = jnp.where(is_intersect, y.real, 1.0) # ensure y is in domain of arcos

Expand All @@ -324,12 +324,12 @@ def intersect1d(self, k=0.0, num_intersect=None, pad_value=0.0):
Shape must broadcast with (..., *cheb.shape[:-2]).
Specify to find solutions yᵢ to fₓ(yᵢ) = k. Default 0.
num_intersect : int or None
If not specified, then all intersects are returned in an array whose
last axis has size ``self.M*(self.N-1)``. If there were less than that many
intersects detected, then the last axis of the returned arrays is padded
with ``pad_value``. Specify to return the first ``num_intersect`` pairs
of intersects. This is useful if ``num_intersect`` tightly bounds the
actual number.
Specify to return the first ``num_intersect`` intersects.
This is useful if ``num_intersect`` tightly bounds the actual number.
If not specified, then all intersects are returned. If there were fewer
intersects detected than the size of the last axis of the returned arrays,
then that axis is padded with ``pad_value``.
pad_value : float
Value with which to pad array. Default 0.
Expand Down Expand Up @@ -988,8 +988,8 @@ def bounce_points(self, pitch, num_well=None):
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
pitch along each field line. This is useful if ``num_well`` tightly
bounds the actual number of wells. As a reference, there are typically
at most 5 wells per toroidal transit for a given pitch.
bounds the actual number. As a reference, there are typically at most 5
wells per toroidal transit for a given pitch.
If not specified, then all bounce points are returned. If there were fewer
wells detected along a field line than the size of the last axis of the
Expand Down Expand Up @@ -1050,8 +1050,8 @@ def integrate(self, pitch, integrand, f, weight=None, num_well=None):
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
pitch along each field line. This is useful if ``num_well`` tightly
bounds the actual number of wells. As a reference, there are typically
at most 5 wells per toroidal transit for a given pitch.
bounds the actual number. As a reference, there are typically at most 5
wells per toroidal transit for a given pitch.
If not specified, then all bounce points are returned. If there were fewer
wells detected along a field line than the size of the last axis of the
Expand Down Expand Up @@ -1300,8 +1300,8 @@ def bounce_points(self, pitch, num_well=None):
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
pitch along each field line. This is useful if ``num_well`` tightly
bounds the actual number of wells. As a reference, there are typically
at most 5 wells per toroidal transit for a given pitch.
bounds the actual number. As a reference, there are typically at most 5
wells per toroidal transit for a given pitch.
If not specified, then all bounce points are returned. If there were fewer
wells detected along a field line than the size of the last axis of the
Expand Down Expand Up @@ -1397,16 +1397,16 @@ def integrate(
``integrand``. Use the method ``self.reshape_data`` to reshape the data
into the expected shape.
weight : jnp.ndarray
Shape (L * M, N).
Shape must broadcast with (L * M, N).
If supplied, the bounce integral labeled by well j is weighted such that
the returned value is w(j) ∫ f(ℓ) dℓ, where w(j) is ``weight``
interpolated to the deepest point in the magnetic well. Use the method
``self.reshape_data`` to reshape the data into the expected shape.
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
pitch along each field line. This is useful if ``num_well`` tightly
bounds the actual number of wells. As a reference, there are typically
at most 5 wells per toroidal transit for a given pitch.
bounds the actual number. As a reference, there are typically at most 5
wells per toroidal transit for a given pitch.
If not specified, then all bounce points are returned. If there were fewer
wells detected along a field line than the size of the last axis of the
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def integrate(
check=check,
)
if weight is not None:
result *= _interp_to_argmin_B_soft(
result *= interp_to_argmin_B_soft(
g=weight,
bp1=bp1,
bp2=bp2,
Expand Down
24 changes: 16 additions & 8 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ def bounce_points(
num_well : int or None
Specify to return the first ``num_well`` pairs of bounce points for each
pitch along each field line. This is useful if ``num_well`` tightly
bounds the actual number of wells. As a reference, there are typically
at most 5 wells per toroidal transit for a given pitch.
bounds the actual number. As a reference, there are typically at most 5
wells per toroidal transit for a given pitch.
If not specified, then all bounce points are returned. If there were fewer
wells detected along a field line than the size of the last axis of the
Expand Down Expand Up @@ -685,13 +685,16 @@ def _get_extrema(knots, B, dB_dz, sentinel=jnp.nan):
return extrema, B_extrema


def _interp_to_argmin_B_soft(g, bp1, bp2, knots, B, dB_dz, method="cubic", beta=-50):
def interp_to_argmin_B_soft(g, bp1, bp2, knots, B, dB_dz, method="cubic", beta=-50):
"""Interpolate ``g`` to the deepest point in the magnetic well.
Let E = {ζ ∣ ζ₁ < ζ < ζ₂} and A = argmin_E |B|(ζ). Returns mean_A g(ζ).
Parameters
----------
g : jnp.ndarray
Shape must broadcast with (S, knots.size).
Values evaluated on ``knots`` to interpolate.
beta : float
More negative gives exponentially better approximation at the
expense of noisier gradients - noisier in the physics sense (unrelated
Expand All @@ -712,19 +715,24 @@ def _interp_to_argmin_B_soft(g, bp1, bp2, knots, B, dB_dz, method="cubic", beta=
)
g = jnp.linalg.vecdot(
argmin,
interp1d_vec(ext, knots, g.reshape(-1, knots.size), method=method)[
:, jnp.newaxis
],
interp1d_vec(ext, knots, jnp.atleast_2d(g), method=method)[:, jnp.newaxis],
)
assert g.shape == bp1.shape == bp2.shape
return g


# Less efficient than soft if P >> 1.
def _interp_to_argmin_B_hard(g, bp1, bp2, knots, B, dB_dz, method="cubic"):
def interp_to_argmin_B_hard(g, bp1, bp2, knots, B, dB_dz, method="cubic"):
"""Interpolate ``g`` to the deepest point in the magnetic well.
Let E = {ζ ∣ ζ₁ < ζ < ζ₂} and A ∈ argmin_E |B|(ζ). Returns g(A).
Parameters
----------
g : jnp.ndarray
Shape must broadcast with (S, knots.size).
Values evaluated on ``knots`` to interpolate.
"""
ext, B = _get_extrema(knots, B, dB_dz, sentinel=0)
assert ext.shape[0] == B.shape[0] == bp1.shape[1] == bp2.shape[1]
Expand All @@ -738,7 +746,7 @@ def _interp_to_argmin_B_hard(g, bp1, bp2, knots, B, dB_dz, method="cubic"):
axis=-1,
)
A = jnp.take_along_axis(ext[jnp.newaxis], argmin, axis=-1)
g = interp1d_vec(A, knots, g.reshape(-1, knots.size), method=method)
g = interp1d_vec(A, knots, jnp.atleast_2d(g), method=method)
assert g.shape == bp1.shape == bp2.shape
return g

Expand Down
4 changes: 2 additions & 2 deletions desc/integrals/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def poly_root(

if sort or distinct:
r = jnp.sort(r, axis=-1)
return _filter_distinct(r, sentinel, eps) if distinct else r
return filter_distinct(r, sentinel, eps) if distinct else r


def _root_cubic(a, b, c, d, sentinel, eps, distinct):
Expand Down Expand Up @@ -640,7 +640,7 @@ def _concat_sentinel(r, sentinel, num=1):
return jnp.append(r, sent, axis=-1)


def _filter_distinct(r, sentinel, eps):
def filter_distinct(r, sentinel, eps):
"""Set all but one of matching adjacent elements in ``r`` to ``sentinel``."""
# eps needs to be low enough that close distinct roots do not get removed.
# Otherwise, algorithms relying on continuity will fail.
Expand Down
8 changes: 3 additions & 5 deletions tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
)
from desc.integrals.bounce_utils import (
_get_extrema,
_interp_to_argmin_B_hard,
_interp_to_argmin_B_soft,
bounce_points,
get_pitch,
interp_to_argmin_B_hard,
interp_to_argmin_B_soft,
plot_ppoly,
)
from desc.integrals.quad_utils import (
Expand Down Expand Up @@ -1096,9 +1096,7 @@ def denominator(B, pitch):
print(pitch[:, i, j])

@pytest.mark.unit
@pytest.mark.parametrize(
"func", [_interp_to_argmin_B_soft, _interp_to_argmin_B_hard]
)
@pytest.mark.parametrize("func", [interp_to_argmin_B_soft, interp_to_argmin_B_hard])
def test_interp_to_argmin_B(self, func):
"""Test argmin interpolation.""" # noqa: D202

Expand Down

0 comments on commit cee3da7

Please sign in to comment.