Skip to content

Commit

Permalink
Quad change for speed (#1478)
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis authored Dec 17, 2024
2 parents 1cf86c2 + 3b7e5e1 commit 5c374d8
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 68 deletions.
45 changes: 10 additions & 35 deletions desc/compute/_neoclassical.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,27 +309,20 @@ def _v_tau(data, B, pitch):
return safediv(2.0, jnp.sqrt(jnp.abs(1 - pitch * B)))


def _f1(data, B, pitch):
def _drift1(data, B, pitch):
return (
safediv(
1 - 0.5 * pitch * B,
jnp.sqrt(jnp.abs(1 - pitch * B)),
)
safediv(1 - 0.5 * pitch * B, jnp.sqrt(jnp.abs(1 - pitch * B)))
* data["|grad(psi)|*kappa_g"]
/ B
)


def _f2(data, B, pitch):
def _drift2(data, B, pitch):
return (
safediv(1 - 0.5 * pitch * B, jnp.sqrt(jnp.abs(1 - pitch * B)))
* data["|B|_r|v,p"]
/ B
)


def _f3(data, B, pitch):
return jnp.sqrt(jnp.abs(1 - pitch * B)) * data["K"] / B
+ jnp.sqrt(jnp.abs(1 - pitch * B)) * data["K"]
) / B


@register_compute_fun(
Expand Down Expand Up @@ -366,7 +359,6 @@ def _f3(data, B, pitch):
resolution_requirement="tz",
grid_requirement={"can_fft2": True},
**_bounce_doc,
quad2="Same as ``quad`` for the weak singular integrals in particular.",
)
@partial(
jit,
Expand Down Expand Up @@ -411,7 +403,6 @@ def _Gamma_c(params, transforms, profiles, data, **kwargs):
leggauss(kwargs.get("num_quad", 32)),
(automorphism_sin, grad_automorphism_sin),
)
quad2 = kwargs["quad2"] if "quad2" in kwargs else chebgauss2(quad[0].size)

def Gamma_c(data):
"""∫ dλ ∑ⱼ [v τ γ_c²]ⱼ π²/4."""
Expand All @@ -429,35 +420,19 @@ def Gamma_c(data):

def fun(pitch_inv):
points = bounce.points(pitch_inv, num_well=num_well)
v_tau, f1, f2 = bounce.integrate(
[_v_tau, _f1, _f2],
v_tau, drift1, drift2 = bounce.integrate(
[_v_tau, _drift1, _drift2],
pitch_inv,
data,
["|grad(psi)|*kappa_g", "|B|_r|v,p"],
["|grad(psi)|*kappa_g", "|B|_r|v,p", "K"],
points,
is_fourier=True,
)
# This is γ_c π/2.
gamma_c = jnp.arctan(
safediv(
f1,
(
f2
# TODO: Once people are happy with benchmarking
# we can push this integral into f2.
# The quadrature is less optimal, but
# it still works and it would be more efficient
# since we don't have to interpolate twice.
+ bounce.integrate(
_f3,
pitch_inv,
data,
"K",
points,
quad=quad2,
is_fourier=True,
)
)
drift1,
drift2
* bounce.interp_to_argmin(
data["|grad(rho)|*|e_alpha|r,p|"], points, is_fourier=True
),
Expand Down
30 changes: 6 additions & 24 deletions desc/compute/_neoclassical_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
grad_automorphism_sin,
)
from ..utils import cross, dot, safediv
from ._neoclassical import _bounce_doc, _cvdrift0, _dH, _dI, _f1, _f2, _f3, _v_tau
from ._neoclassical import _bounce_doc, _cvdrift0, _dH, _dI, _drift1, _drift2, _v_tau
from .data_index import register_compute_fun

_bounce1D_doc = {
Expand Down Expand Up @@ -286,7 +286,6 @@ def _effective_ripple_1D(params, transforms, profiles, data, **kwargs):
+ Bounce1D.required_names,
source_grid_requirement={"coordinates": "raz", "is_meshgrid": True},
**_bounce1D_doc,
quad2="Same as ``quad`` for the weak singular integrals in particular.",
)
@partial(jit, static_argnames=["num_well", "num_quad", "num_pitch", "batch"])
def _Gamma_c_1D(params, transforms, profiles, data, **kwargs):
Expand All @@ -312,41 +311,24 @@ def _Gamma_c_1D(params, transforms, profiles, data, **kwargs):
leggauss(kwargs.get("num_quad", 32)),
(automorphism_sin, grad_automorphism_sin),
)
quad2 = kwargs["quad2"] if "quad2" in kwargs else chebgauss2(quad[0].size)

def Gamma_c(data):
"""∫ dλ ∑ⱼ [v τ γ_c²]ⱼ π²/4."""
bounce = Bounce1D(grid, data, quad, automorphism=None, is_reshaped=True)
points = bounce.points(data["pitch_inv"], num_well=num_well)
v_tau, f1, f2 = bounce.integrate(
[_v_tau, _f1, _f2],
v_tau, drift1, drift2 = bounce.integrate(
[_v_tau, _drift1, _drift2],
data["pitch_inv"],
data,
["|grad(psi)|*kappa_g", "|B|_r|v,p"],
["|grad(psi)|*kappa_g", "|B|_r|v,p", "K"],
points,
batch=batch,
)
# This is γ_c π/2.
gamma_c = jnp.arctan(
safediv(
f1,
(
f2
# TODO: Once people are happy with benchmarking
# we can push this integral into f2.
# The quadrature is less optimal, but
# it still works and it would be more efficient
# since we don't have to interpolate twice.
+ bounce.integrate(
_f3,
data["pitch_inv"],
data,
"K",
points,
batch=batch,
quad=quad2,
)
)
drift1,
drift2
* bounce.interp_to_argmin(data["|grad(rho)|*|e_alpha|r,p|"], points),
)
)
Expand Down
10 changes: 1 addition & 9 deletions desc/objectives/_neoclassical.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,10 @@ def build(self, use_jit=True, verbose=1):
domain=(0, 2 * np.pi),
)
self._constants["fieldline quad"] = leggauss(self._hyperparam["Y_B"] // 2)
num_quad = self._hyperparam.pop("num_quad")
self._constants["quad"] = get_quadrature(
leggauss(num_quad),
leggauss(self._hyperparam.pop("num_quad")),
(automorphism_sin, grad_automorphism_sin),
)
if self._key == "Gamma_c":
self._constants["quad2"] = chebgauss2(num_quad)

self._dim_f = self._grid.num_rho
self._target, self._bounds = _parse_callable_target_bounds(
Expand Down Expand Up @@ -483,10 +480,6 @@ def compute(self, params, constants=None):
"""
if constants is None:
constants = self.constants
quad2 = {}
if self._key == "Gamma_c":
quad2["quad2"] = constants["quad2"]

eq = self.things[0]
data = compute_fun(
eq, "iota", params, constants["transforms"], constants["profiles"]
Expand All @@ -511,7 +504,6 @@ def compute(self, params, constants=None):
),
fieldline_quad=constants["fieldline quad"],
quad=constants["quad"],
**quad2,
**self._hyperparam,
)
return constants["transforms"]["grid"].compress(data[self._key])
Binary file modified tests/baseline/test_Gamma_c.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/baseline/test_Gamma_c_1D.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 5c374d8

Please sign in to comment.