Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve coordinate mapping performance #1154

Open
unalmis opened this issue Jul 29, 2024 · 6 comments
Open

Improve coordinate mapping performance #1154

unalmis opened this issue Jul 29, 2024 · 6 comments
Assignees
Labels
P3 Highest Priority, someone is/should be actively working on this performance New feature or request to make the code faster waiting for other PRs

Comments

@unalmis
Copy link
Collaborator

unalmis commented Jul 29, 2024

Neoclassical stuff will be bottlenecked by root finding. The PR that closes this issue will resolve this, and the benefits should assist any other objective that requires a coordinate mapping.

The performance of the coordinate mapping routine map_clebsch_coords, #1153 and compute_theta_coords can be further improved by using partial summation techniques and precomputing basis functions.

Notice $\rho$ and $\zeta$ are fixed throughout the root finding for $\vartheta \to \theta$ or $\alpha \to \theta$. Therefore, one can condense the 3d map $\lambda \colon \rho, \theta, \zeta \mapsto \lambda(\rho, \theta, \zeta)$ to a set of 1d maps $\lambda \colon \rho, \theta, \zeta \mapsto \lambda_{\rho, \zeta}(\theta)$ by computing those parts of each Fourier Zernike basis function exterior to the rootfinding. This significantly reduces the computation to

$$\lambda_{\rho, \zeta}(\theta) = \sum_{\ell, m, n} a_{\ell,m,n} \phi_{\ell}(\theta)$$

Moreover, for rootfinding on a tensor-product grid in $\rho, \zeta$, one can further reduce evaluation to the inner product

$$\lambda_{\rho, \zeta}(\theta) = \sum_{\ell} b_{\ell} \phi_{\ell}(\theta)$$

or a single matrix product for all the $\rho, \zeta$.

Assuming $L = M = N$, this reduces the computation from $\mathcal{O}(N^6) \to \mathcal{O}(N^4)$ with the same constant of proportionality in the front if implemented correctly.

The literature mentions such partial summation techniques have achieved factors of 10000 improvement in runtime for $N = 100$. Given that those were actually compiled codes, I doubt that JIT is smart enough to find an algorithmic optimization like this.

Only benchmarks will show if we obtain the same benefit, but it should be wortwhile for us as this computation is in a Newton iteration. That means for us $N$ is proportional to the grid resolution times the number of iterations of the root finding.

For bounce integrals a typical resolution may be $X = 32$ and $Y = 64$.
So in that case, we have something like $N^2 \sim Y \times F \times \text{num iter}$ where $F$ is resolution of $\lambda$. Less computation means less memory consumed by AD to keep track of all the gradients, so this should also decrease memory in optimization.

@unalmis unalmis added performance New feature or request to make the code faster theory Requires theory work before coding labels Jul 29, 2024
@unalmis unalmis self-assigned this Jul 29, 2024
@f0uriest
Copy link
Member

f0uriest commented Jul 31, 2024

I like the idea of partial summation, I'm also wondering if we could figure out how to do it in map_coordinates when inbasis and outbasis share certain coordinates? If we could get that to work we may not need compute_{theta,clebsch}_coords which would simplify the API (and in general would improve other coordinate mapping stuff.

Some tests:

basis1 = desc.basis.FourierZernikeBasis(12,12,12)
basis2 = desc.basis.DoubleFourierSeries(12,12)
basis3 = desc.basis.FourierSeries(12)

rho = theta = zeta = np.linspace(0,1,10000)

@jit
def compute1(theta, zeta):
    nodes = jnp.array([rho, theta, zeta]).T
    return basis1.evaluate(nodes)

@jit
def compute2(rho, theta, zeta):
    nodes = jnp.array([rho, theta, zeta]).T
    return basis1.evaluate(nodes)

@jit
def compute3(theta, zeta):
    nodes = jnp.array([rho, theta, zeta]).T
    # use basis1 modes with duplicate m,n for fair comparison
    # this is basically equivalent to eval the fourier zernike
    # assuming fixed rho is precomputed
    return basis2.evaluate(nodes, modes=basis1.modes)

