Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An attempt at a pytorch fast_inla implementation. #97

Closed
wants to merge 1 commit into from

Conversation

tbenthompson
Copy link
Member

I played a bunch with pytorch over the last day. Overall, my conclusion is pretty negative in comparison with jax. I’m going to record my explorations in the repo because it’s still worth doing another look once we start playing with GPU stuff more.

Overall:

  • I can’t get a float64 CPU version of pytorch to be competitive with JAX. It’s about 3x slower (21us vs 7us). This isn’t a huge deal since neither of these numbers is slow, but it is disappointing since that’s the relevant comparison.
  • Using the M1 GPU/Apple metal backend is REALLY messy. The backend is quite immature and seems to only really support standard neural net operations well.
    • Some things are a lot faster. For example, with v1.13 (the version that enables the “mps” mac gpu backend), I get about a 10x speedup compared to v1.11 for something like a simple matrix multiply. That’s awesome and really important since matmuls are so fundamental.
    • But, what we need is a bit more complex and several of the important operations are not supported yet by the mps backend. determinants, matrix solves and inverses. a few others.
    • Several operations ran but returned incorrect output. torch.einsum and torch.linalg.inv just gave completely incorrect output. I made an issue: torch.linalg.inv gives incorrect results for any index besides zero on M1/mps GPU device pytorch/pytorch#78363
  • The other big issue is that running with the metal backend requires running in float32 (this would be true of cuda too, at least for good performance). It turns out that this causes major problems for our current Berry fast_inla code. A few of our operations fail in 32 bit floating point. The INLA optimization to find the mode fails to converge because the hessian inverse is unstable particularly for small sigma2. If I use a linear solver instead of computing the inverse directly, I can get float32 to work correctly. But that means it’s nontrivially harder to get the variance of the multivariate gaussian approximation. Anyway, this is probably another set of issues that we’ll need to explore a bit more in the future. Being able to do correct inference in 32 bit floating point precision would be really helpful for performance because it unlocks the whole GPU world. I don’t think I should dig into it more right now because it’s not critical path, but it’s good to know what kind of problems we might run into.

@tbenthompson
Copy link
Member Author

Closing this immediately since I don't want to merge this code, but just record some thoughts on the branch and be able to reference this elsewhere in issues.

@constinit
Copy link
Contributor

constinit commented May 27, 2022

Nice this is a very cool experiment.

  • Interesting, I had heard pytorch doesn't have great CPU performance, but I didn't think the gap would be that large. I would make sure it's getting JIT compiled.
  • I guess the metal backend problems aren't too surprising at this stage of development. Something we could look at is using Choleskies for the inverses and determinants, assuming the Cholesky op works!

@tbenthompson
Copy link
Member Author

tbenthompson commented May 27, 2022

using Choleskies for the inverses and determinants

Oh gosh, I totally forgot about Cholesky!! What a huge oversight. In the JAX/numpy worlds, it was substantially slower, but it's the "correct" algorithm in this case in a lot of ways. I suspect the JAX problem is just that the implementation there is suboptimal in some way. Anyway, I'll try that.

@tbenthompson
Copy link
Member Author

I would make sure it's getting JIT compiled.

Yeah, I tried using the torchscript jit. I'm not sure if I'm just doing something wrong, but I don't get much if any speedup. On the other hand, some of our more complex vectorized operations don't work in torchscript so I have to write the code quite differently.

@constinit
Copy link
Contributor

In the JAX/numpy worlds, it was substantially slower, but it's the "correct" algorithm in this case in a lot of ways. I suspect the JAX problem is just that the implementation there is suboptimal in some way.

jax.lax.linalg contains a batched, (hopefully) fast cholesky op.
https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.linalg.cholesky.html

@tbenthompson
Copy link
Member Author

tbenthompson commented May 28, 2022 via email

@tbenthompson tbenthompson deleted the tbt/pytorch branch July 9, 2022 23:28
tbenthompson pushed a commit that referenced this pull request Dec 21, 2022
* Rename notebooks and integrate simple solvers for tile-based approaches.

* Add in new change for tilt-bound

* Add gaussian bias test (it works woohoo)

* Fix python version to 3.9

* Use 3.9.13 instead?

* Fix to python 3.10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants