Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regression in _initial_guess_surface from v0.10.4 #941

Closed
f0uriest opened this issue Mar 16, 2024 · 9 comments · Fixed by #998
Closed

Regression in _initial_guess_surface from v0.10.4 #941

f0uriest opened this issue Mar 16, 2024 · 9 comments · Fixed by #998
Assignees

Comments

@f0uriest
Copy link
Member

On master now initializing an Equilibrium takes about 2x longer than it did on the last release v0.10.4, the culprit seems to be the _initial_guess_surface function.

In jupyter:

jax.clear_caches()
%lprun -f Equilibrium.__init__ \
   -f desc.equilibrium.initial_guess.set_initial_guess \
   -f desc.equilibrium.initial_guess._initial_guess_surface \
   eq = Equilibrium(L=12, M=12, N=12)

I suspect we didn't see it because the benchmarks run the same code over and over so you get the advantage of JIT caching, but in practice you're usually only creating a single equilibrium so JIT doesn't really help. We should add jax.clear_caches() to those benchmarks to avoid this in the future.

@f0uriest
Copy link
Member Author

f0uriest commented Apr 5, 2024

profiling function for different zernike methods:

def profile(N, func):
    basis = desc.basis.FourierZernikeBasis(N, N, N, sym="cos")
    grid = desc.grid.ConcentricGrid(2 * N, 2 * N, 2 * N)
    rho = grid.nodes[:, 0]
    print(f"{func} compile, dr=0,1,2,3")
    %timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready()
    %timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready()
    %timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready()
    %timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready()
    print(f"{func} run, dr=0,1,2,3")
    %timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready()
    %timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready()
    %timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready()
    %timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready()

@YigitElma can you try the different versions in faster zernike for cpu and gpu?

@rahulgaur104
Copy link
Collaborator

Should we also add jax.clear_caches() to our optimizers?

@f0uriest
Copy link
Member Author

f0uriest commented Apr 5, 2024

Should we also add jax.clear_caches() to our optimizers?

I don't think so. doing it too often significantly slows performance, so I think it's better to only do it at a top level if its necessary for a specific case.

@YigitElma
Copy link
Collaborator

Ok, I can compare them tomorrow.

@f0uriest
Copy link
Member Author

f0uriest commented Apr 5, 2024

thanks. I was using N=[6,9,12,15,18]

@f0uriest
Copy link
Member Author

@YigitElma any update?

@YigitElma
Copy link
Collaborator

YigitElma commented Apr 16, 2024

I'm so sorry @f0uriest . Here are the results for GPU comparing the current version on DESC and the older version.
image
image

Click for full results
Profiling Newer on GPU with N=6
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.26 s ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.26 s ± 5.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.26 s ± 5.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.26 s ± 6.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
1.17 ms ± 3.43 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.76 ms ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
2.06 ms ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.23 ms ± 2.24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Profiling Older on GPU with N=6
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
510 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
722 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
940 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.15 s ± 3.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
422 µs ± 39.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
629 µs ± 23 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
762 µs ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
884 µs ± 3.12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Profiling Newer on GPU with N=9
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.39 s ± 3.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.39 s ± 5.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.39 s ± 4.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.39 s ± 4.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
2.62 ms ± 1.83 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.74 ms ± 2.94 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.21 ms ± 3.07 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.66 ms ± 3.42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Profiling Older on GPU with N=9
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
528 ms ± 4.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
748 ms ± 2.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
971 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.2 s ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
1.47 ms ± 22 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.47 ms ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.21 ms ± 41 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.72 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Profiling Newer on GPU with N=12
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.42 s ± 4.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.43 s ± 7.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.43 s ± 4.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.42 s ± 3.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
11.8 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
13.7 ms ± 12.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
14.7 ms ± 7.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.1 ms ± 36.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Profiling Older on GPU with N=12
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
522 ms ± 929 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
740 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
964 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.19 s ± 4.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
7.38 ms ± 37.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
13.4 ms ± 55.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
18.3 ms ± 67.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
22.5 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Profiling Newer on GPU with N=15
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.58 s ± 3.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.59 s ± 4.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.59 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.59 s ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
51.1 ms ± 91.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
53.1 ms ± 17 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54 ms ± 33.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.4 ms ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Profiling Older on GPU with N=15
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
587 ms ± 1.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
842 ms ± 5.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.1 s ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.35 s ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
29.3 ms ± 62.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
52.9 ms ± 39.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
71.8 ms ± 70.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
87.7 ms ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Profiling Newer on GPU with N=18
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.76 s ± 6.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.76 s ± 4.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.76 s ± 7.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.77 s ± 9.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
191 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
193 ms ± 390 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
194 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
197 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Profiling Older on GPU with N=18
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
667 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
984 ms ± 1.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.3 s ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.6 s ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
98.1 ms ± 95.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
180 ms ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
250 ms ± 137 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
316 ms ± 190 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