@jit
def compute4(theta):
    nodes = jnp.array([rho, theta, zeta]).T
    # use basis1 modes with duplicate m,n for fair comparison
    # this is basically equivalent to eval the fourier zernike
    # for fixed rho, zeta
    return basis3.evaluate(nodes, modes=basis1.modes)

%timeit _ = compute1(theta, zeta).block_until_ready()
# 732 ms ± 24.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _ = compute2(rho, theta, zeta).block_until_ready()
# 760 ms ± 24.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _ = compute3(theta, zeta).block_until_ready()
# 178 ms ± 8.71 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit _ = compute4(theta).block_until_ready()
# 121 ms ± 1.86 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

It looks like just partial evaluation might be a significant speedup when rho is fixed, even if we don't do the full partial summation to reduce it further. But yeah, it does seem like jax doesn't automatically compile away the constant part (the first two have roughly the same time even though rho is constant in the first one)

@unalmis
Copy link
Collaborator Author

unalmis commented Aug 20, 2024

Also after this make use of partial summation in tranforms and basis

unalmis added a commit that referenced this issue Sep 3, 2024
After the recent refactoring to the `Bounce1D` class that resulted from
#1214, the API is a little too strict for computations like effective
ripple etc. where we vectorize the computation over over some dimensions
and loop over others to save memory.

This PR changes the expected shape of the pitch angle input to
`Bounce1D` in #854 from `(P, M, L)` to `(M, L, P)`. With this change,
the two leading axes of all inputs to the methods in that class is `(M,
L)`.

These changes are tested and already included in downstream branches. I
am making new PR instead of directly committing to the `bounce` branch
for people who have already reviewed the `bounce` PR.

 This is better because
1. Easier usage for end users. (Previously, you'd have to manually add
trailing axes to pitch angle array).
2. Makes it much simpler to use with JAX's new batched map.
3. Previously we would loop over the pitch angles to save memory.
However, this means some computation is repeated because interpax would
interpolate multiple times. By looping over the field lines instead and
doing the interpolation for all the pitch angles at once, both
`_bounce_quadrature` and `interp_to_argmin` are faster. (I'm seeing 30%
faster speed just from computing effective ripple (no optimization), but
I don't plan to do any benchmarking to see whether that is from recent
changes like #1154 or #1043 , or others).
@unalmis
Copy link
Collaborator Author

unalmis commented Oct 7, 2024

We should do partial summation regardless, but there also exists a closed form solution to this root finding problem. It is given by equation 19.36 in Boyd's chebyshev and fourier book. Have you seen /considered this before @f0uriest ?

@f0uriest
Copy link
Member

f0uriest commented Oct 7, 2024

I hadn't seen that before, but I'm not sure that completely solves the problem? You would need to evaluate lambda at a generally nonuniform grid in theta which without partial summation would still be expensive, and for fixed N only gives an approximation which may still need to be refined with newton. I think in most of our cases we only need ~<5 newton iterations so not sure if it would end up being more efficient?

@unalmis
Copy link
Collaborator Author

unalmis commented Oct 7, 2024

It won't beat ~<5 when the integration region is $\theta \in [0, 2\pi]$ unless the complexity of $\alpha(\rho, \theta, \zeta)$ on that domain is similar to a cubic.

The nice quality I see in that formula is that it enables cheap simultaneous root finding to a set of equations. E.g. when seeking a root of $B$ over a region where it is not complex for many $\lambda$ in $B - 1/\lambda = 0$

@unalmis
Copy link
Collaborator Author

unalmis commented Oct 7, 2024

It could also be useful for initial guess if convergence has ever been an issue

@unalmis unalmis mentioned this issue Oct 8, 2024
@unalmis unalmis removed the theory Requires theory work before coding label Nov 11, 2024
@dpanici dpanici added the P2 Medium Priority, not urgent but should be on the near-term agend label Nov 11, 2024
@unalmis unalmis added P3 Highest Priority, someone is/should be actively working on this and removed P2 Medium Priority, not urgent but should be on the near-term agend labels Dec 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
P3 Highest Priority, someone is/should be actively working on this performance New feature or request to make the code faster waiting for other PRs
Projects
None yet
Development

No branches or pull requests

3 participants