diff --git a/starfish/core/label_image/__init__.py b/starfish/core/label_image/__init__.py new file mode 100644 index 000000000..772b912ec --- /dev/null +++ b/starfish/core/label_image/__init__.py @@ -0,0 +1 @@ +from .label_image import LabelImage diff --git a/starfish/core/label_image/label_image.py b/starfish/core/label_image/label_image.py new file mode 100644 index 000000000..93f88d8a0 --- /dev/null +++ b/starfish/core/label_image/label_image.py @@ -0,0 +1,172 @@ +from pathlib import Path +from typing import Any, Hashable, Mapping, MutableMapping, Optional, Sequence, Tuple, Union + +import numpy as np +import xarray as xr +from semantic_version import Version + +from starfish.core.types import Axes, Coordinates, LOG, Number, STARFISH_EXTRAS_KEY +from starfish.core.util.logging import Log +from .util import _get_axes_names + + +class AttrKeys: + DOCTYPE = f"{STARFISH_EXTRAS_KEY}.DOCTYPE" + LOG = f"{STARFISH_EXTRAS_KEY}.{LOG}" + VERSION = f"{STARFISH_EXTRAS_KEY}.VERSION" + + +DOCTYPE_STRING = "starfish/LabelImage" +CURRENT_VERSION = Version("0.0.0") +MIN_SUPPORTED_VERSION = Version("0.0.0") +MAX_SUPPORTED_VERSION = Version("0.0.0") + + +class LabelImage: + """Wraps an xarray that contains a 2D or 3D labeled image. Each axis is labeled with physical + coordinate data.""" + + def __init__(self, label_image: xr.DataArray): + # verify that the data array has the required elements. + if label_image.dtype.kind not in ("i", "u"): + raise TypeError("label image should be an integer type") + for axis in (Axes.X, Axes.Y): + if axis.value not in label_image.coords: + raise ValueError(f"label image should have an {axis.value} axis") + expected_coordinates: Tuple[Coordinates, ...] + if label_image.ndim == 5: + expected_coordinates = (Coordinates.X, Coordinates.Y, Coordinates.Z) + else: + expected_coordinates = (Coordinates.X, Coordinates.Y) + for coord in expected_coordinates: + if coord.value not in label_image.coords: + raise ValueError(f"label image should have a {coord.value} coordinates") + + self.label_image = label_image.copy(deep=False) + if AttrKeys.DOCTYPE not in self.label_image.attrs: + self.label_image.attrs[AttrKeys.DOCTYPE] = DOCTYPE_STRING + if AttrKeys.LOG not in self.label_image.attrs: + self.label_image.attrs[AttrKeys.LOG] = Log().encode() + + @classmethod + def from_array_and_coords( + cls, + array: np.ndarray, + pixel_coordinates: Optional[Union[ + Mapping[Axes, Sequence[int]], + Mapping[str, Sequence[int]]]], + physical_coordinates: Union[ + Mapping[Coordinates, Sequence[Number]], + Mapping[str, Sequence[Number]]], + log: Optional[Log], + ) -> "LabelImage": + """Constructs a LabelImage from an array containing the labels, a set of physical + coordinates, and an optional log of how this label image came to be. + + Parameters + ---------- + array : np.ndarray + A 2D or 3D array containing the labels. The ordering of the axes must be Y, X for 2D + images and ZPLANE, Y, X for 3D images. + pixel_coordinates : Optional[Mapping[Union[Axes, str], Sequence[int]]] + A map from the axis to the values for that axis. For any axis that exist in the array + but not in pixel_coordinates, the pixel coordinates are assigned from 0..N-1, where N is + the size along that axis. + physical_coordinates : Mapping[Union[Coordinates, str], Sequence[Number]] + A map from the physical coordinate type to the values for axis. For 2D label images, + X and Y physical coordinates must be provided. For 3D label images, Z physical + coordinates must also be provided. + log : Optional[Log] + A log of how this label image came to be. + """ + # normalize the pixel coordinates to Mapping[Axes, Sequence[int]] + pixel_coordinates = { + axis if isinstance(axis, Axes) else Axes(axis): axis_values + for axis, axis_values in (pixel_coordinates or {}).items() + } + # normalize the physical coordinates to Mapping[Coordinates, Sequence[Number]] + physical_coordinates = { + coord if isinstance(coord, Coordinates) else Coordinates(coord): coord_values + for coord, coord_values in physical_coordinates.items() + } + + img_axes, img_coords = _get_axes_names(array.ndim) + xr_axes = [axis.value for axis in img_axes] + try: + xr_coords: MutableMapping[Hashable, Any] = { + coord.value: (axis.value, physical_coordinates[coord]) + for axis, coord in zip(img_axes, img_coords) + } + except KeyError as ex: + raise KeyError(f"missing physical coordinates {ex.args[0]}") from ex + + xr_coords[Axes.X.value] = pixel_coordinates.get(Axes.X, np.arange(0, array.shape[-1])) + xr_coords[Axes.Y.value] = pixel_coordinates.get(Axes.Y, np.arange(0, array.shape[-2])) + if array.ndim == 3: + xr_coords[Axes.ZPLANE.value] = pixel_coordinates.get( + Axes.ZPLANE, np.arange(0, array.shape[-3])) + + dataarray = xr.DataArray( + array, + dims=xr_axes, + coords=xr_coords, + ) + dataarray.attrs.update({ + AttrKeys.LOG: (log or Log()).encode(), + AttrKeys.DOCTYPE: DOCTYPE_STRING, + AttrKeys.VERSION: str(CURRENT_VERSION), + }) + + return LabelImage(dataarray) + + @property + def xarray(self): + """Returns the xarray that contains the label image and the physical coordinates.""" + return self.label_image + + @property + def log(self) -> Log: + """Returns a copy of the provenance data. Modifications to this copy will not affect the + log stored on this label image.""" + return Log.decode(self.label_image.attrs[AttrKeys.LOG]) + + @classmethod + def from_disk(cls, path: Union[str, Path]) -> "LabelImage": + """Load a label image from disk. + + Parameters + ---------- + path : Union[str, Path] + Path of the label image to instantiate from. + + Returns + ------- + label_image : LabelImage + Label image from the path. + """ + label_image = xr.open_dataarray(path) + if ( + AttrKeys.DOCTYPE not in label_image.attrs + or label_image.attrs[AttrKeys.DOCTYPE] != DOCTYPE_STRING + or AttrKeys.VERSION not in label_image.attrs + ): + raise ValueError(f"{path} does not appear to be a starfish label image") + if not ( + MIN_SUPPORTED_VERSION + <= Version(label_image.attrs[AttrKeys.VERSION]) + <= MAX_SUPPORTED_VERSION): + raise ValueError( + f"{path} contains a label image, but the version " + f"{label_image.attrs[AttrKeys.VERSION]} is not supported") + + return cls(label_image) + + def save(self, path: Union[str, Path]): + """Save the label image to disk. + + Parameters + ---------- + path : Union[str, Path] + Path of the netcdf file to write to. + """ + self.label_image.to_netcdf(path) diff --git a/starfish/core/label_image/test/__init__.py b/starfish/core/label_image/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/starfish/core/label_image/test/test_label_image.py b/starfish/core/label_image/test/test_label_image.py new file mode 100644 index 000000000..b272158d9 --- /dev/null +++ b/starfish/core/label_image/test/test_label_image.py @@ -0,0 +1,118 @@ +from typing import Mapping, Optional, Sequence, Type + +import numpy as np +import pytest + +from starfish import Log +from starfish.image import Filter +from starfish.types import Axes, Coordinates, Number +from ..label_image import AttrKeys, CURRENT_VERSION, DOCTYPE_STRING, LabelImage + + +@pytest.mark.parametrize( + "array, physical_coordinates, log, expected_error", + [ + # 3D label image + [ + np.zeros((1, 1, 1), dtype=np.int32), + { + Coordinates.X: [0], + Coordinates.Y: [0], + Coordinates.Z: [0], + }, + None, + None, + ], + # 2D label image + [ + np.zeros((1, 1), dtype=np.int32), + { + Coordinates.X: [0], + Coordinates.Y: [0], + }, + None, + None, + ], + # wrong dtype + [ + np.zeros((1, 1), dtype=np.float32), + { + Coordinates.X: [0], + Coordinates.Y: [0], + }, + None, + TypeError, + ], + # missing some coordinates + [ + np.zeros((1, 1), dtype=np.float32), + { + Coordinates.X: [0], + }, + None, + KeyError, + ], + ] +) +def test_from_array_and_coords( + array: np.ndarray, + physical_coordinates: Mapping[Coordinates, Sequence[Number]], + log: Optional[Log], + expected_error: Optional[Type[Exception]], +): + """Test that we can construct a LabelImage and that some common error conditions are caught.""" + if expected_error is not None: + with pytest.raises(expected_error): + LabelImage.from_array_and_coords(array, None, physical_coordinates, log) + else: + label_image = LabelImage.from_array_and_coords(array, None, physical_coordinates, log) + assert isinstance(label_image.log, Log) + assert label_image.xarray.attrs.get(AttrKeys.DOCTYPE, None) == DOCTYPE_STRING + assert label_image.xarray.attrs.get(AttrKeys.VERSION, None) == str(CURRENT_VERSION) + + +def test_pixel_coordinates(): + array = np.zeros((2, 2, 2), dtype=np.int32) + pixel_coordinates = { + Axes.X: [2, 3], + Axes.ZPLANE: [0, 1], + } + physical_coordinates = { + Coordinates.X: [0, 0.5], + Coordinates.Y: [0, 0.2], + Coordinates.Z: [0, 0.1], + } + label_image = LabelImage.from_array_and_coords( + array, pixel_coordinates, physical_coordinates, None) + + assert np.array_equal(label_image.xarray.coords[Axes.X.value], [2, 3]) + # not provided, should be 0..N-1 + assert np.array_equal(label_image.xarray.coords[Axes.Y.value], [0, 1]) + assert np.array_equal(label_image.xarray.coords[Axes.ZPLANE.value], [0, 1]) + + +def test_save_and_load(tmp_path): + """Verify that we can save the label image and load it correctly.""" + array = np.zeros((2, 2, 2), dtype=np.int32) + pixel_coordinates = { + Axes.X: [2, 3], + Axes.ZPLANE: [0, 1], + } + physical_coordinates = { + Coordinates.X: [0, 0.5], + Coordinates.Y: [0, 0.2], + Coordinates.Z: [0, 0.1], + } + log = Log() + # instantiate a filter (even though that makes no sense in this context) + filt = Filter.Reduce((Axes.ROUND,), func="max") + log.update_log(filt) + + label_image = LabelImage.from_array_and_coords( + array, pixel_coordinates, physical_coordinates, log) + label_image.save(tmp_path / "label_image.netcdf") + + loaded_label_image = LabelImage.from_disk(tmp_path / "label_image.netcdf") + + assert label_image.xarray.equals(loaded_label_image.xarray) + assert label_image.xarray.attrs == loaded_label_image.xarray.attrs diff --git a/starfish/core/label_image/util.py b/starfish/core/label_image/util.py new file mode 100644 index 000000000..66d4ae5b5 --- /dev/null +++ b/starfish/core/label_image/util.py @@ -0,0 +1,30 @@ +from typing import List, Tuple + +from starfish.core.types import Axes, Coordinates + + +def _get_axes_names(ndim: int) -> Tuple[List[Axes], List[Coordinates]]: + """Get needed axes and coordinates given the number of dimensions. + + Parameters + ---------- + ndim : int + Number of dimensions. + + Returns + ------- + axes : List[Axes] + Axes. + coords : List[Coordinates] + Coordinates. + """ + if ndim == 2: + axes = [Axes.Y, Axes.X] + coords = [Coordinates.Y, Coordinates.X] + elif ndim == 3: + axes = [Axes.ZPLANE, Axes.Y, Axes.X] + coords = [Coordinates.Z, Coordinates.Y, Coordinates.X] + else: + raise TypeError('expected 2- or 3-D image') + + return axes, coords