Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RayTransform custom backends #1540

Merged
merged 39 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0dbddb0
Simplify adjoint method
adriaangraas Jan 26, 2020
4a4b934
Make RayTransform operators independent from their implementations
adriaangraas Jan 26, 2020
c4db486
move __future__ import to beginning
adriaangraas Jan 26, 2020
b251fbd
Keep interface of RayTransform.impl the same, fix lowercasing
adriaangraas Jan 27, 2020
468f78e
Share cache with adjoint
adriaangraas Jan 27, 2020
3a3036b
Fix self.__impl initialization
adriaangraas Jan 27, 2020
9b5025e
Style changes, change ValueError into TypeError
adriaangraas Jan 28, 2020
894cb07
Remove RayTransformImplBase
adriaangraas Feb 16, 2020
cfdf701
Remove RayTransformBase; RayBackProjection becomes Operator
adriaangraas Feb 16, 2020
d7d4228
Move _call logic into backend for potential optimization
adriaangraas Feb 16, 2020
ad90713
Remove backends from __all__, update imports
adriaangraas Feb 16, 2020
2e3c274
Bring `RayBackProjection` class inline in `RayTransform.adjoint`
adriaangraas Feb 19, 2020
bc329b9
Fix linear=True kwarg in `RayBackProjection`
adriaangraas Feb 19, 2020
a1f54eb
Decorate RayTransform backend calls with `_add_default_complex_impl`
adriaangraas Feb 20, 2020
3e623e1
Fix import
adriaangraas Feb 20, 2020
25a9b91
Use `impl_type.__name__` as a return value for custom types
adriaangraas Feb 20, 2020
edeafe1
Add `geometry` with @property
adriaangraas Feb 20, 2020
af7cea7
Make `_check_impl` static
adriaangraas Feb 20, 2020
b39fa5f
Change class names of implementations
adriaangraas Feb 20, 2020
e2e1319
Change `reco_space` into `vol_space` and formatting
adriaangraas Feb 20, 2020
8cf7398
Change `reco_space` to `vol_space`
adriaangraas Feb 20, 2020
c64993f
Fix `self` in function call
adriaangraas Feb 20, 2020
7642de1
Fix complex spaces, and fix `out` argument
adriaangraas Feb 23, 2020
05d9fb2
Add properties for `vol_space` and `proj_space` to implementation cla…
adriaangraas Feb 23, 2020
da3fd97
Add docstrings
adriaangraas Feb 23, 2020
3bd5640
Update test to include complex adjoint of `RayTransform`
adriaangraas Feb 27, 2020
2aab34e
Make `_IMPL_STR2TYPE` public as `RAY_TRAFO_IMPLS`
adriaangraas Feb 27, 2020
f952b9e
Do not reassign to `out` when `out` is not None
adriaangraas Feb 27, 2020
fb2d29d
Some formatting updates
kohr-h Apr 13, 2020
5c7dad5
Update of README.md in tomo subpackage
kohr-h Apr 13, 2020
2fc5d60
Copyright notice and functool.wraps in backend utils
kohr-h Apr 13, 2020
d9cfc62
Fix failing import in skimage_radon.py
kohr-h Apr 13, 2020
60a18cb
Change of words
adriaangraas Apr 16, 2020
217e33a
Docstring extended with use case
adriaangraas Apr 16, 2020
204ec7f
Removed _ALL_IMPLS, enhanced exception messages, allow duck-typing `i…
adriaangraas Apr 16, 2020
6f15246
Import sorting order corrected
adriaangraas Apr 16, 2020
034494c
Renamed `_check_impl` and `create_impl`. Simplified docstrings.
adriaangraas Apr 16, 2020
55b0c63
Whitespace changes
adriaangraas Apr 16, 2020
5982904
Turn MD syntax in docstring to rst
kohr-h Apr 16, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions odl/test/tomo/backends/astra_cuda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
import pytest

import odl
from odl.tomo.backends.astra_cuda import (
AstraCudaBackProjectorImpl, AstraCudaProjectorImpl)
from odl.tomo.backends.astra_cuda import AstraCudaImpl
from odl.tomo.util.testutils import skip_if_no_astra_cuda


Expand Down Expand Up @@ -85,24 +84,25 @@ def test_astra_cuda_projector(space_and_geometry):
"""Test ASTRA CUDA projector."""

# Create reco space and a phantom
reco_space, geom = space_and_geometry
phantom = odl.phantom.cuboid(reco_space)
vol_space, geom = space_and_geometry
phantom = odl.phantom.cuboid(vol_space)

# Make projection space
proj_space = odl.uniform_discr_frompartition(geom.partition,
dtype=reco_space.dtype)
dtype=vol_space.dtype)

# create RayTransform implementation
astra_cuda = AstraCudaImpl(geom, vol_space, proj_space)

# Forward evaluation
projector = AstraCudaProjectorImpl(geom, reco_space, proj_space)
proj_data = projector.call_forward(phantom)
proj_data = astra_cuda.call_forward(phantom)
assert proj_data in proj_space
assert proj_data.norm() > 0
assert np.all(proj_data.asarray() >= 0)

# Backward evaluation
back_projector = AstraCudaBackProjectorImpl(geom, reco_space, proj_space)
backproj = back_projector.call_backward(proj_data)
assert backproj in reco_space
backproj = astra_cuda.call_backward(proj_data)
assert backproj in vol_space
assert backproj.norm() > 0
assert np.all(proj_data.asarray() >= 0)

Expand Down
12 changes: 10 additions & 2 deletions odl/test/tomo/operators/ray_trafo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
skip_if_no_astra, skip_if_no_astra_cuda, skip_if_no_skimage)
from odl.util.testutils import all_almost_equal, simple_fixture


# --- pytest fixtures --- #


Expand Down Expand Up @@ -106,7 +105,6 @@ def geometry(request):
'par2d skimage half_uniform'])
)


projector_ids = [
" geom='{}' - impl='{}' - angles='{}' ".format(*p.values[0].split())
for p in projectors
Expand Down Expand Up @@ -339,6 +337,16 @@ def test_complex(impl):
assert all_almost_equal(data.real, true_data_re)
assert all_almost_equal(data.imag, true_data_im)

# test adjoint for complex data
backproj_r = ray_trafo_r.adjoint
backproj_c = ray_trafo_c.adjoint
true_vol_re = backproj_r(data.real)
true_vol_im = backproj_r(data.imag)
backproj_vol = backproj_c(data)

assert all_almost_equal(backproj_vol.real, true_vol_re)
assert all_almost_equal(backproj_vol.imag, true_vol_im)


def test_anisotropic_voxels(geometry):
"""Test projection and backprojection with anisotropic voxels."""
Expand Down
2 changes: 1 addition & 1 deletion odl/tomo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ This directory contains all of the source code related tomographic reconstructio
* [analytic](analytic) Analytic reconstruction methods such as filtered back-projection. Also contains various utilities like `parker_weighting`.
* [backends](backends) Bindings to external libraries.
* [geometry](geometry) Definitions of projection geometries.
* [operators](operators) Defines the `RayTransform` operator and its adjoint ("back-projection").
* [operators](operators) Defines the `RayTransform` operator.
* [util](util) Utilities used internally.
2 changes: 2 additions & 0 deletions odl/tomo/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from .astra_cuda import *
from .astra_setup import *
from .skimage_radon import *
from .util import *

__all__ = ()
__all__ += astra_cpu.__all__
__all__ += astra_cuda.__all__
__all__ += astra_setup.__all__
__all__ += util.__all__
__all__ += skimage_radon.__all__
167 changes: 132 additions & 35 deletions odl/tomo/backends/astra_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@

from __future__ import absolute_import, division, print_function

import warnings

import numpy as np

from odl.discr import DiscretizedSpace, DiscretizedSpaceElement
from odl.tomo.backends.astra_setup import (
astra_algorithm, astra_data, astra_projection_geometry, astra_projector,
astra_volume_geometry)
from odl.tomo.backends.util import _add_default_complex_impl
from odl.tomo.geometry import (
DivergentBeamGeometry, Geometry, ParallelBeamGeometry)
from odl.util import writable_array
Expand Down Expand Up @@ -94,30 +97,42 @@ def astra_cpu_forward_projector(vol_data, geometry, proj_space, out=None,
If ``out`` was provided, the returned object is a reference to it.
"""
if not isinstance(vol_data, DiscretizedSpaceElement):
raise TypeError('volume data {!r} is not a `DiscretizedSpaceElement` '
'instance.'.format(vol_data))
raise TypeError(
'volume data {!r} is not a `DiscretizedSpaceElement` instance'
''.format(vol_data)
)
if vol_data.space.impl != 'numpy':
raise TypeError("`vol_data.space.impl` must be 'numpy', got {!r}"
"".format(vol_data.space.impl))
raise TypeError(
"`vol_data.space.impl` must be 'numpy', got {!r}"
"".format(vol_data.space.impl)
)
if not isinstance(geometry, Geometry):
raise TypeError('geometry {!r} is not a Geometry instance'
''.format(geometry))
raise TypeError(
'geometry {!r} is not a Geometry instance'.format(geometry)
)
if not isinstance(proj_space, DiscretizedSpace):
raise TypeError('`proj_space` {!r} is not a DiscretizedSpace '
'instance.'.format(proj_space))
raise TypeError(
'`proj_space` {!r} is not a DiscretizedSpace instance.'
''.format(proj_space)
)
if proj_space.impl != 'numpy':
raise TypeError("`proj_space.impl` must be 'numpy', got {!r}"
"".format(proj_space.impl))
raise TypeError(
"`proj_space.impl` must be 'numpy', got {!r}"
"".format(proj_space.impl)
)
if vol_data.ndim != geometry.ndim:
raise ValueError('dimensions {} of volume data and {} of geometry '
'do not match'
''.format(vol_data.ndim, geometry.ndim))
raise ValueError(
'dimensions {} of volume data and {} of geometry do not match'
''.format(vol_data.ndim, geometry.ndim)
)
if out is None:
out = proj_space.element()
else:
if out not in proj_space:
raise TypeError('`out` {} is neither None nor a '
'DiscretizedSpaceElement instance'.format(out))
raise TypeError(
'`out` {} is neither None nor a `DiscretizedSpaceElement` '
'instance'.format(out)
)

ndim = vol_data.ndim

Expand Down Expand Up @@ -188,28 +203,37 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,
'instance'.format(proj_data)
)
if proj_data.space.impl != 'numpy':
raise TypeError('`proj_data` must be a `numpy.ndarray` based, '
"container got `impl` {!r}"
"".format(proj_data.space.impl))
raise TypeError(
'`proj_data` must be a `numpy.ndarray` based, container, '
"got `impl` {!r}".format(proj_data.space.impl)
)
if not isinstance(geometry, Geometry):
raise TypeError('geometry {!r} is not a Geometry instance'
''.format(geometry))
raise TypeError(
'geometry {!r} is not a Geometry instance'.format(geometry)
)
if not isinstance(vol_space, DiscretizedSpace):
raise TypeError('volume space {!r} is not a DiscretizedSpace '
'instance'.format(vol_space))
raise TypeError(
'volume space {!r} is not a DiscretizedSpace instance'
''.format(vol_space)
)
if vol_space.impl != 'numpy':
raise TypeError("`vol_space.impl` must be 'numpy', got {!r}"
"".format(vol_space.impl))
raise TypeError(
"`vol_space.impl` must be 'numpy', got {!r}".format(vol_space.impl)
)
if vol_space.ndim != geometry.ndim:
raise ValueError('dimensions {} of reconstruction space and {} of '
'geometry do not match'.format(
vol_space.ndim, geometry.ndim))
raise ValueError(
'dimensions {} of reconstruction space and {} of geometry '
'do not match'
''.format(vol_space.ndim, geometry.ndim)
)
if out is None:
out = vol_space.element()
else:
if out not in vol_space:
raise TypeError('`out` {} is neither None nor a '
'DiscretizedSpaceElement instance'.format(out))
raise TypeError(
'`out` {} is neither None nor a `DiscretizedSpaceElement` '
'instance'.format(out)
)

ndim = proj_data.ndim

Expand All @@ -218,8 +242,9 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,
proj_geom = astra_projection_geometry(geometry)

# Create ASTRA data structure
sino_id = astra_data(proj_geom, datatype='projection', data=proj_data,
allow_copy=True)
sino_id = astra_data(
proj_geom, datatype='projection', data=proj_data, allow_copy=True
)

# Create projector
if astra_proj_type is None:
Expand All @@ -228,11 +253,13 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,

# Convert out to correct dtype and order if needed.
with writable_array(out, dtype='float32', order='C') as out_arr:
vol_id = astra_data(vol_geom, datatype='volume', data=out_arr,
ndim=vol_space.ndim)
vol_id = astra_data(
vol_geom, datatype='volume', data=out_arr, ndim=vol_space.ndim
)
# Create algorithm
algo_id = astra_algorithm('backward', ndim, vol_id, sino_id, proj_id,
impl='cpu')
algo_id = astra_algorithm(
'backward', ndim, vol_id, sino_id, proj_id, impl='cpu'
)

# Run algorithm
astra.algorithm.run(algo_id)
Expand All @@ -251,6 +278,76 @@ def astra_cpu_back_projector(proj_data, geometry, vol_space, out=None,
return out


class AstraCpuImpl:
"""Thin wrapper implementing ASTRA CPU for `RayTransform`."""

def __init__(self, geometry, vol_space, proj_space):
"""Initialize a new instance.

Parameters
----------
geometry : `Geometry`
Geometry defining the tomographic setup.
vol_space : `DiscreteLp`
Reconstruction space, the space of the images to be forward
projected.
proj_space : `DiscreteLp`
Projection space, the space of the result.
"""
if not isinstance(geometry, Geometry):
raise TypeError(
'`geometry` must be a `Geometry` instance, got {!r}'
''.format(geometry)
)
if not isinstance(vol_space, DiscretizedSpace):
raise TypeError(
'`vol_space` must be a `DiscretizedSpace` instance, got {!r}'
''.format(vol_space)
)
if not isinstance(proj_space, DiscretizedSpace):
raise TypeError(
'`proj_space` must be a `DiscretizedSpace` instance, got {!r}'
''.format(proj_space)
)
if geometry.ndim > 2:
raise ValueError(
'`impl` {!r} only works for 2d'.format(self.__name__)
)

if vol_space.size >= 512 ** 2:
warnings.warn(
"The 'astra_cpu' backend may be too slow for volumes of this "
"size. Consider using 'astra_cuda' if your machine has an "
"Nvidia GPU.",
RuntimeWarning,
)

self.geometry = geometry
self._vol_space = vol_space
self._proj_space = proj_space

@property
def vol_space(self):
return self._vol_space

@property
def proj_space(self):
return self._proj_space

@_add_default_complex_impl
def call_backward(self, x, out, **kwargs):
return astra_cpu_back_projector(
x, self.geometry, self.vol_space.real_space, out, **kwargs
)

@_add_default_complex_impl
def call_forward(self, x, out, **kwargs):
return astra_cpu_forward_projector(
x, self.geometry, self.proj_space.real_space, out, **kwargs
)


if __name__ == '__main__':
from odl.util.testutils import run_doctests

run_doctests()
Loading