From 8af3e83706ac07c586b3b8b46133f403e842e24c Mon Sep 17 00:00:00 2001 From: Tony Tung Date: Wed, 22 May 2019 16:00:35 -0700 Subject: [PATCH] Create an all-purpose ImageStack factory (#1348) 1. Create a `LocationAwareFetchedTile` class, which is like `FetchedTile`, but is explicitly aware of its location in 5D space. 2. Create an all_purpose.imagestack_factory method that produces an ImageStack with the provided coordinate ranges. 3. Fix existing tests that did not deal with coordinates properly. 4. Fix `imagestack_test_utils.py::verify_physical_coordinates`, which used zplane as an index rather than a value. In the case of labeled images, this makes a difference. Test plan: `make -j all` --- .../imagestack/test/factories/all_purpose.py | 138 ++++++++++++++++++ .../imagestack/test/factories/unique_tiles.py | 75 ++++------ .../imagestack/test/imagestack_test_utils.py | 9 +- .../core/imagestack/test/test_cropped_load.py | 86 ++++++----- .../imagestack/test/test_labeled_indices.py | 54 ++++--- 5 files changed, 253 insertions(+), 109 deletions(-) create mode 100644 starfish/core/imagestack/test/factories/all_purpose.py diff --git a/starfish/core/imagestack/test/factories/all_purpose.py b/starfish/core/imagestack/test/factories/all_purpose.py new file mode 100644 index 000000000..35dd219a9 --- /dev/null +++ b/starfish/core/imagestack/test/factories/all_purpose.py @@ -0,0 +1,138 @@ +from abc import ABCMeta +from typing import Mapping, Optional, Sequence, Tuple, Type, Union + +import numpy as np + +from starfish.core.experiment.builder import ( + build_image, + FetchedTile, + tile_fetcher_factory, + TileFetcher, +) +from starfish.core.imagestack.imagestack import ImageStack +from starfish.core.imagestack.parser.crop import CropParameters +from starfish.core.types import Axes, Coordinates, Number + + +class LocationAwareFetchedTile(FetchedTile, metaclass=ABCMeta): + """This is the base class for tiles that are aware of their location in the 5D tensor. + """ + def __init__( + self, + # these are the arguments passed in as a result of tile_fetcher_factory's + # pass_tile_indices parameter. + fov: int, _round: int, ch: int, zplane: int, + # these are the arguments we are passing through tile_fetcher_factory. + rounds: Sequence[int], chs: Sequence[int], zplanes: Sequence[int], + tile_height: int, tile_width: int, + ) -> None: + super().__init__() + self.round = _round + self.ch = ch + self.zplane = zplane + self.rounds = rounds + self.chs = chs + self.zplanes = zplanes + self.tile_height = tile_height + self.tile_width = tile_width + + +def _apply_coords_range_fetcher( + backing_tile_fetcher: TileFetcher, + zplanes: Sequence[int], + xrange: Tuple[Number, Number], + yrange: Tuple[Number, Number], + zrange: Tuple[Number, Number], +) -> TileFetcher: + """Given a :py:class:`TileFetcher`, intercept all the returned :py:class:`FetchedTile` instances + and replace the coordinates such that the resulting tensor has coordinates that range from + `xrange[0]:xrange[1]`, `yrange[0]:yrange[1]`, `zrange[0]:zrange[1]` """ + class ModifiedTile(FetchedTile): + def __init__(self, backing_tile: FetchedTile, zplane: int, *args, **kwargs): + super().__init__(*args, **kwargs) + self.backing_tile = backing_tile + self.zplane = zplane + + @property + def shape(self) -> Mapping[Axes, int]: + return self.backing_tile.shape + + @property + def coordinates(self) -> Mapping[Union[str, Coordinates], + Union[Number, Tuple[Number, Number]]]: + zplane_offset = zplanes.index(self.zplane) + zplane_coords = np.linspace(zrange[0], zrange[1], len(zplanes)) + + return { + Coordinates.X: xrange, + Coordinates.Y: yrange, + Coordinates.Z: zplane_coords[zplane_offset], + } + + def tile_data(self) -> np.ndarray: + return self.backing_tile.tile_data() + + class ModifiedTileFetcher(TileFetcher): + def get_tile(self, fov: int, r: int, ch: int, z: int) -> FetchedTile: + original_fetched_tile = backing_tile_fetcher.get_tile(fov, r, ch, z) + return ModifiedTile(original_fetched_tile, z) + + return ModifiedTileFetcher() + + +def imagestack_factory( + fetched_tile_cls: Type[LocationAwareFetchedTile], + round_labels: Sequence[int], + ch_labels: Sequence[int], + zplane_labels: Sequence[int], + tile_height: int, + tile_width: int, + xrange: Tuple[Number, Number], + yrange: Tuple[Number, Number], + zrange: Tuple[Number, Number], + crop_parameters: Optional[CropParameters] = None) -> ImageStack: + """Given a type that implements the :py:class:`LocationAwareFetchedTile` contract, produce an + imagestack with those tiles, and apply coordinates such that the 5D tensor has coordinates + that range from `xrange[0]:xrange[1]`, `yrange[0]:yrange[1]`, `zrange[0]:zrange[1]`. + + Parameters + ---------- + fetched_tile_cls : Type[LocationAwareFetchedTile] + The class of the FetchedTile. + round_labels : Sequence[int] + Labels for the rounds. + ch_labels : Sequence[int] + Labels for the channels. + zplane_labels : Sequence[int] + Labels for the zplanes. + tile_height : int + Height of each tile, in pixels. + tile_width : int + Width of each tile, in pixels. + xrange : Tuple[Number, Number] + The starting and ending x physical coordinates for the tile. + yrange : Tuple[Number, Number] + The starting and ending y physical coordinates for the tile. + zrange : Tuple[Number, Number] + The starting and ending z physical coordinates for the tile. + crop_parameters : Optional[CropParameters] + The crop parameters to apply during ImageStack construction. + """ + original_tile_fetcher = tile_fetcher_factory( + fetched_tile_cls, True, + round_labels, ch_labels, zplane_labels, + tile_height, tile_width, + ) + modified_tile_fetcher = _apply_coords_range_fetcher( + original_tile_fetcher, zplane_labels, xrange, yrange, zrange) + + collection = build_image( + range(1), + round_labels, + ch_labels, + zplane_labels, + modified_tile_fetcher, + ) + tileset = list(collection.all_tilesets())[0][1] + + return ImageStack.from_tileset(tileset, crop_parameters) diff --git a/starfish/core/imagestack/test/factories/unique_tiles.py b/starfish/core/imagestack/test/factories/unique_tiles.py index 6288e67b7..a108265b8 100644 --- a/starfish/core/imagestack/test/factories/unique_tiles.py +++ b/starfish/core/imagestack/test/factories/unique_tiles.py @@ -1,13 +1,14 @@ -from typing import Mapping, Optional, Sequence, Tuple, Union +from abc import ABCMeta +from typing import Mapping, Optional, Sequence import numpy as np from skimage import img_as_float32 from slicedimage import ImageFormat -from starfish.core.experiment.builder import build_image, FetchedTile, tile_fetcher_factory from starfish.core.imagestack.imagestack import ImageStack from starfish.core.imagestack.parser.crop import CropParameters -from starfish.core.types import Axes, Coordinates, Number +from starfish.core.types import Axes +from .all_purpose import imagestack_factory, LocationAwareFetchedTile X_COORDS = 0.01, 0.1 Y_COORDS = 0.001, 0.01 @@ -32,70 +33,48 @@ def unique_data( return img_as_float32(result) -class UniqueTiles(FetchedTile): +class UniqueTiles(LocationAwareFetchedTile, metaclass=ABCMeta): """Tiles where the pixel values are unique per round/ch/z.""" - def __init__( - self, - # these are the arguments passed in as a result of tile_fetcher_factory's - # pass_tile_indices parameter. - fov: int, _round: int, ch: int, zplane: int, - # these are the arguments we are passing through tile_fetcher_factory. - num_rounds: int, num_chs: int, num_zplanes: int, tile_height: int, tile_width: int - ) -> None: - super().__init__() - self._round = _round - self._ch = ch - self._zplane = zplane - self.num_rounds = num_rounds - self.num_chs = num_chs - self.num_zplanes = num_zplanes - self.tile_height = tile_height - self.tile_width = tile_width + @property + def format(self) -> ImageFormat: + return ImageFormat.TIFF @property def shape(self) -> Mapping[Axes, int]: return {Axes.Y: self.tile_height, Axes.X: self.tile_width} - @property - def coordinates(self) -> Mapping[Union[str, Coordinates], Union[Number, Tuple[Number, Number]]]: - return { - Coordinates.X: X_COORDS, - Coordinates.Y: Y_COORDS, - Coordinates.Z: Z_COORDS, - } - - @property - def format(self) -> ImageFormat: - return ImageFormat.TIFF - def tile_data(self) -> np.ndarray: + """Return the data for a given tile.""" return unique_data( - self._round, self._ch, self._zplane, - self.num_rounds, self.num_chs, self.num_zplanes, - self.tile_height, self.tile_width, + self.rounds.index(self.round), + self.chs.index(self.ch), + self.zplanes.index(self.zplane), + len(self.rounds), + len(self.chs), + len(self.zplanes), + self.tile_height, + self.tile_width, ) def unique_tiles_imagestack( round_labels: Sequence[int], ch_labels: Sequence[int], - z_labels: Sequence[int], + zplane_labels: Sequence[int], tile_height: int, tile_width: int, crop_parameters: Optional[CropParameters] = None) -> ImageStack: """Build an imagestack with unique values per tile. """ - collection = build_image( - range(1), + return imagestack_factory( + UniqueTiles, round_labels, ch_labels, - z_labels, - tile_fetcher_factory( - UniqueTiles, True, - len(round_labels), len(ch_labels), len(z_labels), - tile_height, tile_width, - ), + zplane_labels, + tile_height, + tile_width, + X_COORDS, + Y_COORDS, + Z_COORDS, + crop_parameters, ) - tileset = list(collection.all_tilesets())[0][1] - - return ImageStack.from_tileset(tileset, crop_parameters) diff --git a/starfish/core/imagestack/test/imagestack_test_utils.py b/starfish/core/imagestack/test/imagestack_test_utils.py index 1a6e9104b..3acef9c1d 100644 --- a/starfish/core/imagestack/test/imagestack_test_utils.py +++ b/starfish/core/imagestack/test/imagestack_test_utils.py @@ -26,9 +26,8 @@ def verify_physical_coordinates(stack: ImageStack, expected_y_coordinates: Tuple[float, float], expected_z_coordinates: Tuple[float, float], zplane: Optional[int] = None) -> None: - """Given an imagestack and a set of coordinate min/max values - verify that the physical coordinates on the stack match the expected - range of values for each coord dimension. + """Given an imagestack and a set of coordinate min/max values verify that the physical + coordinates on the stack match the expected range of values for each coord dimension. """ assert np.all(np.isclose(stack.xarray[Coordinates.X.value], np.linspace(expected_x_coordinates[0], @@ -41,7 +40,9 @@ def verify_physical_coordinates(stack: ImageStack, # If zplane provided, test expected_z_coordinates on specific plane. # Else just test expected_z_coordinates on entire array if zplane is not None: - assert np.isclose(stack.xarray[Coordinates.Z.value][zplane], expected_z_coordinates) + assert np.isclose( + stack.xarray.sel({Axes.ZPLANE.value: zplane})[Coordinates.Z.value], + expected_z_coordinates) else: assert np.all(np.isclose(stack.xarray[Coordinates.Z.value], expected_z_coordinates)) diff --git a/starfish/core/imagestack/test/test_cropped_load.py b/starfish/core/imagestack/test/test_cropped_load.py index 46e233e0b..eaa228958 100644 --- a/starfish/core/imagestack/test/test_cropped_load.py +++ b/starfish/core/imagestack/test/test_cropped_load.py @@ -2,6 +2,8 @@ These tests center around creating an ImageStack but selectively loading data from the original TileSet. """ +import numpy as np + from starfish.core.types import Axes from .factories.unique_tiles import ( unique_data, unique_tiles_imagestack, X_COORDS, Y_COORDS, Z_COORDS, @@ -13,7 +15,6 @@ ) from ..imagestack import ImageStack from ..parser.crop import CropParameters -from ..physical_coordinates import _get_physical_coordinates_of_z_plane NUM_ROUND = 3 @@ -53,9 +54,11 @@ def test_crop_rcz(): assert stack.axis_labels(Axes.CH) == chs assert stack.axis_labels(Axes.ZPLANE) == ZPLANE_LABELS - for round_ in stack.axis_labels(Axes.ROUND): - for ch in stack.axis_labels(Axes.CH): - for zplane in stack.axis_labels(Axes.ZPLANE): + expected_zplane_coordinates = np.linspace(Z_COORDS[0], Z_COORDS[1], NUM_ZPLANE) + + for zplane in stack.axis_labels(Axes.ZPLANE): + for round_ in stack.axis_labels(Axes.ROUND): + for ch in stack.axis_labels(Axes.CH): expected_tile_data = expected_data(round_, ch, zplane) verify_stack_data( @@ -63,13 +66,16 @@ def test_crop_rcz(): {Axes.ROUND: round_, Axes.CH: ch, Axes.ZPLANE: zplane}, expected_tile_data, ) - expected_z_coordinates = _get_physical_coordinates_of_z_plane(Z_COORDS) - verify_physical_coordinates( - stack, - X_COORDS, - Y_COORDS, - expected_z_coordinates, - ) + + zplane_index = ZPLANE_LABELS.index(zplane) + expected_zplane_coordinate = expected_zplane_coordinates[zplane_index] + verify_physical_coordinates( + stack, + X_COORDS, + Y_COORDS, + expected_zplane_coordinate, + zplane, + ) def test_crop_xy(): @@ -90,9 +96,11 @@ def test_crop_xy(): assert stack.raw_shape[3] == Y_SLICE[1] - Y_SLICE[0] assert stack.raw_shape[4] == X_SLICE[1] - X_SLICE[0] - for round_ in stack.axis_labels(Axes.ROUND): - for ch in stack.axis_labels(Axes.CH): - for zplane in stack.axis_labels(Axes.ZPLANE): + expected_zplane_coordinates = np.linspace(Z_COORDS[0], Z_COORDS[1], NUM_ZPLANE) + + for zplane in stack.axis_labels(Axes.ZPLANE): + for round_ in stack.axis_labels(Axes.ROUND): + for ch in stack.axis_labels(Axes.CH): expected_tile_data = expected_data(round_, ch, zplane) expected_tile_data = expected_tile_data[ Y_SLICE[0]:Y_SLICE[1], X_SLICE[0]:X_SLICE[1]] @@ -103,27 +111,29 @@ def test_crop_xy(): expected_tile_data, ) - # the coordinates should be rescaled. verify that the coordinates on the ImageStack - # are also rescaled. - original_x_coordinates = X_COORDS - expected_x_coordinates = recalculate_physical_coordinate_range( - original_x_coordinates[0], original_x_coordinates[1], - WIDTH, - slice(*X_SLICE), - ) - - original_y_coordinates = Y_COORDS - expected_y_coordinates = recalculate_physical_coordinate_range( - original_y_coordinates[0], original_y_coordinates[1], - HEIGHT, - slice(*Y_SLICE), - ) - - expected_z_coordinates = _get_physical_coordinates_of_z_plane(Z_COORDS) - - verify_physical_coordinates( - stack, - expected_x_coordinates, - expected_y_coordinates, - expected_z_coordinates, - ) + # the coordinates should be rescaled. verify that the coordinates on the ImageStack + # are also rescaled. + original_x_coordinates = X_COORDS + expected_x_coordinates = recalculate_physical_coordinate_range( + original_x_coordinates[0], original_x_coordinates[1], + WIDTH, + slice(*X_SLICE), + ) + + original_y_coordinates = Y_COORDS + expected_y_coordinates = recalculate_physical_coordinate_range( + original_y_coordinates[0], original_y_coordinates[1], + HEIGHT, + slice(*Y_SLICE), + ) + + zplane_index = ZPLANE_LABELS.index(zplane) + expected_zplane_coordinate = expected_zplane_coordinates[zplane_index] + + verify_physical_coordinates( + stack, + expected_x_coordinates, + expected_y_coordinates, + expected_zplane_coordinate, + zplane, + ) diff --git a/starfish/core/imagestack/test/test_labeled_indices.py b/starfish/core/imagestack/test/test_labeled_indices.py index d4d592f45..1ddb2ab9a 100644 --- a/starfish/core/imagestack/test/test_labeled_indices.py +++ b/starfish/core/imagestack/test/test_labeled_indices.py @@ -10,7 +10,6 @@ ) from .imagestack_test_utils import verify_physical_coordinates, verify_stack_data from ..imagestack import ImageStack -from ..physical_coordinates import _get_physical_coordinates_of_z_plane ROUND_LABELS = (1, 4, 6) CH_LABELS = (2, 4, 6, 8) @@ -24,7 +23,12 @@ def expected_data(round_: int, ch: int, zplane: int): - return unique_data(round_, ch, zplane, NUM_ROUND, NUM_CH, NUM_ZPLANE, HEIGHT, WIDTH) + return unique_data( + ROUND_LABELS.index(round_), + CH_LABELS.index(ch), + ZPLANE_LABELS.index(zplane), + NUM_ROUND, NUM_CH, NUM_ZPLANE, + HEIGHT, WIDTH) def setup_imagestack() -> ImageStack: @@ -78,6 +82,8 @@ def test_labeled_indices_sel_single_tile(): the data is correct and that the physical coordinates are correctly set.""" stack = setup_imagestack() + expected_zplane_coordinates = np.linspace(Z_COORDS[0], Z_COORDS[1], NUM_ZPLANE) + for selector in stack._iter_axes({Axes.ROUND, Axes.CH, Axes.ZPLANE}): subselected = stack.sel(selector) @@ -90,21 +96,25 @@ def test_labeled_indices_sel_single_tile(): selector[Axes.ROUND], selector[Axes.CH], selector[Axes.ZPLANE]) verify_stack_data(stack, selector, expected_tile_data) + zplane_index = ZPLANE_LABELS.index(selector[Axes.ZPLANE]) + expected_zplane_coordinate = expected_zplane_coordinates[zplane_index] + # assert that the physical coordinate values are what we expect. - verify_physical_coordinates( - stack, - X_COORDS, - Y_COORDS, - _get_physical_coordinates_of_z_plane(Z_COORDS), - ) + verify_physical_coordinates( + stack, + X_COORDS, + Y_COORDS, + expected_zplane_coordinate, + selector[Axes.ZPLANE], + ) def test_labeled_indices_sel_slice(): """Select a single tile across each index from an ImageStack with labeled indices. Verify that the data is correct and that the physical coordinates are correctly set.""" stack = setup_imagestack() - selector = {Axes.ROUND: slice(None, 4), Axes.CH: slice(4, 6), Axes.ZPLANE: 4} - subselected = stack.sel(selector) + set_selector = {Axes.ROUND: slice(None, 4), Axes.CH: slice(4, 6), Axes.ZPLANE: 4} + subselected = stack.sel(set_selector) # verify that the subselected stack has the correct index labels. for index_type, expected_results in ( @@ -113,19 +123,25 @@ def test_labeled_indices_sel_slice(): (Axes.ZPLANE, [4],)): assert subselected.axis_labels(index_type) == expected_results - for selectors in subselected._iter_axes({Axes.ROUND, Axes.CH, Axes.ZPLANE}): + expected_zplane_coordinates = np.linspace(Z_COORDS[0], Z_COORDS[1], NUM_ZPLANE) + + for selector in subselected._iter_axes({Axes.ROUND, Axes.CH, Axes.ZPLANE}): # verify that the subselected stack has the correct data. expected_tile_data = expected_data( - selectors[Axes.ROUND], selectors[Axes.CH], selectors[Axes.ZPLANE]) - verify_stack_data(subselected, selectors, expected_tile_data) + selector[Axes.ROUND], selector[Axes.CH], selector[Axes.ZPLANE]) + verify_stack_data(subselected, selector, expected_tile_data) + + zplane_index = ZPLANE_LABELS.index(set_selector[Axes.ZPLANE]) + expected_zplane_coordinate = expected_zplane_coordinates[zplane_index] # verify that each tile in the subselected stack has the correct physical coordinates. - verify_physical_coordinates( - stack, - X_COORDS, - Y_COORDS, - _get_physical_coordinates_of_z_plane(Z_COORDS), - ) + verify_physical_coordinates( + stack, + X_COORDS, + Y_COORDS, + expected_zplane_coordinate, + selector[Axes.ZPLANE] + ) def multiply(array, value):