-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regression in _initial_guess_surface
from v0.10.4
#941
Comments
profiling function for different zernike methods: def profile(N, func):
basis = desc.basis.FourierZernikeBasis(N, N, N, sym="cos")
grid = desc.grid.ConcentricGrid(2 * N, 2 * N, 2 * N)
rho = grid.nodes[:, 0]
print(f"{func} compile, dr=0,1,2,3")
%timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready()
%timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready()
%timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready()
%timeit jax.clear_caches(); z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready()
print(f"{func} run, dr=0,1,2,3")
%timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready()
%timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready()
%timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready()
%timeit z2 = func(rho, basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready() @YigitElma can you try the different versions in faster zernike for cpu and gpu? |
Should we also add jax.clear_caches() to our optimizers? |
I don't think so. doing it too often significantly slows performance, so I think it's better to only do it at a top level if its necessary for a specific case. |
Ok, I can compare them tomorrow. |
thanks. I was using |
@YigitElma any update? |
I'm so sorry @f0uriest . Here are the results for GPU comparing the current version on DESC and the older version. Click for full resultsProfiling Newer on GPU with N=6
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.26 s ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.26 s ± 5.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.26 s ± 5.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.26 s ± 6.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
1.17 ms ± 3.43 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.76 ms ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
2.06 ms ± 2.61 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.23 ms ± 2.24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Profiling Older on GPU with N=6
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
510 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
722 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
940 ms ± 2.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.15 s ± 3.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
422 µs ± 39.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
629 µs ± 23 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
762 µs ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
884 µs ± 3.12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) Profiling Newer on GPU with N=9
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.39 s ± 3.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.39 s ± 5.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.39 s ± 4.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.39 s ± 4.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
2.62 ms ± 1.83 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.74 ms ± 2.94 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.21 ms ± 3.07 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.66 ms ± 3.42 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Profiling Older on GPU with N=9
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
528 ms ± 4.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
748 ms ± 2.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
971 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.2 s ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
1.47 ms ± 22 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.47 ms ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.21 ms ± 41 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.72 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Profiling Newer on GPU with N=12
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.42 s ± 4.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.43 s ± 7.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.43 s ± 4.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.42 s ± 3.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
11.8 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
13.7 ms ± 12.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
14.7 ms ± 7.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.1 ms ± 36.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Profiling Older on GPU with N=12
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
522 ms ± 929 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
740 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
964 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.19 s ± 4.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
7.38 ms ± 37.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
13.4 ms ± 55.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
18.3 ms ± 67.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
22.5 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Profiling Newer on GPU with N=15
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.58 s ± 3.93 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.59 s ± 4.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.59 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.59 s ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
51.1 ms ± 91.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
53.1 ms ± 17 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54 ms ± 33.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.4 ms ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Profiling Older on GPU with N=15
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
587 ms ± 1.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
842 ms ± 5.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.1 s ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.35 s ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
29.3 ms ± 62.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
52.9 ms ± 39.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
71.8 ms ± 70.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
87.7 ms ± 17.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Profiling Newer on GPU with N=18
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> compile, dr=0,1,2,3
2.76 s ± 6.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.76 s ± 4.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.76 s ± 7.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.77 s ± 9.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x14e2088432d0> run, dr=0,1,2,3
191 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
193 ms ± 390 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
194 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
197 ms ± 391 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Profiling Older on GPU with N=18
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> compile, dr=0,1,2,3
667 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
984 ms ± 1.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.3 s ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.6 s ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x14e200ba2ac0>> run, dr=0,1,2,3
98.1 ms ± 95.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
180 ms ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
250 ms ± 137 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
316 ms ± 190 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) it looks like |
So, if we call each derivative once for a function, overall time spent would be the same (because old method requires a new compilation but new one doesn't, at least I understand that). N = 6
basis = FourierZernikeBasis(N, N, N, sym="cos")
grid = ConcentricGrid(2 * N, 2 * N, 2 * N)
rho = grid.nodes[:, 0]
print("new")
%timeit jax.clear_caches(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready(); z1 = zernike_radial_switch_gpu(rho, basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready()
print("old")
%timeit jax.clear_caches(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=0).block_until_ready(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=1).block_until_ready(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=2).block_until_ready(); z2 = zernike_radial_old_desc(rho[:,np.newaxis], basis.modes[:,0], basis.modes[:,1], dr=3).block_until_ready()
|
For future reference, the CPU version of the code (with a single line change of how we update the array i.e. using foriloop instead of jnp.where) results are such, Click for full resultsProfiling Newer on CPU with N=6
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
2.11 s ± 6.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.11 s ± 7.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.11 s ± 7.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.11 s ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
2.06 ms ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.82 ms ± 15.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.49 ms ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.05 ms ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Profiling Older on CPU with N=6
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
316 ms ± 1.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
497 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
681 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
876 ms ± 2.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
13.4 ms ± 86.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
26.8 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
44.1 ms ± 782 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
61.6 ms ± 574 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) Profiling Newer on CPU with N=9
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
2.14 s ± 5.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.16 s ± 11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.15 s ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.16 s ± 5.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
40 ms ± 1.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
46.6 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
50.6 ms ± 880 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.7 ms ± 984 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Profiling Older on CPU with N=9
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
526 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
886 ms ± 3.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.23 s ± 1.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.58 s ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
204 ms ± 894 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
382 ms ± 3.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
551 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
706 ms ± 2.22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) Profiling Newer on CPU with N=12
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
2.4 s ± 5.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.42 s ± 4.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.44 s ± 4.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.46 s ± 5.43 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
255 ms ± 714 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
271 ms ± 521 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
289 ms ± 775 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
306 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Profiling Older on CPU with N=12
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
1.55 s ± 968 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.82 s ± 3.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.02 s ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.13 s ± 6.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
1.22 s ± 2.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.28 s ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.3 s ± 2.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.23 s ± 3.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) Profiling Newer on CPU with N=15
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
3.06 s ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.1 s ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.15 s ± 7.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.21 s ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
967 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.02 s ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.08 s ± 3.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.13 s ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Profiling Older on CPU with N=15
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
5.2 s ± 6.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.73 s ± 9.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
14.3 s ± 35.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
18.5 s ± 25.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
4.9 s ± 8.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
9.26 s ± 14.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
13.7 s ± 23.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
17.8 s ± 24.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) Profiling Newer on CPU with N=18
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> compile, dr=0,1,2,3
4.75 s ± 35.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.94 s ± 64.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.01 s ± 37.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5.12 s ± 12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<jax._src.custom_derivatives.custom_jvp object at 0x152b86e9f910> run, dr=0,1,2,3
2.71 s ± 6.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.74 s ± 4.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.86 s ± 3.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.99 s ± 4.48 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Profiling Older on CPU with N=18
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> compile, dr=0,1,2,3
16.8 s ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
31.9 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
46.4 s ± 91.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min 1s ± 28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
<PjitFunction of <function zernike_radial_old_desc at 0x152b73839b20>> run, dr=0,1,2,3
16.5 s ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
31.4 s ± 18.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
45.8 s ± 91.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min ± 39.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) |
Until we decide to whether use new repo for zernike_radial, we switch back to the old version of the code. Resolves #941
On master now initializing an
Equilibrium
takes about 2x longer than it did on the last release v0.10.4, the culprit seems to be the_initial_guess_surface
function.In jupyter:
I suspect we didn't see it because the benchmarks run the same code over and over so you get the advantage of JIT caching, but in practice you're usually only creating a single equilibrium so JIT doesn't really help. We should add
jax.clear_caches()
to those benchmarks to avoid this in the future.The text was updated successfully, but these errors were encountered: