Skip to content

Commit

Permalink
Make map_coordinates differentiable for JAX 0.4.34 (#1293)
Browse files Browse the repository at this point in the history
- Adds `full_output` flags to `root` and `root_scalar` to make them
differentiable.
- Adds tests for differentiability of `root` and `root_scalar` in
addition to `map_coordinates_derivative`

- While working on jax problems, I used this PR to update our `test_jax`
workflow with new jax versions and better dependency installation
routine (i.e. previously since jax was uploaded later, rest of the
packages were latest and only jax was old, this was causing
incompatibilities and false-errors)

Resolves #1291
  • Loading branch information
dpanici authored Oct 30, 2024
2 parents 831e7bd + c508da8 commit 7887a52
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 70 deletions.
37 changes: 19 additions & 18 deletions .github/workflows/jax_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,35 @@ jobs:
strategy:
fail-fast: false
matrix:
jax-version: [0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5,
0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11,
0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17,
0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24,
0.3.25, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5,
0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11,
0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18,
jax-version: [0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17,
0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23,
0.4.24, 0.4.25]
0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29,
0.4.30, 0.4.31, 0.4.33, 0.4.34, 0.4.35]
# 0.4.32 is not available on PyPI
# earlier jax versions are not compatible with other
# dependencies as of 2024-10-04
group: [1, 2]
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: pip
- name: Install dependencies
python-version: '3.10'
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.5.0
- name: Remove jax
- name: Install dependencies with given JAX version
run: |
pip uninstall jax jaxlib -y
- name: install jax
sed -i '/jax/d' ./requirements.txt
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Verify dependencies
run: |
pip install "jax[cpu]==${{ matrix.jax-version }}"
python --version
pip --version
pip list
- name: Test with pytest
run: |
pwd
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Bug Fixes

- Fixes bugs that occur when saving asymmetric equilibria as wout files
- Fixes bug that occurs when using ``VMECIO.plot_vmec_comparison`` to compare to an asymmetric wout file
- Fixes bug that occurs when taking the gradient of ``root`` and ``root_scalar`` with newer versions of JAX (>=0.4.34) and unpins the JAX version

Deprecations

Expand Down
73 changes: 56 additions & 17 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@
treedef_is_leaf,
)