it looks like custom_jvp cause all the derivatives to be compiled for the first run, which makes sense because we switch between them. Roughly the the total time of compilation is on the same order of the sum of each derivative compilation times of the old method.

@YigitElma
Copy link
Collaborator

So, if we call each derivative once for a function, overall time spent would be the same (because old method requires a new compilation but new one doesn't, at least I understand that).
Like,

N = 6
basis = FourierZernikeBasis(N, N, N, sym="cos")
grid = ConcentricGrid(2 * N, 2 * N, 2 * N)
rho = grid.nodes[:, 0]

print("new")
%timeit jax.clear_caches(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready()
    
print("old")
%timeit jax.clear_caches(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready()
new
2.3 s ± 3.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
old
3.04 s ± 7.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@YigitElma
Copy link
Collaborator

YigitElma commented Apr 16, 2024

For future reference, the CPU version of the code (with a single line change of how we update the array i.e. using foriloop instead of jnp.where) results are such,

Click for full results
Profiling Newer on CPU with N=6
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
2.11 s ± 6.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.11 s ± 7.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.11 s ± 7.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.11 s ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
2.06 ms ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.82 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.49 ms ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.05 ms ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Profiling Older on CPU with N=6
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
316 ms ± 1.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
497 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
681 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
876 ms ± 2.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
13.4 ms ± 86.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.8 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
44.1 ms ± 782 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
61.6 ms ± 574 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Profiling Newer on CPU with N=9
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
2.14 s ± 5.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.16 s ± 11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.15 s ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.16 s ± 5.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
40 ms ± 1.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.6 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
50.6 ms ± 880 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.7 ms ± 984 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Profiling Older on CPU with N=9
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
526 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
886 ms ± 3.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.23 s ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.58 s ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
204 ms ± 894 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
382 ms ± 3.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
551 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
706 ms ± 2.22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Profiling Newer on CPU with N=12
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
2.4 s ± 5.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.42 s ± 4.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.44 s ± 4.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.46 s ± 5.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
255 ms ± 714 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
271 ms ± 521 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
289 ms ± 775 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
306 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Profiling Older on CPU with N=12
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
1.55 s ± 968 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.82 s ± 3.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.02 s ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.13 s ± 6.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
1.22 s ± 2.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.28 s ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.3 s ± 2.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.23 s ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Profiling Newer on CPU with N=15
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
3.06 s ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.1 s ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.15 s ± 7.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.21 s ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
967 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.02 s ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.08 s ± 3.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.13 s ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Profiling Older on CPU with N=15
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
5.2 s ± 6.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.73 s ± 9.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
14.3 s ± 35.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
18.5 s ± 25.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
4.9 s ± 8.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.26 s ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
13.7 s ± 23.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
17.8 s ± 24.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Profiling Newer on CPU with N=18
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
4.75 s ± 35.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.94 s ± 64.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.01 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.12 s ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
2.71 s ± 6.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.74 s ± 4.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.86 s ± 3.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.99 s ± 4.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Profiling Older on CPU with N=18
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
16.8 s ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
31.9 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
46.4 s ± 91.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min 1s ± 28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
16.5 s ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
31.4 s ± 18.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
45.8 s ± 91.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min ± 39.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

YigitElma added a commit that referenced this issue Apr 29, 2024
Until we decide to whether use new repo for zernike_radial, we switch
back to the old version of the code.
Resolves #941
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants