Skip to content

Commit

Permalink
ENH: add utility to check bounds for geometry methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Sep 26, 2017
1 parent 8d890a6 commit 6292fa7
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 42 deletions.
6 changes: 3 additions & 3 deletions odl/tomo/geometry/conebeam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from odl.tomo.geometry.detector import Flat1dDetector, Flat2dDetector
from odl.tomo.geometry.geometry import (
DivergentBeamGeometry, AxisOrientedGeometry)
from odl.tomo.util.utility import euler_matrix, transform_system
from odl.tomo.util.utility import (
euler_matrix, transform_system, is_inside_bounds)
from odl.util import signature_string, indent_rows


Expand Down Expand Up @@ -484,8 +485,7 @@ def rotation_matrix(self, angle):
squeeze_out = (np.shape(angle) == ())
angle = np.array(angle, dtype=float, copy=False, ndmin=1)
if (self.check_bounds and
not self.motion_params.contains_all(angle.ravel())):
# Allow `angle` with ndim > 1 by checking the raveled array
not is_inside_bounds(angle, self.motion_params)):
raise ValueError('`angle` {} not in the valid range {}'
''.format(angle, self.motion_params))

Expand Down
40 changes: 14 additions & 26 deletions odl/tomo/geometry/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np

from odl.discr import RectPartition
from odl.tomo.util.utility import perpendicular_vector
from odl.tomo.util import perpendicular_vector, is_inside_bounds
from odl.util import indent_rows, signature_string


Expand Down Expand Up @@ -171,6 +171,7 @@ def surface_normal(self, param):
- ``param.shape + (space_ndim,)`` if `ndim` is 1,
- ``param.shape[:-1] + (space_ndim,)`` otherwise.
"""
# Checking is done by `surface_deriv`
if self.ndim == 1 and self.space_ndim == 2:
return -perpendicular_vector(self.surface_deriv(param))
elif self.ndim == 2 and self.space_ndim == 3:
Expand Down Expand Up @@ -215,6 +216,7 @@ def surface_measure(self, param):
.. _Surface area:
https://en.wikipedia.org/wiki/Surface_area
"""
# Checking is done by `surface_deriv`
if self.ndim == 1:
scalar_out = (np.shape(param) == ())
measure = np.linalg.norm(self.surface_deriv(param), axis=-1)
Expand Down Expand Up @@ -324,8 +326,7 @@ def surface(self, param):
"""
squeeze_out = (np.shape(param) == ())
param = np.array(param, dtype=float, copy=False, ndmin=1)
if self.check_bounds and not self.params.contains_all(param.ravel()):
# Allow `param` with ndim > 1 by checking the raveled array
if self.check_bounds and not is_inside_bounds(param, self.params):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param, self.params))

Expand Down Expand Up @@ -377,8 +378,7 @@ def surface_deriv(self, param):
"""
squeeze_out = (np.shape(param) == ())
param = np.array(param, dtype=float, copy=False, ndmin=1)
if self.check_bounds and not self.params.contains_all(param.ravel()):
# Allow `param` with ndim > 1 by checking the raveled array
if self.check_bounds and not is_inside_bounds(param, self.params):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param, self.params))
if squeeze_out:
Expand Down Expand Up @@ -510,14 +510,9 @@ def surface(self, param):
param_in = param
param = tuple(np.array(p, dtype=float, copy=False, ndmin=1)
for p in param)
if self.check_bounds:
# Flesh out and flatten to check bounds
bcast_param = np.broadcast_arrays(*param)
stacked_param = np.vstack(bcast_param)
flat_param = stacked_param.reshape(self.ndim, -1)
if not self.params.contains_all(flat_param):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param_in, self.params))
if self.check_bounds and not is_inside_bounds(param, self.params):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param_in, self.params))

# Compute outer product of the i-th spatial component of the
# parameter and sum up the contributions
Expand Down Expand Up @@ -594,14 +589,9 @@ def surface_deriv(self, param):
param_in = param
param = tuple(np.array(p, dtype=float, copy=False, ndmin=1)
for p in param)
if self.check_bounds:
# Flesh out and flatten to check bounds
bcast_param = np.broadcast_arrays(*param)
stacked_param = np.vstack(bcast_param)
flat_param = stacked_param.reshape(self.ndim, -1)
if not self.params.contains_all(flat_param):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param_in, self.params))
if self.check_bounds and not is_inside_bounds(param, self.params):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param_in, self.params))

if squeeze_out:
return self.axes
Expand Down Expand Up @@ -754,8 +744,7 @@ def surface(self, param):
"""
squeeze_out = (np.shape(param) == ())
param = np.array(param, dtype=float, copy=False, ndmin=1)
if self.check_bounds and not self.params.contains_all(param.ravel()):
# Allow `param` with ndim > 1 by checking the raveled array
if self.check_bounds and not is_inside_bounds(param, self.params):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param, self.params))

Expand Down Expand Up @@ -820,8 +809,7 @@ def surface_deriv(self, param):
"""
squeeze_out = (np.shape(param) == ())
param = np.array(param, dtype=float, copy=False, ndmin=1)
if self.check_bounds and not self.params.contains_all(param.ravel()):
# Allow `param` with ndim > 1 by checking the raveled array
if self.check_bounds and not is_inside_bounds(param, self.params):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param, self.params))

