-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
An attempt at a pytorch fast_inla implementation. #97
Conversation
Closing this immediately since I don't want to merge this code, but just record some thoughts on the branch and be able to reference this elsewhere in issues. |
Nice this is a very cool experiment.
|
Oh gosh, I totally forgot about Cholesky!! What a huge oversight. In the JAX/numpy worlds, it was substantially slower, but it's the "correct" algorithm in this case in a lot of ways. I suspect the JAX problem is just that the implementation there is suboptimal in some way. Anyway, I'll try that. |
Yeah, I tried using the torchscript jit. I'm not sure if I'm just doing something wrong, but I don't get much if any speedup. On the other hand, some of our more complex vectorized operations don't work in torchscript so I have to write the code quite differently. |
|
Yeah, I tried that and found it to be a bit slower than the current
fast_invert method. On the other hand, it'll likely be necessary in 32-bit
float JAX.
…On Sat, May 28, 2022 at 6:48 PM constinit ***@***.***> wrote:
In the JAX/numpy worlds, it was substantially slower, but it's the
"correct" algorithm in this case in a lot of ways. I suspect the JAX
problem is just that the implementation there is suboptimal in some way.
jax.lax.linalg contains a batched, (hopefully) fast cholesky op.
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.linalg.cholesky.html
—
Reply to this email directly, view it on GitHub
<#97 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABALTE2ZCI4FB6CGT5E5K7DVMKPFVANCNFSM5XCBPY7A>
.
You are receiving this because you modified the open/close state.Message
ID: ***@***.***>
|
* Rename notebooks and integrate simple solvers for tile-based approaches. * Add in new change for tilt-bound * Add gaussian bias test (it works woohoo) * Fix python version to 3.9 * Use 3.9.13 instead? * Fix to python 3.10
I played a bunch with pytorch over the last day. Overall, my conclusion is pretty negative in comparison with jax. I’m going to record my explorations in the repo because it’s still worth doing another look once we start playing with GPU stuff more.
Overall: