Using JAX for accelerated high-resolution inference of generative models of neuroimaging data #18520
Unanswered
maedoc
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
hi, developer of brain simulation software Virtual Brain (thevirtualbrain.org) since 10+ years here. I wanted to show and tell about a package I started just over a year ago for building, running and doing inference with generative models of brain imaging data, named vbjax
https://github.com/ins-amu/vbjax
disclaimer, no brain images there, sorry :) it's a bring-your-own-data deal
About two years ago, I wrote a neural field solver for GPU in CUDA C++ (among other things), and when we needed gradients, I started looking around for a Better Way, since these kinds of models didn't work well in Stan (has great autodiff and stats, but CPU, single threaded mostly, highly templated C++). After a brief time with Futhark, which I enjoyed, JAX ticked a bunch of boxes for me since I (a) really liked the
autograd
package (b) enjoyed but can't work in functional languages that favour functions & immutability (c) need to provide a flexible Python API that covers diverse needs of the research group I support (d) the Pyro group had a sister package NumPyro so I had the Bayesian estimation use cases covered and of course (e) it's fast enough (tho I will look at how to add custom ops since ISPC still eats everyone's lunch on multicore CPU).My current focus is extending the use cases to larger datasets, which involved learning the parallelisation & distribution mechanisms in JAX, and improving the user facing API and examples.
Lastly, I am starting to teach JAX to students/postdocs in the team and to others in the local community, so hopefully it will gain some mindshare here.
🙏 many thanks for the great foundation library 🙏
Beta Was this translation helpful? Give feedback.
All reactions