Skip to content

Commit

Permalink
Create an all-purpose ImageStack factory (#1348)
Browse files Browse the repository at this point in the history
1. Create a `LocationAwareFetchedTile` class, which is like `FetchedTile`, but is explicitly aware of its location in 5D space.
2. Create an all_purpose.imagestack_factory method that produces an ImageStack with the provided coordinate ranges.
3. Fix existing tests that did not deal with coordinates properly.
4. Fix `imagestack_test_utils.py::verify_physical_coordinates`, which used zplane as an index rather than a value.  In the case of labeled images, this makes a difference.

Test plan: `make -j all`
  • Loading branch information
Tony Tung authored May 22, 2019
1 parent 6a2a2e3 commit 8af3e83
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 109 deletions.
138 changes: 138 additions & 0 deletions starfish/core/imagestack/test/factories/all_purpose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from abc import ABCMeta
from typing import Mapping, Optional, Sequence, Tuple, Type, Union

import numpy as np

from starfish.core.experiment.builder import (
build_image,
FetchedTile,
tile_fetcher_factory,
TileFetcher,
)
from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.imagestack.parser.crop import CropParameters
from starfish.core.types import Axes, Coordinates, Number


class LocationAwareFetchedTile(FetchedTile, metaclass=ABCMeta):
"""This is the base class for tiles that are aware of their location in the 5D tensor.
"""
def __init__(
self,
# these are the arguments passed in as a result of tile_fetcher_factory's
# pass_tile_indices parameter.
fov: int, _round: int, ch: int, zplane: int,
# these are the arguments we are passing through tile_fetcher_factory.
rounds: Sequence[int], chs: Sequence[int], zplanes: Sequence[int],
tile_height: int, tile_width: int,
) -> None:
super().__init__()
self.round = _round
self.ch = ch
self.zplane = zplane
self.rounds = rounds
self.chs = chs
self.zplanes = zplanes
self.tile_height = tile_height
self.tile_width = tile_width


def _apply_coords_range_fetcher(
backing_tile_fetcher: TileFetcher,
zplanes: Sequence[int],
xrange: Tuple[Number, Number],
yrange: Tuple[Number, Number],
zrange: Tuple[Number, Number],
) -> TileFetcher:
"""Given a :py:class:`TileFetcher`, intercept all the returned :py:class:`FetchedTile` instances
and replace the coordinates such that the resulting tensor has coordinates that range from
`xrange[0]:xrange[1]`, `yrange[0]:yrange[1]`, `zrange[0]:zrange[1]` """
class ModifiedTile(FetchedTile):
def __init__(self, backing_tile: FetchedTile, zplane: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.backing_tile = backing_tile
self.zplane = zplane

@property
def shape(self) -> Mapping[Axes, int]:
return self.backing_tile.shape

@property
def coordinates(self) -> Mapping[Union[str, Coordinates],
Union[Number, Tuple[Number, Number]]]:
zplane_offset = zplanes.index(self.zplane)
zplane_coords = np.linspace(zrange[0], zrange[1], len(zplanes))

return {
Coordinates.X: xrange,
Coordinates.Y: yrange,
Coordinates.Z: zplane_coords[zplane_offset],
}

def tile_data(self) -> np.ndarray:
return self.backing_tile.tile_data()

class ModifiedTileFetcher(TileFetcher):
def get_tile(self, fov: int, r: int, ch: int, z: int) -> FetchedTile:
original_fetched_tile = backing_tile_fetcher.get_tile(fov, r, ch, z)
return ModifiedTile(original_fetched_tile, z)

return ModifiedTileFetcher()


def imagestack_factory(
fetched_tile_cls: Type[LocationAwareFetchedTile],
round_labels: Sequence[int],
ch_labels: Sequence[int],
zplane_labels: Sequence[int],
tile_height: int,
tile_width: int,
xrange: Tuple[Number, Number],
yrange: Tuple[Number, Number],
zrange: Tuple[Number, Number],
crop_parameters: Optional[CropParameters] = None) -> ImageStack:
"""Given a type that implements the :py:class:`LocationAwareFetchedTile` contract, produce an
imagestack with those tiles, and apply coordinates such that the 5D tensor has coordinates
that range from `xrange[0]:xrange[1]`, `yrange[0]:yrange[1]`, `zrange[0]:zrange[1]`.
Parameters
----------
fetched_tile_cls : Type[LocationAwareFetchedTile]
The class of the FetchedTile.
round_labels : Sequence[int]
Labels for the rounds.
ch_labels : Sequence[int]
Labels for the channels.
zplane_labels : Sequence[int]
Labels for the zplanes.
tile_height : int
Height of each tile, in pixels.
tile_width : int
Width of each tile, in pixels.
xrange : Tuple[Number, Number]
The starting and ending x physical coordinates for the tile.
yrange : Tuple[Number, Number]
The starting and ending y physical coordinates for the tile.
zrange : Tuple[Number, Number]
The starting and ending z physical coordinates for the tile.
crop_parameters : Optional[CropParameters]
The crop parameters to apply during ImageStack construction.
"""
original_tile_fetcher = tile_fetcher_factory(
fetched_tile_cls, True,
round_labels, ch_labels, zplane_labels,
tile_height, tile_width,
)
modified_tile_fetcher = _apply_coords_range_fetcher(
original_tile_fetcher, zplane_labels, xrange, yrange, zrange)

collection = build_image(
range(1),
round_labels,
ch_labels,
zplane_labels,
modified_tile_fetcher,
)
tileset = list(collection.all_tilesets())[0][1]

return ImageStack.from_tileset(tileset, crop_parameters)
75 changes: 27 additions & 48 deletions starfish/core/imagestack/test/factories/unique_tiles.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Mapping, Optional, Sequence, Tuple, Union
from abc import ABCMeta
from typing import Mapping, Optional, Sequence

import numpy as np
from skimage import img_as_float32
from slicedimage import ImageFormat

from starfish.core.experiment.builder import build_image, FetchedTile, tile_fetcher_factory
from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.imagestack.parser.crop import CropParameters
from starfish.core.types import Axes, Coordinates, Number
from starfish.core.types import Axes
from .all_purpose import imagestack_factory, LocationAwareFetchedTile

X_COORDS = 0.01, 0.1
Y_COORDS = 0.001, 0.01
Expand All @@ -32,70 +33,48 @@ def unique_data(
return img_as_float32(result)


class UniqueTiles(FetchedTile):
class UniqueTiles(LocationAwareFetchedTile, metaclass=ABCMeta):
"""Tiles where the pixel values are unique per round/ch/z."""
def __init__(
self,
# these are the arguments passed in as a result of tile_fetcher_factory's
# pass_tile_indices parameter.
fov: int, _round: int, ch: int, zplane: int,
# these are the arguments we are passing through tile_fetcher_factory.
num_rounds: int, num_chs: int, num_zplanes: int, tile_height: int, tile_width: int
) -> None:
super().__init__()
self._round = _round
self._ch = ch
self._zplane = zplane
self.num_rounds = num_rounds
self.num_chs = num_chs
self.num_zplanes = num_zplanes
self.tile_height = tile_height
self.tile_width = tile_width
@property
def format(self) -> ImageFormat:
return ImageFormat.TIFF

@property
def shape(self) -> Mapping[Axes, int]:
return {Axes.Y: self.tile_height, Axes.X: self.tile_width}

@property
def coordinates(self) -> Mapping[Union[str, Coordinates], Union[Number, Tuple[Number, Number]]]:
return {
Coordinates.X: X_COORDS,
Coordinates.Y: Y_COORDS,
Coordinates.Z: Z_COORDS,
}

@property
def format(self) -> ImageFormat:
return ImageFormat.TIFF

def tile_data(self) -> np.ndarray:
"""Return the data for a given tile."""
return unique_data(
self._round, self._ch, self._zplane,
self.num_rounds, self.num_chs, self.num_zplanes,
self.tile_height, self.tile_width,
self.rounds.index(self.round),
self.chs.index(self.ch),
self.zplanes.index(self.zplane),
len(self.rounds),
len(self.chs),
len(self.zplanes),
self.tile_height,
self.tile_width,
)


def unique_tiles_imagestack(
round_labels: Sequence[int],
ch_labels: Sequence[int],
z_labels: Sequence[int],
zplane_labels: Sequence[int],
tile_height: int,
tile_width: int,
crop_parameters: Optional[CropParameters] = None) -> ImageStack:
"""Build an imagestack with unique values per tile.
"""
collection = build_image(
range(1),
return imagestack_factory(
UniqueTiles,
round_labels,
ch_labels,
z_labels,
tile_fetcher_factory(
UniqueTiles, True,
len(round_labels), len(ch_labels), len(z_labels),
tile_height, tile_width,
),
zplane_labels,
tile_height,
tile_width,
X_COORDS,
Y_COORDS,
Z_COORDS,
crop_parameters,
)
tileset = list(collection.all_tilesets())[0][1]

return ImageStack.from_tileset(tileset, crop_parameters)
9 changes: 5 additions & 4 deletions starfish/core/imagestack/test/imagestack_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@ def verify_physical_coordinates(stack: ImageStack,
expected_y_coordinates: Tuple[float, float],
expected_z_coordinates: Tuple[float, float],
zplane: Optional[int] = None) -> None:
"""Given an imagestack and a set of coordinate min/max values
verify that the physical coordinates on the stack match the expected
range of values for each coord dimension.
"""Given an imagestack and a set of coordinate min/max values verify that the physical
coordinates on the stack match the expected range of values for each coord dimension.
"""
assert np.all(np.isclose(stack.xarray[Coordinates.X.value],
np.linspace(expected_x_coordinates[0],
Expand All @@ -41,7 +40,9 @@ def verify_physical_coordinates(stack: ImageStack,
# If zplane provided, test expected_z_coordinates on specific plane.
# Else just test expected_z_coordinates on entire array
if zplane is not None:
assert np.isclose(stack.xarray[Coordinates.Z.value][zplane], expected_z_coordinates)
assert np.isclose(
stack.xarray.sel({Axes.ZPLANE.value: zplane})[Coordinates.Z.value],
expected_z_coordinates)
else:
assert np.all(np.isclose(stack.xarray[Coordinates.Z.value], expected_z_coordinates))

Expand Down
86 changes: 48 additions & 38 deletions starfish/core/imagestack/test/test_cropped_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
These tests center around creating an ImageStack but selectively loading data from the original
TileSet.
"""
import numpy as np

from starfish.core.types import Axes
from .factories.unique_tiles import (
unique_data, unique_tiles_imagestack, X_COORDS, Y_COORDS, Z_COORDS,
Expand All @@ -13,7 +15,6 @@
)
from ..imagestack import ImageStack
from ..parser.crop import CropParameters
from ..physical_coordinates import _get_physical_coordinates_of_z_plane


NUM_ROUND = 3
Expand Down Expand Up @@ -53,23 +54,28 @@ def test_crop_rcz():
assert stack.axis_labels(Axes.CH) == chs
assert stack.axis_labels(Axes.ZPLANE) == ZPLANE_LABELS

for round_ in stack.axis_labels(Axes.ROUND):
for ch in stack.axis_labels(Axes.CH):
for zplane in stack.axis_labels(Axes.ZPLANE):
expected_zplane_coordinates = np.linspace(Z_COORDS[0], Z_COORDS[1], NUM_ZPLANE)

for zplane in stack.axis_labels(Axes.ZPLANE):
for round_ in stack.axis_labels(Axes.ROUND):
for ch in stack.axis_labels(Axes.CH):
expected_tile_data = expected_data(round_, ch, zplane)

verify_stack_data(
stack,
{Axes.ROUND: round_, Axes.CH: ch, Axes.ZPLANE: zplane},
expected_tile_data,
)
expected_z_coordinates = _get_physical_coordinates_of_z_plane(Z_COORDS)
verify_physical_coordinates(
stack,
X_COORDS,
Y_COORDS,
expected_z_coordinates,
)

zplane_index = ZPLANE_LABELS.index(zplane)
expected_zplane_coordinate = expected_zplane_coordinates[zplane_index]
verify_physical_coordinates(
stack,
X_COORDS,
Y_COORDS,
expected_zplane_coordinate,
zplane,
)


def test_crop_xy():
Expand All @@ -90,9 +96,11 @@ def test_crop_xy():
assert stack.raw_shape[3] == Y_SLICE[1] - Y_SLICE[0]
assert stack.raw_shape[4] == X_SLICE[1] - X_SLICE[0]

for round_ in stack.axis_labels(Axes.ROUND):
for ch in stack.axis_labels(Axes.CH):
for zplane in stack.axis_labels(Axes.ZPLANE):
expected_zplane_coordinates = np.linspace(Z_COORDS[0], Z_COORDS[1], NUM_ZPLANE)

for zplane in stack.axis_labels(Axes.ZPLANE):
for round_ in stack.axis_labels(Axes.ROUND):
for ch in stack.axis_labels(Axes.CH):
expected_tile_data = expected_data(round_, ch, zplane)
expected_tile_data = expected_tile_data[
Y_SLICE[0]:Y_SLICE[1], X_SLICE[0]:X_SLICE[1]]
Expand All @@ -103,27 +111,29 @@ def test_crop_xy():
expected_tile_data,
)

# the coordinates should be rescaled. verify that the coordinates on the ImageStack
# are also rescaled.
original_x_coordinates = X_COORDS
expected_x_coordinates = recalculate_physical_coordinate_range(
original_x_coordinates[0], original_x_coordinates[1],
WIDTH,
slice(*X_SLICE),
)

original_y_coordinates = Y_COORDS
expected_y_coordinates = recalculate_physical_coordinate_range(
original_y_coordinates[0], original_y_coordinates[1],
HEIGHT,
slice(*Y_SLICE),
)

expected_z_coordinates = _get_physical_coordinates_of_z_plane(Z_COORDS)

verify_physical_coordinates(
stack,
expected_x_coordinates,
expected_y_coordinates,
expected_z_coordinates,
)
# the coordinates should be rescaled. verify that the coordinates on the ImageStack
# are also rescaled.
original_x_coordinates = X_COORDS
expected_x_coordinates = recalculate_physical_coordinate_range(
original_x_coordinates[0], original_x_coordinates[1],
WIDTH,
slice(*X_SLICE),
)

original_y_coordinates = Y_COORDS
expected_y_coordinates = recalculate_physical_coordinate_range(
original_y_coordinates[0], original_y_coordinates[1],
HEIGHT,
slice(*Y_SLICE),
)

zplane_index = ZPLANE_LABELS.index(zplane)
expected_zplane_coordinate = expected_zplane_coordinates[zplane_index]

verify_physical_coordinates(
stack,
expected_x_coordinates,
expected_y_coordinates,
expected_zplane_coordinate,
zplane,
)
Loading

0 comments on commit 8af3e83

Please sign in to comment.