Skip to content

Commit

Permalink
MAINT: minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Holger Kohr committed Jan 21, 2017
1 parent 670864e commit 61a29db
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 16 deletions.
21 changes: 13 additions & 8 deletions odl/discr/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ def __init__(self, *coord_vectors):
[1.0, 2.0, 5.0],
[-2.0, 1.5, 2.0]
)
>>> print(g)
grid [1.0, 2.0, 5.0] x [-2.0, 1.5, 2.0]
>>> g.ndim # number of axes
2
>>> g.shape # points per axis
Expand Down Expand Up @@ -392,11 +390,17 @@ def mid_pt(self):
def stride(self):
"""Step per axis between neighboring points of a uniform grid.
Raises
------
NotImplementedError
if the grid is not uniform
Examples
--------
>>> rg = uniform_grid([-1.5, -1], [-0.5, 3], (2, 3))
>>> rg.stride
array([ 1., 2.])
"""
if not self.is_uniform:
raise NotImplementedError('`stride` not defined for non-uniform '
Expand Down Expand Up @@ -570,10 +574,11 @@ def is_subgrid(self, other, atol=0.0):
# Optimization for some common cases
if other is self:
return True
if not (isinstance(other, RectGrid) and
np.all(self.shape <= other.shape) and
np.all(self.min_pt >= other.min_pt - atol) and
np.all(self.max_pt <= other.max_pt + atol)):
if not isinstance(other, RectGrid):
return False
if not(np.all(self.shape <= other.shape) and
np.all(self.min_pt >= other.min_pt - atol) and
np.all(self.max_pt <= other.max_pt + atol)):
return False

if self.is_uniform and other.is_uniform:
Expand All @@ -599,7 +604,7 @@ def is_subgrid(self, other, atol=0.0):
# vec_s. If there is no almost zero entry in each row,
# return False.
vec_o_mg, vec_s_mg = sparse_meshgrid(vec_o, vec_s)
if not np.all(np.any(np.abs(vec_s_mg - vec_o_mg) <= atol,
if not np.all(np.any(np.isclose(vec_s_mg, vec_o_mg, atol=atol),
axis=0)):
return False
return True
Expand Down Expand Up @@ -701,7 +706,7 @@ def squeeze(self):
)
"""
nondegen_indcs = np.where(self.nondegen_byaxis)[0]
nondegen_indcs = np.flatnonzero(self.nondegen_byaxis)
coord_vecs = [self.coord_vectors[axis] for axis in nondegen_indcs]
return RectGrid(*coord_vecs)

Expand Down
11 changes: 3 additions & 8 deletions odl/discr/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,6 @@ def __repr__(self):

return RectPartitionByAxis()

def __str__(self):
"""Return ``str(self)``."""
return 'partition of {} using {}'.format(self.set, self.grid)

def __repr__(self):
"""Return ``repr(self)``."""
bdry_fracs = np.vstack(self.boundary_cell_fractions)
Expand Down Expand Up @@ -835,6 +831,8 @@ def __repr__(self):
sep=[',\n', ', ', ',\n'])
return '{}(\n{}\n)'.format(constructor, indent_rows(sig_str))

__str__ = __repr__


def uniform_partition_fromintv(intv_prod, shape, nodes_on_bdry=False):
"""Return a partition of an interval product into equally sized cells.
Expand Down Expand Up @@ -898,10 +896,7 @@ def uniform_partition_fromintv(intv_prod, shape, nodes_on_bdry=False):
>>> part.grid.coord_vectors[1]
array([ 0.2, 0.6, 1. ])
"""

grid = uniform_grid_fromintv(intv_prod, shape,
nodes_on_bdry=nodes_on_bdry)

grid = uniform_grid_fromintv(intv_prod, shape, nodes_on_bdry=nodes_on_bdry)
return RectPartition(intv_prod, grid)


Expand Down
2 changes: 2 additions & 0 deletions odl/test/discr/grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# You should have received a copy of the GNU General Public License
# along with ODL. If not, see <http://www.gnu.org/licenses/>.

from __future__ import division

import pytest
import numpy as np

Expand Down
2 changes: 2 additions & 0 deletions odl/test/discr/partition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# You should have received a copy of the GNU General Public License
# along with ODL. If not, see <http://www.gnu.org/licenses/>.

from __future__ import division

import pytest
import numpy as np

Expand Down
7 changes: 7 additions & 0 deletions odl/test/tomo/backends/astra_setup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ def test_astra_projection_geometry():
apart = odl.uniform_partition(0, 2 * np.pi, 5)
dpart = odl.uniform_partition(-40, 40, 10)

# motion sampling grid, detector sampling grid but not uniform
dpart_0 = odl.RectPartition(odl.IntervalProd(0, 3),
odl.RectGrid([0, 1, 3]))
geom_p2d = odl.tomo.Parallel2dGeometry(apart, dpart=dpart_0)
with pytest.raises(ValueError):
odl.tomo.astra_projection_geometry(geom_p2d)

# detector sampling grid, motion sampling grid
geom_p2d = odl.tomo.Parallel2dGeometry(apart, dpart)
odl.tomo.astra_projection_geometry(geom_p2d)
Expand Down

0 comments on commit 61a29db

Please sign in to comment.