Expand Down Expand Up @@ -880,7 +868,7 @@ def surface_measure(self, param):
"""
scalar_out = (np.shape(param) == ())
param = np.array(param, dtype=float, copy=False, ndmin=1)
if self.check_bounds and not self.params.contains_all(param):
if self.check_bounds and not is_inside_bounds(param, self.params):
raise ValueError('`param` {} not in the valid range '
'{}'.format(param, self.params))

Expand Down
4 changes: 2 additions & 2 deletions odl/tomo/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from odl.discr import RectPartition
from odl.tomo.geometry.detector import Detector
from odl.tomo.util import axis_rotation_matrix
from odl.tomo.util import axis_rotation_matrix, is_inside_bounds


__all__ = ('Geometry', 'DivergentBeamGeometry', 'AxisOrientedGeometry')
Expand Down Expand Up @@ -610,7 +610,7 @@ def rotation_matrix(self, angle):
squeeze_out = (np.shape(angle) == ())
angle = np.array(angle, dtype=float, copy=False, ndmin=1)
if (self.check_bounds and
not self.motion_params.contains_all(angle.ravel())):
not is_inside_bounds(angle, self.motion_params)):
raise ValueError('`angle` {} not in the valid range {}'
''.format(angle, self.motion_params))

Expand Down
16 changes: 6 additions & 10 deletions odl/tomo/geometry/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from odl.discr import uniform_partition
from odl.tomo.geometry.detector import Flat1dDetector, Flat2dDetector
from odl.tomo.geometry.geometry import Geometry, AxisOrientedGeometry
from odl.tomo.util import euler_matrix, transform_system
from odl.tomo.util import euler_matrix, transform_system, is_inside_bounds
from odl.util import signature_string, indent_rows


Expand Down Expand Up @@ -623,7 +623,7 @@ def rotation_matrix(self, angle):
squeeze_out = (np.shape(angle) == ())
angle = np.array(angle, dtype=float, copy=False, ndmin=1)
if (self.check_bounds and
not self.motion_params.contains_all(angle.ravel())):
not is_inside_bounds(angle, self.motion_params)):
raise ValueError('`angle` {} not in the valid range {}'
''.format(angle, self.motion_params))

Expand Down Expand Up @@ -1033,14 +1033,10 @@ def rotation_matrix(self, angles):
angles_in = angles
angles = tuple(np.array(angle, dtype=float, copy=False, ndmin=1)
for angle in angles)
if self.check_bounds:
# Flesh out and flatten to check bounds
bcast_angles = np.broadcast_arrays(*angles)
stacked_angles = np.vstack(bcast_angles)
flat_angles = stacked_angles.reshape(self.motion_params.ndim, -1)
if not self.motion_params.contains_all(flat_angles):
raise ValueError('`angles` {} not in the valid range '
'{}'.format(angles_in, self.motion_params))
if (self.check_bounds and
not is_inside_bounds(angles, self.motion_params)):
raise ValueError('`angles` {} not in the valid range '
'{}'.format(angles_in, self.motion_params))

matrix = euler_matrix(*angles)
if squeeze_out:
Expand Down
56 changes: 55 additions & 1 deletion odl/tomo/util/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

__all__ = ('euler_matrix', 'axis_rotation', 'axis_rotation_matrix',
'rotation_matrix_from_to', 'transform_system',
'perpendicular_vector')
'perpendicular_vector', 'is_inside_bounds')


def euler_matrix(phi, theta=None, psi=None):
Expand Down Expand Up @@ -620,6 +620,60 @@ def perpendicular_vector(vec):
return result


def is_inside_bounds(value, params):
"""Return ``True`` if ``value`` is contained in ``params``.
This method supports broadcasting in the sense that for
``params.ndim >= 2``, if more than one value is given, the inputs
are broadcast against each other.
Parameters
----------
value : `array-like`
Value(s) to be checked. For several inputs, the final bool
tells whether all inputs pass the check or not.
params : `IntervalProd`
Set in which the value is / the values are supposed to lie.
Returns
-------
is_inside_bounds : bool
``True`` is all values lie in ``params``, ``False`` otherwise.
Examples
--------
Check a single point:
>>> params = odl.IntervalProd([0, 0], [1, 2])
>>> is_inside_bounds([0, 0], params)
True
>>> is_inside_bounds([0, -1], params)
False
Using broadcasting:
>>> pts_ax0 = np.array([0, 0, 1, 0, 1])[:, None]
>>> pts_ax1 = np.array([2, 0, 1])[None, :]
>>> is_inside_bounds([pts_ax0, pts_ax1], params)
True
>>> pts_ax1 = np.array([-2, 1])[None, :]
>>> is_inside_bounds([pts_ax0, pts_ax1], params)
False
"""
if value in params:
# Single parameter
return True
else:
if params.ndim == 1:
return params.contains_all(np.ravel(value))
else:
# Flesh out and flatten to check bounds
bcast_value = np.broadcast_arrays(*value)
stacked_value = np.vstack(bcast_value)
flat_value = stacked_value.reshape(params.ndim, -1)
return params.contains_all(flat_value)


if __name__ == '__main__':
from odl.util.testutils import run_doctests
run_doctests()

0 comments on commit 6292fa7

Please sign in to comment.