jaxtomo implements tomographic projectors with JAX.
They are implemented purely in Python, which makes the code readable and hackable. Because JAX offers just-in-time compilation to GPU, the projectors are reasonably fast. They don't use texture memory and are slower than optimized implementations such as torch-radon.
This is a personal project and very work-in-progress. It is meant as a learning exercise for me, a pedagogical implementation for others (once I add some comments), and maybe even a tool for implementing proof-of-concept pipelines.
- Parallel beam
- Fan beam
- Cone Beam
- FBP
- ... all with a flat detector
- FP and BP registered as respective transpose for autodiff with JAX
- End-to-end SIR via autodiff
jax.pmap
for multi-GPU speedup
- Valid FBP for large fan/cone angles (atm we just do Ramlak filter + BP)
- Other FP methods (Siddon, Footprint, ...)
- Curved detector
- Different voxel basis functions [1], [2]
- speedup bilinear interpolation and/or profile FP, as it's rather slow.
According toexamples/timing.py
, FP takes ~5x longer than BP. - Try JAX Pallas