trapezoid = (
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)
if hasattr(jnp, "trapezoid"):
trapezoid = jnp.trapezoid # for JAX 0.4.26 and later
elif hasattr(jax.scipy, "integrate"):
trapezoid = jax.scipy.integrate.trapezoid
else:
trapezoid = jnp.trapz # for older versions of JAX, deprecated by jax 0.4.16

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Expand Down Expand Up @@ -200,6 +203,7 @@ def root_scalar(
maxiter_ls=5,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.
Expand Down Expand Up @@ -227,6 +231,9 @@ def root_scalar(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> x'.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.
Returns
-------
Expand Down Expand Up @@ -271,18 +278,25 @@ def bodyfun(state):
xk1, fk1 = backtrack(xk1, fk1, d)
return xk1, fk1, k1 + 1

state = guess, res(guess), 0
state = guess, res(guess), 0.0
state = jax.lax.while_loop(condfun, bodyfun, state)
return state[0], state[1:]
if full_output:
return state[0], state[1:]
else:
return state[0]

def tangent_solve(g, y):
A = jax.jacfwd(g)(y)
return y / A

x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (abs(res), niter)
if full_output:
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (abs(res), niter)
else:
x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False)
return x

def root(
fun,
Expand All @@ -294,6 +308,7 @@ def root(
maxiter_ls=0,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.
Expand Down Expand Up @@ -321,6 +336,9 @@ def root(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> 1d array.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.
Returns
-------
Expand Down Expand Up @@ -388,19 +406,26 @@ def bodyfun(state):
state = (
jnp.atleast_1d(jnp.asarray(guess)),
jnp.atleast_1d(resfun(guess)),
0,
0.0,
)
state = jax.lax.while_loop(condfun, bodyfun, state)
return state[0], state[1:]
if full_output:
return state[0], state[1:]
else:
return state[0]

def tangent_solve(g, y):
A = jnp.atleast_2d(jax.jacfwd(g)(y))
return _lstsq(A, jnp.atleast_1d(y))

x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (safenorm(res), niter)
if full_output:
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (safenorm(res), niter)
else:
x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False)
return x


# we can't really test the numpy backend stuff in automated testing, so we ignore it
Expand Down Expand Up @@ -711,6 +736,7 @@ def root_scalar(
maxiter_ls=5,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.
Expand Down Expand Up @@ -738,6 +764,9 @@ def root_scalar(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x) -> x'.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.
Returns
-------
Expand All @@ -750,7 +779,10 @@ def root_scalar(
out = scipy.optimize.root_scalar(
fun, args, x0=x0, fprime=jac, xtol=tol, rtol=tol
)
return out.root, out
if full_output:
return out.root, out
else:
return out.root

def root(
fun,
Expand All @@ -762,6 +794,7 @@ def root(
maxiter_ls=0,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.
Expand Down Expand Up @@ -789,6 +822,9 @@ def root(
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> 1d array.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.
Returns
-------
Expand All @@ -803,7 +839,10 @@ def root(
will solve it in a least squares sense.
"""
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out
if full_output:
return out.x, out
else:
return out.x

def flatnonzero(a, size=None, fill_value=0):
"""A numpy implementation of jnp.flatnonzero."""
Expand Down
12 changes: 6 additions & 6 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,7 @@ def zernike_radial(r, l, m, dr=0):
"Analytic radial derivatives of Zernike polynomials for order>4 "
+ "have not been implemented."
)
return s * jnp.where((l - m) % 2 == 0, out, 0)
return s * jnp.where((l - m) % 2 == 0, out, 0.0)


def power_coeffs(l):
Expand Down Expand Up @@ -1732,7 +1732,7 @@ def _binom_body_fun(i, b_n):
return b


@custom_jvp
@functools.partial(custom_jvp, nondiff_argnums=(4,))
@jit
@jnp.vectorize
def _jacobi(n, alpha, beta, x, dx=0):
Expand Down Expand Up @@ -1804,13 +1804,13 @@ def _jacobi_body_fun(kk, d_p_a_b_x):


@_jacobi.defjvp
def _jacobi_jvp(x, xdot):
(n, alpha, beta, x, dx) = x
(ndot, alphadot, betadot, xdot, dxdot) = xdot
def _jacobi_jvp(dx, x, xdot):
(n, alpha, beta, x) = x
(*_, xdot) = xdot
f = _jacobi(n, alpha, beta, x, dx)
df = _jacobi(n, alpha, beta, x, dx + 1)
# in theory n, alpha, beta, dx aren't differentiable (they're integers)
# but marking them as non-diff argnums seems to cause escaped tracer values.
# probably a more elegant fix, but just setting those derivatives to zero seems
# to work fine.
return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot
return f, df * xdot
36 changes: 27 additions & 9 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,18 @@ def fixup(y, *args):
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
)
# See description here
# https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
# except we make sure properly handle periodic coordinates.
yk, (res, niter) = vecroot(yk, coords)
if full_output:
yk, (res, niter) = vecroot(yk, coords)
else:
yk = vecroot(yk, coords)

out = compute(yk, outbasis)
if full_output:
Expand Down Expand Up @@ -363,18 +367,28 @@ def fixup(x, *args):
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
)
rho, theta_PEST, zeta = coords.T
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
if full_output:
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
else:
theta = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
Expand Down Expand Up @@ -466,6 +480,7 @@ def fixup(x, *args):
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
Expand All @@ -474,7 +489,10 @@ def fixup(x, *args):
if guess is None:
# Assume λ=0 for default initial guess.
guess = alpha + iota * zeta
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
if full_output:
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
else:
theta = vecroot(guess, alpha, rho, zeta, iota)

out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
Expand Down
15 changes: 12 additions & 3 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,10 +741,19 @@ def fun_jax(zeta_hat, theta, zeta):
n, r, r_offset = n_and_r_jax(nodes)
return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta

vecroot = jit(vmap(lambda x0, *p: root_scalar(fun_jax, x0, jac=None, args=p)))
zetas, (res, niter) = vecroot(
grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]
vecroot = jit(
vmap(
lambda x0, *p: root_scalar(
fun_jax, x0, jac=None, args=p, full_output=full_output
)
)
)
if full_output:
zetas, (res, niter) = vecroot(
grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]
)
else:
zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2])

zetas = np.asarray(zetas)
nodes = np.vstack((np.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T
Expand Down
3 changes: 1 addition & 2 deletions devtools/dev-requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ dependencies:
- pip:
# Conda only parses a single list of pip requirements.
# If two pip lists are given, all but the last list is skipped.
- jax >= 0.4.24, < 0.5.0
- diffrax >= 0.4.1
- interpax >= 0.3.3
- jax[cpu] >= 0.3.2, < 0.5.0
- nvgpu
- orthax
- plotly >= 5.16, < 6.0
Expand All @@ -29,7 +29,6 @@ dependencies:
- qicna @ git+https://github.com/rogeriojorge/pyQIC/
- black[jupyter] = 24.3.0


# building the docs
- nbsphinx == 0.8.12
- pandoc
Expand Down
Loading

0 comments on commit 7887a52

Please sign in to comment.