-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve coordinate mapping performance #1154
Comments
I like the idea of partial summation, I'm also wondering if we could figure out how to do it in Some tests: basis1 = desc.basis.FourierZernikeBasis(12,12,12)
basis2 = desc.basis.DoubleFourierSeries(12,12)
basis3 = desc.basis.FourierSeries(12)
rho = theta = zeta = np.linspace(0,1,10000)
@jit
def compute1(theta, zeta):
nodes = jnp.array([rho, theta, zeta]).T
return basis1.evaluate(nodes)
@jit
def compute2(rho, theta, zeta):
nodes = jnp.array([rho, theta, zeta]).T
return basis1.evaluate(nodes)
@jit
def compute3(theta, zeta):
nodes = jnp.array([rho, theta, zeta]).T
# use basis1 modes with duplicate m,n for fair comparison
# this is basically equivalent to eval the fourier zernike
# assuming fixed rho is precomputed
return basis2.evaluate(nodes, modes=basis1.modes)
@jit
def compute4(theta):
nodes = jnp.array([rho, theta, zeta]).T
# use basis1 modes with duplicate m,n for fair comparison
# this is basically equivalent to eval the fourier zernike
# for fixed rho, zeta
return basis3.evaluate(nodes, modes=basis1.modes)
%timeit _ = compute1(theta, zeta).block_until_ready()
# 732 ms ± 24.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _ = compute2(rho, theta, zeta).block_until_ready()
# 760 ms ± 24.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _ = compute3(theta, zeta).block_until_ready()
# 178 ms ± 8.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _ = compute4(theta).block_until_ready()
# 121 ms ± 1.86 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) It looks like just partial evaluation might be a significant speedup when rho is fixed, even if we don't do the full partial summation to reduce it further. But yeah, it does seem like jax doesn't automatically compile away the constant part (the first two have roughly the same time even though rho is constant in the first one) |
Also after this make use of partial summation in tranforms and basis |
After the recent refactoring to the `Bounce1D` class that resulted from #1214, the API is a little too strict for computations like effective ripple etc. where we vectorize the computation over over some dimensions and loop over others to save memory. This PR changes the expected shape of the pitch angle input to `Bounce1D` in #854 from `(P, M, L)` to `(M, L, P)`. With this change, the two leading axes of all inputs to the methods in that class is `(M, L)`. These changes are tested and already included in downstream branches. I am making new PR instead of directly committing to the `bounce` branch for people who have already reviewed the `bounce` PR. This is better because 1. Easier usage for end users. (Previously, you'd have to manually add trailing axes to pitch angle array). 2. Makes it much simpler to use with JAX's new batched map. 3. Previously we would loop over the pitch angles to save memory. However, this means some computation is repeated because interpax would interpolate multiple times. By looping over the field lines instead and doing the interpolation for all the pitch angles at once, both `_bounce_quadrature` and `interp_to_argmin` are faster. (I'm seeing 30% faster speed just from computing effective ripple (no optimization), but I don't plan to do any benchmarking to see whether that is from recent changes like #1154 or #1043 , or others).
We should do partial summation regardless, but there also exists a closed form solution to this root finding problem. It is given by equation 19.36 in Boyd's chebyshev and fourier book. Have you seen /considered this before @f0uriest ? |
I hadn't seen that before, but I'm not sure that completely solves the problem? You would need to evaluate lambda at a generally nonuniform grid in theta which without partial summation would still be expensive, and for fixed N only gives an approximation which may still need to be refined with newton. I think in most of our cases we only need ~<5 newton iterations so not sure if it would end up being more efficient? |
It won't beat ~<5 when the integration region is The nice quality I see in that formula is that it enables cheap simultaneous root finding to a set of equations. E.g. when seeking a root of |
It could also be useful for initial guess if convergence has ever been an issue |
Neoclassical stuff will be bottlenecked by root finding. The PR that closes this issue will resolve this, and the benefits should assist any other objective that requires a coordinate mapping.
The performance of the coordinate mapping routine
map_clebsch_coords
, #1153 andcompute_theta_coords
can be further improved by using partial summation techniques and precomputing basis functions.Notice$\rho$ and $\zeta$ are fixed throughout the root finding for $\vartheta \to \theta$ or $\alpha \to \theta$ . Therefore, one can condense the 3d map $\lambda \colon \rho, \theta, \zeta \mapsto \lambda(\rho, \theta, \zeta)$ to a set of 1d maps $\lambda \colon \rho, \theta, \zeta \mapsto \lambda_{\rho, \zeta}(\theta)$ by computing those parts of each Fourier Zernike basis function exterior to the rootfinding. This significantly reduces the computation to
Moreover, for rootfinding on a tensor-product grid in$\rho, \zeta$ , one can further reduce evaluation to the inner product
or a single matrix product for all the$\rho, \zeta$ .
Assuming$L = M = N$ , this reduces the computation from $\mathcal{O}(N^6) \to \mathcal{O}(N^4)$ with the same constant of proportionality in the front if implemented correctly.
The literature mentions such partial summation techniques have achieved factors of 10000 improvement in runtime for$N = 100$ . Given that those were actually compiled codes, I doubt that JIT is smart enough to find an algorithmic optimization like this.
Only benchmarks will show if we obtain the same benefit, but it should be wortwhile for us as this computation is in a Newton iteration. That means for us$N$ is proportional to the grid resolution times the number of iterations of the root finding.
For bounce integrals a typical resolution may be$X = 32$ and $Y = 64$ .$N^2 \sim Y \times F \times \text{num iter}$ where $F$ is resolution of $\lambda$ . Less computation means less memory consumed by AD to keep track of all the gradients, so this should also decrease memory in optimization.
So in that case, we have something like
The text was updated successfully, but these errors were encountered: