You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am really happy to see some multiple dispatching mechanism brought to JAX, thanks for that!
I ran a highly toy use case to compare some timings between using quaxified functions and class methods and I was quite surprised by the results.
From the snippet below, is there something I have missed? Maybe quax targets more large, already-implemented models/functions?
Make sure to wrap your programs in jax.jit or equinox.filter_jit. This is an important thing to do for all JAX programs if you want good performnace :)
(Once this has happened then Quax itself will have zero overhead, as it only affects the compilation process -- the result will be a compuation graph like any other jit-compiled JAX computation graph.)
Hi,
I am really happy to see some multiple dispatching mechanism brought to JAX, thanks for that!
I ran a highly toy use case to compare some timings between using quaxified functions and class methods and I was quite surprised by the results.
From the snippet below, is there something I have missed? Maybe quax targets more large, already-implemented models/functions?
Again, thanks for this library and the feedback.
Vadim
The text was updated successfully, but these errors were encountered: