Skip to content

Commit

Permalink
WIP: Fix broadcasting in vec geometry methods
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Sep 6, 2020
1 parent 6b8eb8b commit 252b7d0
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions odl/tomo/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,16 +889,16 @@ def det_point_position(self, index, dparam):
>>> geom_2d = odl.tomo.ParallelVecGeometry(det_shape_2d, vecs_2d)
>>> # This is equal to d(0) = (0, 1)
>>> geom_2d.det_point_position(0, 0)
array([ 0., 1.])
array([ 0., 1.])
>>> # d(0) + 2 * u(0) = (0, 1) + 2 * (1, 0)
>>> geom_2d.det_point_position(0, 2)
array([ 2., 1.])
array([ 2., 1.])
>>> # d(1) + 2 * u(1) = (-1, 0) + 2 * (0, 1)
>>> geom_2d.det_point_position(1, 2)
array([-1., 2.])
array([-1., 2.])
>>> # d(0.4) + 2 * u(0.4) = d(0) + 2 * u(0)
>>> geom_2d.det_point_position(0.4, 2)
array([ 2., 1.])
array([ 2., 1.])
Broadcasting of arguments:
Expand Down Expand Up @@ -935,7 +935,7 @@ def det_point_position(self, index, dparam):
Broadcasting of arguments:
>>> idcs = np.array([0.4, 0.6])[:, None]
>>> dpar = np.array([2.0, 1.0])[None, :]
>>> dpar = np.array([2.0, 1.0])[:, None]
>>> geom_3d.det_point_position(idcs, dpar)
"""
Expand All @@ -952,10 +952,10 @@ def det_point_position(self, index, dparam):
''.format(dparam, self.det_params)
)

# TODO: broadcast correctly
if self.ndim == 2:
det_shift = dparam * self.det_axis(index)
elif self.ndim == 3:
axes = self.det_axes(index)
det_shift = sum(
di * ax for di, ax in zip(dparam, self.det_axes(index))
)
Expand Down Expand Up @@ -1038,9 +1038,15 @@ def det_axes(self, index):
Returns
-------
axes : tuple of `numpy.ndarray`, shape ``(ndim,)``
The detector axes at ``index``, 1 for ``ndim == 2`` and
2 for ``ndim == 3``.
axes : `numpy.ndarray`
Unit vectors along which the detector is aligned, an array
with following shape:
- In 2D: If ``index`` is a single parameter, the shape is
``(2,)``, otherwise ``index.shape + (2,)``.
- In 3D: If ``index`` is a single parameter, the shape is
``(2, 3)``, otherwise ``index.shape + (2, 3)``.
Examples
--------
Expand All @@ -1056,16 +1062,17 @@ def det_axes(self, index):
>>> det_shape_3d = (10, 20)
>>> geom_3d = odl.tomo.ParallelVecGeometry(det_shape_3d, vecs_3d)
>>> geom_3d.det_axes(0)
(array([ 1., 0., 0.]), array([ 0., 0., 1.]))
array([[ 1., 0., 0.],
[ 0., 0., 1.]])
>>> geom_3d.det_axes(1)
(array([ 0., 1., 0.]), array([ 0., 0., 1.]))
>>> axs = geom_3d.det_axes([0.4, 0.6]) # values at closest indices
>>> axs[0] # first axis
array([[ 1., 0., 0.],
[ 0., 1., 0.]])
>>> axs[1] # second axis
array([[ 0., 0., 1.],
[ 0., 0., 1.]])
array([[ 0., 1., 0.],
[ 0., 0., 1.]])
>>> geom_3d.det_axes([0.4, 0.6]) # values at closest indices
array([[[ 1., 0., 0.],
[ 0., 0., 1.]],
<BLANKLINE>
[[ 0., 1., 0.],
[ 0., 0., 1.]]])
"""
if (
self.check_bounds
Expand All @@ -1085,19 +1092,18 @@ def det_axes(self, index):

vectors = self.vectors[index_int]
if self.ndim == 2:
det_us = vectors[:, self._slice_det_u]
retval_lst = [det_us[0]] if squeeze_index else [det_us]
axes = np.empty(index_int.shape + (2,))
axes[:] = vectors[:, self._slice_det_u]
elif self.ndim == 3:
det_us = vectors[:, self._slice_det_u]
det_vs = vectors[:, self._slice_det_v]
if squeeze_index:
retval_lst = [det_us[0], det_vs[0]]
else:
retval_lst = [det_us, det_vs]
axes = np.empty(index_int.shape + (2, 3))
axes[..., 0, :] = vectors[..., self._slice_det_u]
axes[..., 1, :] = vectors[..., self._slice_det_v]
else:
raise RuntimeError('invalid `ndim`')

return tuple(retval_lst)
if squeeze_index:
axes = axes[0]
return axes

def __getitem__(self, indices):
"""Return ``self[indices]``.
Expand Down

0 comments on commit 252b7d0

Please sign in to comment.