diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 7be78bc90d..8f514c1c88 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -174,7 +174,7 @@ If the 2nd case is the reason, then you must update the ``master_compute_data.pk except AssertionError as e: error = True - print(e) + print(e) with @@ -183,7 +183,7 @@ with except AssertionError as e: error = False update_master_data = True - print(e) + print(e) - rerun the test ``pytest tests -k test_compute_everything``, now any compute quantity that is different between the PR and master will be updated with the PR value diff --git a/desc/basis.py b/desc/basis.py index b394fcfa90..dafc6bf929 100644 --- a/desc/basis.py +++ b/desc/basis.py @@ -7,7 +7,7 @@ import mpmath import numpy as np -from desc.backend import cond, custom_jvp, fori_loop, gammaln, jit, jnp, sign, switch +from desc.backend import custom_jvp, fori_loop, gammaln, jit, jnp, sign from desc.io import IOAble from desc.utils import check_nonnegint, check_posint, flatten_list @@ -767,7 +767,7 @@ def evaluate( lm = lm[lmidx] m = m[midx] - radial = zernike_radial(r, lm[:, 0], lm[:, 1], dr=derivatives[0]) + radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0]) poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1]) if unique: @@ -1129,7 +1129,7 @@ def evaluate( m = m[midx] n = n[nidx] - radial = zernike_radial(r, lm[:, 0], lm[:, 1], dr=derivatives[0]) + radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0]) poloidal = fourier(t[:, np.newaxis], m, dt=derivatives[1]) toroidal = fourier(z[:, np.newaxis], n, NFP=self.NFP, dt=derivatives[2]) if unique: @@ -1484,149 +1484,83 @@ def zernike_radial_poly(r, l, m, dr=0, exact="auto"): return polyval_vec(coeffs, r, prec=prec).T -@custom_jvp -@jit +@functools.partial(jit, static_argnums=3) def zernike_radial(r, l, m, dr=0): """Radial part of zernike polynomials. - Calculates Radial part of Zernike Polynomials using Jacobi recursion relation - by getting rid of the redundant calculations for appropriate modes. - https://en.wikipedia.org/wiki/Jacobi_polynomials#Recurrence_relations - - For the derivatives, the following formula is used with above recursion relation, - https://en.wikipedia.org/wiki/Jacobi_polynomials#Derivatives - - Used formulas are also in the zerike_eval.ipynb notebook in docs. - - This function can be made faster. However, JAX reverse mode AD causes problems. - In future, we may use vmap() instead of jnp.vectorize() to be able to set dr as - static argument, and not calculate every derivative even thoguh not asked. + Evaluates basis functions using JAX and a stable + evaluation scheme based on jacobi polynomials and + binomial coefficients. Generally faster for L>24 + and differentiable, but slower for low resolution. Parameters ---------- - r : ndarray, shape(N,) or scalar + r : ndarray, shape(N,) radial coordinates to evaluate basis - l : ndarray of int, shape(K,) or integer + l : ndarray of int, shape(K,) radial mode number(s) - m : ndarray of int, shape(K,) or integer + m : ndarray of int, shape(K,) azimuthal mode number(s) dr : int order of derivative (Default = 0) Returns ------- - out : ndarray, shape(N,K) + y : ndarray, shape(N,K) basis function(s) evaluated at specified points """ - dr = jnp.asarray(dr).astype(int) - - branches = [ - _zernike_radial_vectorized, - _zernike_radial_vectorized_d1, - _zernike_radial_vectorized_d2, - _zernike_radial_vectorized_d3, - _zernike_radial_vectorized_d4, - ] - return switch(dr, branches, r, l, m, dr) - - -@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)") -def _zernike_radial_vectorized(r, l, m, dr): - """Calculation of Radial part of Zernike polynomials.""" - - def body_inner(N, args): - alpha, out, P_past = args - P_n2 = P_past[0] # Jacobi at N-2 - P_n1 = P_past[1] # Jacobi at N-1 - P_n = jacobi_poly_single(r_jacobi, N, alpha, 0, P_n1, P_n2) - - # Calculate Radial part of Zernike for N,alpha - result = (-1) ** N * r**alpha * P_n - # Check if the calculated values is in the given modes - mask = jnp.logical_and(m == alpha, n == N) - out = jnp.where(mask, result, out) - - # Shift past values if needed - # For derivative order dx, if N is smaller than 2+dx, then only the initial - # value calculated by find_init_jacobi function will be used. So, if you update - # P_n's, preceeding values will be wrong. - mask = N >= 2 - P_n2 = jnp.where(mask, P_n1, P_n2) - P_n1 = jnp.where(mask, P_n, P_n1) - # Form updated P_past matrix - P_past = P_past.at[0].set(P_n2) - P_past = P_past.at[1].set(P_n1) - - return (alpha, out, P_past) - - def body(alpha, out): - # find l values with m values equal to alpha - l_alpha = jnp.where(m == alpha, l, 0) - # find the maximum among them - L_max = jnp.max(l_alpha) - # Maximum possible value for n for loop bound - N_max = (L_max - alpha) // 2 - - # First 2 Jacobi Polynomials (they don't need recursion) - # P_past stores last 2 Jacobi polynomials (and required derivatives) - # evaluated at given r points - P_past = jnp.zeros(2) - P_past = P_past.at[0].set(jacobi_poly_single(r_jacobi, 0, alpha, beta=0)) - # Jacobi for n=1 - P_past = P_past.at[1].set(jacobi_poly_single(r_jacobi, 1, alpha, beta=0)) - - # Loop over every n value - _, out, _ = fori_loop( - 0, (N_max + 1).astype(int), body_inner, (alpha, out, P_past) + m = jnp.abs(m).astype(float) + alpha = m + beta = 0 + n = (l - m) // 2 + s = (-1) ** n + jacobi_arg = 1 - 2 * r**2 + if dr == 0: + out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0) + elif dr == 1: + f = _jacobi(n, alpha, beta, jacobi_arg, 0) + df = _jacobi(n, alpha, beta, jacobi_arg, 1) + out = m * r ** jnp.maximum(m - 1, 0) * f - 4 * r ** (m + 1) * df + elif dr == 2: + f = _jacobi(n, alpha, beta, jacobi_arg, 0) + df = _jacobi(n, alpha, beta, jacobi_arg, 1) + d2f = _jacobi(n, alpha, beta, jacobi_arg, 2) + out = ( + (m - 1) * m * r ** jnp.maximum(m - 2, 0) * f + - 4 * (2 * m + 1) * r**m * df + + 16 * r ** (m + 2) * d2f ) - return out - - # Make inputs 1D arrays in case they aren't - m = jnp.atleast_1d(m) - l = jnp.atleast_1d(l) - - # From the vectorization, the overall output will be (r.size, m.size) - out = jnp.zeros(m.size) - r_jacobi = 1 - 2 * r**2 - m = jnp.abs(m) - n = ((l - m) // 2).astype(int) - - M_max = jnp.max(m) - # Loop over every different m value. There is another nested - # loop which will execute necessary n values. - out = fori_loop(0, (M_max + 1).astype(int), body, (out)) - return out - - -def jacobi_poly_single(x, n, alpha, beta=0, P_n1=0, P_n2=0): - """Evaluate Jacobi for single alpha and n pair.""" - c = 2 * n + alpha + beta - a1 = 2 * n * (c - n) * (c - 2) - a2 = (c - 1) * (c * (c - 2) * x + (alpha - beta) * (alpha + beta)) - a3 = 2 * (n + alpha - 1) * (n + beta - 1) * c - - # Check if a1 is 0, to prevent division by 0 - a1 = jnp.where(a1 == 0, 1e-6, a1) - P_n = (a2 * P_n1 - a3 * P_n2) / a1 - # Checks for special cases - P_n = jnp.where(n < 0, 0, P_n) - P_n = jnp.where(n == 0, 1, P_n) - P_n = jnp.where(n == 1, (alpha + 1) + (alpha + beta + 2) * (x - 1) / 2, P_n) - return P_n - - -@zernike_radial.defjvp -def _zernike_radial_jvp(x, xdot): - (r, l, m, dr) = x - (rdot, ldot, mdot, drdot) = xdot - f = zernike_radial(r, l, m, dr) - df = zernike_radial(r, l, m, dr + 1) - # in theory l, m, dr aren't differentiable (they're integers) - # but marking them as non-diff argnums seems to cause escaped tracer values. - # probably a more elegant fix, but just setting those derivatives to zero seems - # to work fine. - return f, (df.T * rdot).T + 0 * ldot + 0 * mdot + 0 * drdot + elif dr == 3: + f = _jacobi(n, alpha, beta, jacobi_arg, 0) + df = _jacobi(n, alpha, beta, jacobi_arg, 1) + d2f = _jacobi(n, alpha, beta, jacobi_arg, 2) + d3f = _jacobi(n, alpha, beta, jacobi_arg, 3) + out = ( + (m - 2) * (m - 1) * m * r ** jnp.maximum(m - 3, 0) * f + - 12 * m**2 * r ** jnp.maximum(m - 1, 0) * df + + 48 * (m + 1) * r ** (m + 1) * d2f + - 64 * r ** (m + 3) * d3f + ) + elif dr == 4: + f = _jacobi(n, alpha, beta, jacobi_arg, 0) + df = _jacobi(n, alpha, beta, jacobi_arg, 1) + d2f = _jacobi(n, alpha, beta, jacobi_arg, 2) + d3f = _jacobi(n, alpha, beta, jacobi_arg, 3) + d4f = _jacobi(n, alpha, beta, jacobi_arg, 4) + out = ( + (m - 3) * (m - 2) * (m - 1) * m * r ** jnp.maximum(m - 4, 0) * f + - 8 * m * (2 * m**2 - 3 * m + 1) * r ** jnp.maximum(m - 2, 0) * df + + 48 * (2 * m**2 + 2 * m + 1) * r**m * d2f + - 128 * (2 * m + 3) * r ** (m + 2) * d3f + + 256 * r ** (m + 4) * d4f + ) + else: + raise NotImplementedError( + "Analytic radial derivatives of Zernike polynomials for order>4 " + + "have not been implemented." + ) + return s * jnp.where((l - m) % 2 == 0, out, 0) def power_coeffs(l): @@ -1749,419 +1683,117 @@ def zernike_norm(l, m): return np.sqrt((2 * (l + 1)) / (np.pi * (1 + int(m == 0)))) -def find_intermadiate_jacobi(dx, args): - """Finds Jacobi function and its derivatives for nth loop.""" - r_jacobi, N, alpha, P_n1, P_n2, P_n = args - P_n = P_n.at[dx].set( - jacobi_poly_single(r_jacobi, N - dx, alpha + dx, dx, P_n1[dx], P_n2[dx]) - ) - return (r_jacobi, N, alpha, P_n1, P_n2, P_n) - - -def update_zernike_output(i, args): - """Updates Zernike radial output, if the mode is in the inputs.""" - m, n, alpha, N, result, out = args - idx = jnp.where(jnp.logical_and(m[i] == alpha, n[i] == N), i, -1) - - def falseFun(args): - _, _, out = args - return out - - def trueFun(args): - idx, result, out = args - out = out.at[idx].set(result) - return out - - out = cond(idx >= 0, trueFun, falseFun, (idx, result, out)) - return (m, n, alpha, N, result, out) - - -def find_initial_jacobi(dx, args): - """Finds initial values of Jacobi Polynomial and derivatives.""" - r_jacobi, alpha, P_past = args - # Jacobi for n=0 - P_past = P_past.at[0, dx].set(jacobi_poly_single(r_jacobi, 0, alpha + dx, beta=dx)) - # Jacobi for n=1 - P_past = P_past.at[1, dx].set(jacobi_poly_single(r_jacobi, 1, alpha + dx, beta=dx)) - return (r_jacobi, alpha, P_past) - - -@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)") -def _zernike_radial_vectorized_d1(r, l, m, dr): - """First derivative calculation of Radial part of Zernike polynomials.""" +@jit +@jnp.vectorize +def _binom(n, k): + """Binomial coefficient. - def body_inner(N, args): - alpha, out, P_past = args - P_n2 = P_past[0] # Jacobi at N-2 - P_n1 = P_past[1] # Jacobi at N-1 - P_n = jnp.zeros(MAXDR + 1) # Jacobi at N + Implementation is only correct for positive integer n,k and n>=k - # Calculate Jacobi polynomial and derivatives for (alpha,N) - _, _, _, _, _, P_n = fori_loop( - 0, - MAXDR + 1, - find_intermadiate_jacobi, - (r_jacobi, N, alpha, P_n1, P_n2, P_n), - ) - # Calculate coefficients for derivatives. coef[0] will never be used. Jax - # doesn't have Gamma function directly, that's why we calculate Logarithm of - # Gamma function and then exponentiate it. - coef = jnp.exp( - gammaln(alpha + N + 1 + dxs) - dxs * jnp.log(2) - gammaln(alpha + N + 1) - ) - # 1th Derivative of Zernike Radial - result = (-1) ** N * ( - alpha * r ** jnp.maximum(alpha - 1, 0) * P_n[0] - - coef[1] * 4 * r ** (alpha + 1) * P_n[1] - ) - # Check if the calculated values is in the given modes - mask = jnp.logical_and(m == alpha, n == N) - out = jnp.where(mask, result, out) - - # Shift past values if needed - # For derivative order dx, if N is smaller than 2+dx, then only the initial - # value calculated by find_init_jacobi function will be used. So, if you update - # P_n's, preceeding values will be wrong. - mask = N >= 2 + dxs - P_n2 = jnp.where(mask, P_n1, P_n2) - P_n1 = jnp.where(mask, P_n, P_n1) - # Form updated P_past matrix - P_past = P_past.at[0, :].set(P_n2) - P_past = P_past.at[1, :].set(P_n1) - - return (alpha, out, P_past) - - def body(alpha, out): - # find l values with m values equal to alpha - l_alpha = jnp.where(m == alpha, l, 0) - # find the maximum among them - L_max = jnp.max(l_alpha) - # Maximum possible value for n for loop bound - N_max = (L_max - alpha) // 2 - - # First 2 Jacobi Polynomials (they don't need recursion) - # P_past stores last 2 Jacobi polynomials (and required derivatives) - # evaluated at given r points - P_past = jnp.zeros((2, MAXDR + 1)) - _, _, P_past = fori_loop( - 0, MAXDR + 1, find_initial_jacobi, (r_jacobi, alpha, P_past) - ) + Parameters + ---------- + n : int, array-like + number of things to choose from + k : int, array-like + number of things chosen - # Loop over every n value - _, out, _ = fori_loop( - 0, (N_max + 1).astype(int), body_inner, (alpha, out, P_past) - ) - return out - - # Make inputs 1D arrays in case they aren't - m = jnp.atleast_1d(m) - l = jnp.atleast_1d(l) - dr = jnp.asarray(dr).astype(int) - - # From the vectorization, the overall output will be (r.size, m.size) - out = jnp.zeros(m.size) - r_jacobi = 1 - 2 * r**2 - m = jnp.abs(m) - n = ((l - m) // 2).astype(int) - - # This part can be better implemented. Try to make dr as static argument - # jnp.vectorize doesn't allow it to be static - MAXDR = 1 - dxs = jnp.arange(0, MAXDR + 1) - - M_max = jnp.max(m) - # Loop over every different m value. There is another nested - # loop which will execute necessary n values. - out = fori_loop(0, (M_max + 1).astype(int), body, (out)) - return out - - -@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)") -def _zernike_radial_vectorized_d2(r, l, m, dr): - """Second derivative calculation of Radial part of Zernike polynomials.""" - - def body_inner(N, args): - alpha, out, P_past = args - P_n2 = P_past[0] # Jacobi at N-2 - P_n1 = P_past[1] # Jacobi at N-1 - P_n = jnp.zeros(MAXDR + 1) # Jacobi at N - - # Calculate Jacobi polynomial and derivatives for (alpha,N) - _, _, _, _, _, P_n = fori_loop( - 0, - MAXDR + 1, - find_intermadiate_jacobi, - (r_jacobi, N, alpha, P_n1, P_n2, P_n), - ) + Returns + ------- + val : int, float, array-like + number of possible combinations + """ + # adapted from scipy: + # https://github.com/scipy/scipy/blob/701ffcc8a6f04509d115aac5e5681c538b5265a2/ + # scipy/special/orthogonal_eval.pxd#L68 - # Calculate coefficients for derivatives. coef[0] will never be used. Jax - # doesn't have Gamma function directly, that's why we calculate Logarithm of - # Gamma function and then exponentiate it. - coef = jnp.exp( - gammaln(alpha + N + 1 + dxs) - dxs * jnp.log(2) - gammaln(alpha + N + 1) - ) + n, k = map(jnp.asarray, (n, k)) - result = (-1) ** N * ( - (alpha - 1) * alpha * r ** jnp.maximum(alpha - 2, 0) * P_n[0] - - coef[1] * 4 * (2 * alpha + 1) * r**alpha * P_n[1] - + coef[2] * 16 * r ** (alpha + 2) * P_n[2] - ) - # Check if the calculated values is in the given modes - mask = jnp.logical_and(m == alpha, n == N) - out = jnp.where(mask, result, out) - - # Shift past values if needed - # For derivative order dx, if N is smaller than 2+dx, then only the initial - # value calculated by find_init_jacobi function will be used. So, if you update - # P_n's, preceeding values will be wrong. - mask = N >= 2 + dxs - P_n2 = jnp.where(mask, P_n1, P_n2) - P_n1 = jnp.where(mask, P_n, P_n1) - # Form updated P_past matrix - P_past = P_past.at[0, :].set(P_n2) - P_past = P_past.at[1, :].set(P_n1) - - return (alpha, out, P_past) - - def body(alpha, out): - # find l values with m values equal to alpha - l_alpha = jnp.where(m == alpha, l, 0) - # find the maximum among them - L_max = jnp.max(l_alpha) - # Maximum possible value for n for loop bound - N_max = (L_max - alpha) // 2 - - # First 2 Jacobi Polynomials (they don't need recursion) - # P_past stores last 2 Jacobi polynomials (and required derivatives) - # evaluated at given r points - P_past = jnp.zeros((2, MAXDR + 1)) - _, _, P_past = fori_loop( - 0, MAXDR + 1, find_initial_jacobi, (r_jacobi, alpha, P_past) - ) - - # Loop over every n value - _, out, _ = fori_loop( - 0, (N_max + 1).astype(int), body_inner, (alpha, out, P_past) - ) - return out - - # Make inputs 1D arrays in case they aren't - m = jnp.atleast_1d(m) - l = jnp.atleast_1d(l) - dr = jnp.asarray(dr).astype(int) - - # From the vectorization, the overall output will be (r.size, m.size) - out = jnp.zeros(m.size) - r_jacobi = 1 - 2 * r**2 - m = jnp.abs(m) - n = ((l - m) // 2).astype(int) - - # This part can be better implemented. Try to make dr as static argument - # jnp.vectorize doesn't allow it to be static - MAXDR = 2 - dxs = jnp.arange(0, MAXDR + 1) - - M_max = jnp.max(m) - # Loop over every different m value. There is another nested - # loop which will execute necessary n values. - out = fori_loop(0, (M_max + 1).astype(int), body, (out)) - return out - - -@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)") -def _zernike_radial_vectorized_d3(r, l, m, dr): - """Third derivative calculation of Radial part of Zernike polynomials.""" - - def body_inner(N, args): - alpha, out, P_past = args - P_n2 = P_past[0] # Jacobi at N-2 - P_n1 = P_past[1] # Jacobi at N-1 - P_n = jnp.zeros(MAXDR + 1) # Jacobi at N - - # Calculate Jacobi polynomial and derivatives for (alpha,N) - _, _, _, _, _, P_n = fori_loop( - 0, - MAXDR + 1, - find_intermadiate_jacobi, - (r_jacobi, N, alpha, P_n1, P_n2, P_n), - ) + def _binom_body_fun(i, b_n): + b, n = b_n + num = n + 1 - i + den = i + return (b * num / den, n) - # Calculate coefficients for derivatives. coef[0] will never be used. Jax - # doesn't have Gamma function directly, that's why we calculate Logarithm of - # Gamma function and then exponentiate it. - coef = jnp.exp( - gammaln(alpha + N + 1 + dxs) - dxs * jnp.log(2) - gammaln(alpha + N + 1) - ) + kx = k.astype(int) + b, n = fori_loop(1, 1 + kx, _binom_body_fun, (1.0, n)) + return b - # 3rd Derivative of Zernike Radial - result = (-1) ** N * ( - (alpha - 2) * (alpha - 1) * alpha * r ** jnp.maximum(alpha - 3, 0) * P_n[0] - - coef[1] * 12 * alpha**2 * r ** jnp.maximum(alpha - 1, 0) * P_n[1] - + coef[2] * 48 * (alpha + 1) * r ** (alpha + 1) * P_n[2] - - coef[3] * 64 * r ** (alpha + 3) * P_n[3] - ) - # Check if the calculated values is in the given modes - mask = jnp.logical_and(m == alpha, n == N) - out = jnp.where(mask, result, out) - - # Shift past values if needed - # For derivative order dx, if N is smaller than 2+dx, then only the initial - # value calculated by find_init_jacobi function will be used. So, if you update - # P_n's, preceeding values will be wrong. - mask = N >= 2 + dxs - P_n2 = jnp.where(mask, P_n1, P_n2) - P_n1 = jnp.where(mask, P_n, P_n1) - # Form updated P_past matrix - P_past = P_past.at[0, :].set(P_n2) - P_past = P_past.at[1, :].set(P_n1) - - return (alpha, out, P_past) - - def body(alpha, out): - # find l values with m values equal to alpha - l_alpha = jnp.where(m == alpha, l, 0) - # find the maximum among them - L_max = jnp.max(l_alpha) - # Maximum possible value for n for loop bound - N_max = (L_max - alpha) // 2 - - # First 2 Jacobi Polynomials (they don't need recursion) - # P_past stores last 2 Jacobi polynomials (and required derivatives) - # evaluated at given r points - P_past = jnp.zeros((2, MAXDR + 1)) - _, _, P_past = fori_loop( - 0, MAXDR + 1, find_initial_jacobi, (r_jacobi, alpha, P_past) - ) - # Loop over every n value - _, out, _ = fori_loop( - 0, (N_max + 1).astype(int), body_inner, (alpha, out, P_past) - ) - return out - - # Make inputs 1D arrays in case they aren't - m = jnp.atleast_1d(m) - l = jnp.atleast_1d(l) - dr = jnp.asarray(dr).astype(int) - - # From the vectorization, the overall output will be (r.size, m.size) - out = jnp.zeros(m.size) - r_jacobi = 1 - 2 * r**2 - m = jnp.abs(m) - n = ((l - m) // 2).astype(int) - - # This part can be better implemented. Try to make dr as static argument - # jnp.vectorize doesn't allow it to be static - MAXDR = 3 - dxs = jnp.arange(0, MAXDR + 1) - - M_max = jnp.max(m) - # Loop over every different m value. There is another nested - # loop which will execute necessary n values. - out = fori_loop(0, (M_max + 1).astype(int), body, (out)) - return out - - -@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)") -def _zernike_radial_vectorized_d4(r, l, m, dr): - """Fourth derivative calculation of Radial part of Zernike polynomials.""" - - def body_inner(N, args): - alpha, out, P_past = args - P_n2 = P_past[0] # Jacobi at N-2 - P_n1 = P_past[1] # Jacobi at N-1 - P_n = jnp.zeros(MAXDR + 1) # Jacobi at N - - # Calculate Jacobi polynomial and derivatives for (alpha,N) - _, _, _, _, _, P_n = fori_loop( - 0, - MAXDR + 1, - find_intermadiate_jacobi, - (r_jacobi, N, alpha, P_n1, P_n2, P_n), - ) +@custom_jvp +@jit +@jnp.vectorize +def _jacobi(n, alpha, beta, x, dx=0): + """Jacobi polynomial evaluation. - # Calculate coefficients for derivatives. coef[0] will never be used. Jax - # doesn't have Gamma function directly, that's why we calculate Logarithm of - # Gamma function and then exponentiate it. - coef = jnp.exp( - gammaln(alpha + N + 1 + dxs) - dxs * jnp.log(2) - gammaln(alpha + N + 1) - ) + Implementation is only correct for non-negative integer coefficients, + returns 0 otherwise. - # 4th Derivative of Zernike Radial - result = (-1) ** N * ( - (alpha - 3) - * (alpha - 2) - * (alpha - 1) - * alpha - * r ** jnp.maximum(alpha - 4, 0) - * P_n[0] - - coef[1] - * 8 - * alpha - * (2 * alpha**2 - 3 * alpha + 1) - * r ** jnp.maximum(alpha - 2, 0) - * P_n[1] - + coef[2] * 48 * (2 * alpha**2 + 2 * alpha + 1) * r**alpha * P_n[2] - - coef[3] * 128 * (2 * alpha + 3) * r ** (alpha + 2) * P_n[3] - + coef[4] * 256 * r ** (alpha + 4) * P_n[4] - ) - # Check if the calculated values is in the given modes - mask = jnp.logical_and(m == alpha, n == N) - out = jnp.where(mask, result, out) - - # Shift past values if needed - # For derivative order dx, if N is smaller than 2+dx, then only the initial - # value calculated by find_init_jacobi function will be used. So, if you update - # P_n's, preceeding values will be wrong. - mask = N >= 2 + dxs - P_n2 = jnp.where(mask, P_n1, P_n2) - P_n1 = jnp.where(mask, P_n, P_n1) - # Form updated P_past matrix - P_past = P_past.at[0, :].set(P_n2) - P_past = P_past.at[1, :].set(P_n1) - - return (alpha, out, P_past) - - def body(alpha, out): - # find l values with m values equal to alpha - l_alpha = jnp.where(m == alpha, l, 0) - # find the maximum among them - L_max = jnp.max(l_alpha) - # Maximum possible value for n for loop bound - N_max = (L_max - alpha) // 2 - - # First 2 Jacobi Polynomials (they don't need recursion) - # P_past stores last 2 Jacobi polynomials (and required derivatives) - # evaluated at given r points - P_past = jnp.zeros((2, MAXDR + 1)) - _, _, P_past = fori_loop( - 0, MAXDR + 1, find_initial_jacobi, (r_jacobi, alpha, P_past) - ) + Parameters + ---------- + n : int, array_like + Degree of the polynomial. + alpha : int, array_like + Parameter + beta : int, array_like + Parameter + x : float, array_like + Points at which to evaluate the polynomial - # Loop over every n value - _, out, _ = fori_loop( - 0, (N_max + 1).astype(int), body_inner, (alpha, out, P_past) - ) - return out - - # Make inputs 1D arrays in case they aren't - m = jnp.atleast_1d(m) - l = jnp.atleast_1d(l) - dr = jnp.asarray(dr).astype(int) - - # From the vectorization, the overall output will be (r.size, m.size) - out = jnp.zeros(m.size) - r_jacobi = 1 - 2 * r**2 - m = jnp.abs(m) - n = ((l - m) // 2).astype(int) - - # This part can be better implemented. Try to make dr as static argument - # jnp.vectorize doesn't allow it to be static - MAXDR = 4 - dxs = jnp.arange(0, MAXDR + 1) - - M_max = jnp.max(m) - # Loop over every different m value. There is another nested - # loop which will execute necessary n values. - out = fori_loop(0, (M_max + 1).astype(int), body, (out)) - return out + Returns + ------- + P : ndarray + Values of the Jacobi polynomial + """ + # adapted from scipy: + # https://github.com/scipy/scipy/blob/701ffcc8a6f04509d115aac5e5681c538b5265a2/ + # scipy/special/orthogonal_eval.pxd#L144 + + def _jacobi_body_fun(kk, d_p_a_b_x): + d, p, alpha, beta, x = d_p_a_b_x + k = kk + 1.0 + t = 2 * k + alpha + beta + d = ( + (t * (t + 1) * (t + 2)) * (x - 1) * p + 2 * k * (k + beta) * (t + 2) * d + ) / (2 * (k + alpha + 1) * (k + alpha + beta + 1) * t) + p = d + p + return (d, p, alpha, beta, x) + + n, alpha, beta, x = map(jnp.asarray, (n, alpha, beta, x)) + + # coefficient for derivative + c = ( + gammaln(alpha + beta + n + 1 + dx) + - dx * jnp.log(2) + - gammaln(alpha + beta + n + 1) + ) + c = jnp.exp(c) + # taking derivative is same as coeff*jacobi but for shifted n,a,b + n -= dx + alpha += dx + beta += dx + + d = (alpha + beta + 2) * (x - 1) / (2 * (alpha + 1)) + p = d + 1 + d, p, alpha, beta, x = fori_loop( + 0, jnp.maximum(n - 1, 0).astype(int), _jacobi_body_fun, (d, p, alpha, beta, x) + ) + out = _binom(n + alpha, n) * p + # should be complex for n<0, but it gets replaced elsewhere so just return 0 here + out = jnp.where(n < 0, 0, out) + # other edge cases + out = jnp.where(n == 0, 1.0, out) + out = jnp.where(n == 1, 0.5 * (2 * (alpha + 1) + (alpha + beta + 2) * (x - 1)), out) + return c * out + + +@_jacobi.defjvp +def _jacobi_jvp(x, xdot): + (n, alpha, beta, x, dx) = x + (ndot, alphadot, betadot, xdot, dxdot) = xdot + f = _jacobi(n, alpha, beta, x, dx) + df = _jacobi(n, alpha, beta, x, dx + 1) + # in theory n, alpha, beta, dx aren't differentiable (they're integers) + # but marking them as non-diff argnums seems to cause escaped tracer values. + # probably a more elegant fix, but just setting those derivatives to zero seems + # to work fine. + return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot diff --git a/desc/equilibrium/initial_guess.py b/desc/equilibrium/initial_guess.py index 05c5fa7703..f2f39915f5 100644 --- a/desc/equilibrium/initial_guess.py +++ b/desc/equilibrium/initial_guess.py @@ -314,7 +314,6 @@ def body(k, x_lmn): # now overwrite stuff to deal with the axis scale = zernike_radial(coord, 0, 0) - scale = scale.flatten() for k, (l, m, n) in enumerate(b_basis.modes): if m != 0: continue diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index 4f7587da6b..5b3c6fc274 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -150,7 +150,7 @@ def build(self, use_jit=False, verbose=1): errorif( par not in thing.optimizable_params, ValueError, - f"parameter {par} not found in optimizable_parameters: " + f"couldn't find parameter {par} in optimizable_parameters: " + f"{thing.optimizable_params}", ) self._params = params diff --git a/desc/vmec_utils.py b/desc/vmec_utils.py index 771f8a37a1..9ea5bf7df9 100644 --- a/desc/vmec_utils.py +++ b/desc/vmec_utils.py @@ -299,7 +299,7 @@ def fourier_to_zernike(m, n, x_mn, basis): surfs = x_mn.shape[0] rho = np.sqrt(np.linspace(0, 1, surfs)) - As = zernike_radial(rho, basis.modes[:, 0], basis.modes[:, 1]) + As = zernike_radial(rho[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1]) for k in range(len(m)): idx = np.where((basis.modes[:, 1:] == [m[k], n[k]]).all(axis=1))[0] if len(idx): @@ -343,7 +343,7 @@ def zernike_to_fourier(x_lmn, basis, rho): n = mn[:, 1] x_mn = np.zeros((rho.size, m.size)) - As = zernike_radial(rho, basis.modes[:, 0], basis.modes[:, 1]) + As = zernike_radial(rho[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1]) for k in range(len(m)): idx = np.where((basis.modes[:, 1:] == [m[k], n[k]]).all(axis=1))[0] if len(idx): diff --git a/docs/notebooks/zernike_eval.ipynb b/docs/notebooks/zernike_eval.ipynb index 1525506d28..4b6d1cdaa8 100644 --- a/docs/notebooks/zernike_eval.ipynb +++ b/docs/notebooks/zernike_eval.ipynb @@ -163,13 +163,13 @@ ], "source": [ "print(\"zernike_radial, 0th derivative\")\n", - "%timeit -n 1000 _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], 0).block_until_ready()\n", + "%timeit -n 1000 _ = zernike_radial(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 0).block_until_ready()\n", "print(\"zernike_radial, 1st derivative\")\n", - "%timeit -n 1000 _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], 1).block_until_ready()\n", + "%timeit -n 1000 _ = zernike_radial(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 1).block_until_ready()\n", "print(\"zernike_radial, 2nd derivative\")\n", - "%timeit -n 1000 _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], 2).block_until_ready()\n", + "%timeit -n 1000 _ = zernike_radial(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 2).block_until_ready()\n", "print(\"zernike_radial, 3rd derivative\")\n", - "%timeit -n 1000 _ = zernike_radial(r, basis.modes[:,0], basis.modes[:,1], 3).block_until_ready()" + "%timeit -n 1000 _ = zernike_radial(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 3).block_until_ready()" ] }, { @@ -294,10 +294,10 @@ "metadata": {}, "outputs": [], "source": [ - "zr0 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 0)\n", - "zr1 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 1)\n", - "zr2 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 2)\n", - "zr3 = zernike_radial(r, basis.modes[:, 0], basis.modes[:, 1], 3)\n", + "zr0 = zernike_radial(r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 0)\n", + "zr1 = zernike_radial(r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 1)\n", + "zr2 = zernike_radial(r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 2)\n", + "zr3 = zernike_radial(r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], 3)\n", "zp0 = zernike_radial_poly(\n", " r[:, np.newaxis], basis.modes[:, 0], basis.modes[:, 1], dr=0, exact=False\n", ")\n", @@ -927,7 +927,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.0" } }, "nbformat": 4, diff --git a/tests/inputs/master_compute_data.pkl b/tests/inputs/master_compute_data.pkl index 1d641b6681..4e72ac112d 100644 Binary files a/tests/inputs/master_compute_data.pkl and b/tests/inputs/master_compute_data.pkl differ diff --git a/tests/test_basis.py b/tests/test_basis.py index ace86b2c45..74c81c403a 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from desc.backend import jnp from desc.basis import ( ChebyshevDoubleFourierBasis, ChebyshevPolynomial, @@ -12,9 +13,9 @@ FourierZernikeBasis, PowerSeries, ZernikePolynomial, + _jacobi, chebyshev, fourier, - jacobi_poly_single, polyder_vec, polyval_vec, powers, @@ -22,6 +23,7 @@ zernike_radial_coeffs, zernike_radial_poly, ) +from desc.derivatives import Derivative from desc.grid import LinearGrid @@ -97,10 +99,10 @@ def test_polyval_exact(self): ).T mpmath.mp.dps = 15 - approx1f = zernike_radial(r, l, m) - approx1df = zernike_radial(r, l, m, dr=1) - approx1ddf = zernike_radial(r, l, m, dr=2) - approx1dddf = zernike_radial(r, l, m, dr=3) + approx1f = zernike_radial(r[:, np.newaxis], l, m) + approx1df = zernike_radial(r[:, np.newaxis], l, m, dr=1) + approx1ddf = zernike_radial(r[:, np.newaxis], l, m, dr=2) + approx1dddf = zernike_radial(r[:, np.newaxis], l, m, dr=3) approx2f = zernike_radial_poly(r[:, np.newaxis], l, m) approx2df = zernike_radial_poly(r[:, np.newaxis], l, m, dr=1) approx2ddf = zernike_radial_poly(r[:, np.newaxis], l, m, dr=2) @@ -200,7 +202,9 @@ def Z6_2(x, dx=0): dr: np.array([Z3_1(r, dr), Z4_2(r, dr), Z6_2(r, dr), Z4_2(r, dr)]).T for dr in range(max_dr + 1) } - radial = {dr: zernike_radial(r, l, m, dr) for dr in range(max_dr + 1)} + radial = { + dr: zernike_radial(r[:, np.newaxis], l, m, dr) for dr in range(max_dr + 1) + } radial_poly = { dr: zernike_radial_poly(r[:, np.newaxis], l, m, dr) for dr in range(max_dr + 1) @@ -209,43 +213,6 @@ def Z6_2(x, dx=0): np.testing.assert_allclose(radial[dr], desired[dr], err_msg=dr) np.testing.assert_allclose(radial_poly[dr], desired[dr], err_msg=dr) - @pytest.mark.unit - def test_jacobi_poly_single(self): - """Test Jacobi Polynomial evaluation for special cases.""" - # https://en.wikipedia.org/wiki/Jacobi_polynomials#Special_cases - - def exact(r, n, alpha, beta): - if n == 0: - return np.ones_like(r) - elif n == 1: - return (alpha + 1) + (alpha + beta + 2) * ((r - 1) / 2) - elif n == 2: - a0 = (alpha + 1) * (alpha + 2) / 2 - a1 = (alpha + 2) * (alpha + beta + 3) - a2 = (alpha + beta + 3) * (alpha + beta + 4) / 2 - z = (r - 1) / 2 - return a0 + a1 * z + a2 * z**2 - elif n < 0: - return np.zeros_like(r) - - r = np.linspace(0, 1, 11) - # alpha and beta pairs for test - pairs = np.array([[2, 3], [3, 0], [1, 1], [10, 4]]) - n_values = np.array([-1, -2, 0, 1, 2]) - - for pair in pairs: - alpha = pair[0] - beta = pair[1] - P0 = jacobi_poly_single(r, 0, alpha, beta) - P1 = jacobi_poly_single(r, 1, alpha, beta) - desired = {n: exact(r, n, alpha, beta) for n in n_values} - values = { - n: jacobi_poly_single(r, n, alpha, beta, P1, P0) for n in n_values - } - - for n in n_values: - np.testing.assert_allclose(values[n], desired[n], err_msg=n) - @pytest.mark.unit def test_fourier(self): """Test Fourier series evaluation.""" @@ -432,3 +399,23 @@ def test_basis_resolutions_assert_integers(self): _ = ChebyshevDoubleFourierBasis(L=3, M=1, N=1.0) with pytest.raises(ValueError): _ = ChebyshevDoubleFourierBasis(L=3, M=1, N=1, NFP=1.0) + + +@pytest.mark.unit +def test_jacobi_jvp(): + """Test that custom derivative rule for jacobi polynomials works.""" + basis = ZernikePolynomial(25, 25) + l, m = basis.modes[:, :2].T + m = jnp.abs(m) + alpha = m + beta = 0 + n = (l - m) // 2 + r = np.linspace(0, 1, 1000) + jacobi_arg = 1 - 2 * r**2 + for i in range(5): + # custom jvp rule for derivative of jacobi should just call jacobi with dx+1 + f1 = jnp.vectorize(Derivative(_jacobi, 3, "grad"))( + n, alpha, beta, jacobi_arg[:, None], i + ) + f2 = _jacobi(n, alpha, beta, jacobi_arg[:, None], i + 1) + np.testing.assert_allclose(f1, f2)