You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Oct 26, 2024. It is now read-only.
A good place to start would be to substantially beat the current CPU JAX INLA implementation.
Things to do here:
Get a CUDA GPU cloud instance or local machine and set up a smooth dev environment on the machine. I think we should be willing to invest substantial time here. Good profiling and remote dev tools will be valuable.
Try running the existing JAX implementation on the cuda backend. What's the performance like? What improvements could be made? Are there operations that are poorly suited to the cuda backend?
I think if we find the hessian to be poorly conditioned, we could probably bail out to a first-order optimization method? Could plug in any of the standard NN optimizers or see here for fancier ones (I'd start with Adam). Jax also has a BFGS solver, which I haven't tried.
I also wonder if there's an analytic solution as sigma -> 0.
If you want to do something really fun, it's probably possible to regularize the newton steps. See e.g. this for a new newton algorithm that claims global convergence.
A good place to start would be to substantially beat the current CPU JAX INLA implementation.
Things to do here:
Explore other JIT tools that might fit our use case. Does Numba work well? CuPy? ArrayFire?The text was updated successfully, but these errors were encountered: