Skip to content

Commit

Permalink
Add a label image data type (#1619)
Browse files Browse the repository at this point in the history
Rather than passing around a naked 2D/3D array as a label image, create a type for holding that data.

Test plan: see tests!
Depends on #1617
Part of #1497
  • Loading branch information
Tony Tung authored Oct 28, 2019
1 parent 421ad08 commit 98d54af
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 0 deletions.
1 change: 1 addition & 0 deletions starfish/core/label_image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .label_image import LabelImage
172 changes: 172 additions & 0 deletions starfish/core/label_image/label_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from pathlib import Path
from typing import Any, Hashable, Mapping, MutableMapping, Optional, Sequence, Tuple, Union

import numpy as np
import xarray as xr
from semantic_version import Version

from starfish.core.types import Axes, Coordinates, LOG, Number, STARFISH_EXTRAS_KEY
from starfish.core.util.logging import Log
from .util import _get_axes_names


class AttrKeys:
DOCTYPE = f"{STARFISH_EXTRAS_KEY}.DOCTYPE"
LOG = f"{STARFISH_EXTRAS_KEY}.{LOG}"
VERSION = f"{STARFISH_EXTRAS_KEY}.VERSION"


DOCTYPE_STRING = "starfish/LabelImage"
CURRENT_VERSION = Version("0.0.0")
MIN_SUPPORTED_VERSION = Version("0.0.0")
MAX_SUPPORTED_VERSION = Version("0.0.0")


class LabelImage:
"""Wraps an xarray that contains a 2D or 3D labeled image. Each axis is labeled with physical
coordinate data."""

def __init__(self, label_image: xr.DataArray):
# verify that the data array has the required elements.
if label_image.dtype.kind not in ("i", "u"):
raise TypeError("label image should be an integer type")
for axis in (Axes.X, Axes.Y):
if axis.value not in label_image.coords:
raise ValueError(f"label image should have an {axis.value} axis")
expected_coordinates: Tuple[Coordinates, ...]
if label_image.ndim == 5:
expected_coordinates = (Coordinates.X, Coordinates.Y, Coordinates.Z)
else:
expected_coordinates = (Coordinates.X, Coordinates.Y)
for coord in expected_coordinates:
if coord.value not in label_image.coords:
raise ValueError(f"label image should have a {coord.value} coordinates")

self.label_image = label_image.copy(deep=False)
if AttrKeys.DOCTYPE not in self.label_image.attrs:
self.label_image.attrs[AttrKeys.DOCTYPE] = DOCTYPE_STRING
if AttrKeys.LOG not in self.label_image.attrs:
self.label_image.attrs[AttrKeys.LOG] = Log().encode()

@classmethod
def from_array_and_coords(
cls,
array: np.ndarray,
pixel_coordinates: Optional[Union[
Mapping[Axes, Sequence[int]],
Mapping[str, Sequence[int]]]],
physical_coordinates: Union[
Mapping[Coordinates, Sequence[Number]],
Mapping[str, Sequence[Number]]],
log: Optional[Log],
) -> "LabelImage":
"""Constructs a LabelImage from an array containing the labels, a set of physical
coordinates, and an optional log of how this label image came to be.
Parameters
----------
array : np.ndarray
A 2D or 3D array containing the labels. The ordering of the axes must be Y, X for 2D
images and ZPLANE, Y, X for 3D images.
pixel_coordinates : Optional[Mapping[Union[Axes, str], Sequence[int]]]
A map from the axis to the values for that axis. For any axis that exist in the array
but not in pixel_coordinates, the pixel coordinates are assigned from 0..N-1, where N is
the size along that axis.
physical_coordinates : Mapping[Union[Coordinates, str], Sequence[Number]]
A map from the physical coordinate type to the values for axis. For 2D label images,
X and Y physical coordinates must be provided. For 3D label images, Z physical
coordinates must also be provided.
log : Optional[Log]
A log of how this label image came to be.
"""
# normalize the pixel coordinates to Mapping[Axes, Sequence[int]]
pixel_coordinates = {
axis if isinstance(axis, Axes) else Axes(axis): axis_values
for axis, axis_values in (pixel_coordinates or {}).items()
}
# normalize the physical coordinates to Mapping[Coordinates, Sequence[Number]]
physical_coordinates = {
coord if isinstance(coord, Coordinates) else Coordinates(coord): coord_values
for coord, coord_values in physical_coordinates.items()
}

img_axes, img_coords = _get_axes_names(array.ndim)
xr_axes = [axis.value for axis in img_axes]
try:
xr_coords: MutableMapping[Hashable, Any] = {
coord.value: (axis.value, physical_coordinates[coord])
for axis, coord in zip(img_axes, img_coords)
}
except KeyError as ex:
raise KeyError(f"missing physical coordinates {ex.args[0]}") from ex

xr_coords[Axes.X.value] = pixel_coordinates.get(Axes.X, np.arange(0, array.shape[-1]))
xr_coords[Axes.Y.value] = pixel_coordinates.get(Axes.Y, np.arange(0, array.shape[-2]))
if array.ndim == 3:
xr_coords[Axes.ZPLANE.value] = pixel_coordinates.get(
Axes.ZPLANE, np.arange(0, array.shape[-3]))

dataarray = xr.DataArray(
array,
dims=xr_axes,
coords=xr_coords,
)
dataarray.attrs.update({
AttrKeys.LOG: (log or Log()).encode(),
AttrKeys.DOCTYPE: DOCTYPE_STRING,
AttrKeys.VERSION: str(CURRENT_VERSION),
})

return LabelImage(dataarray)

@property
def xarray(self):
"""Returns the xarray that contains the label image and the physical coordinates."""
return self.label_image

@property
def log(self) -> Log:
"""Returns a copy of the provenance data. Modifications to this copy will not affect the
log stored on this label image."""
return Log.decode(self.label_image.attrs[AttrKeys.LOG])

@classmethod
def from_disk(cls, path: Union[str, Path]) -> "LabelImage":
"""Load a label image from disk.
Parameters
----------
path : Union[str, Path]
Path of the label image to instantiate from.
Returns
-------
label_image : LabelImage
Label image from the path.
"""
label_image = xr.open_dataarray(path)
if (
AttrKeys.DOCTYPE not in label_image.attrs
or label_image.attrs[AttrKeys.DOCTYPE] != DOCTYPE_STRING
or AttrKeys.VERSION not in label_image.attrs
):
raise ValueError(f"{path} does not appear to be a starfish label image")
if not (
MIN_SUPPORTED_VERSION
<= Version(label_image.attrs[AttrKeys.VERSION])
<= MAX_SUPPORTED_VERSION):
raise ValueError(
f"{path} contains a label image, but the version "
f"{label_image.attrs[AttrKeys.VERSION]} is not supported")

return cls(label_image)

def save(self, path: Union[str, Path]):
"""Save the label image to disk.
Parameters
----------
path : Union[str, Path]
Path of the netcdf file to write to.
"""
self.label_image.to_netcdf(path)
Empty file.
118 changes: 118 additions & 0 deletions starfish/core/label_image/test/test_label_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Mapping, Optional, Sequence, Type

import numpy as np
import pytest

from starfish import Log
from starfish.image import Filter
from starfish.types import Axes, Coordinates, Number
from ..label_image import AttrKeys, CURRENT_VERSION, DOCTYPE_STRING, LabelImage


@pytest.mark.parametrize(
"array, physical_coordinates, log, expected_error",
[
# 3D label image
[
np.zeros((1, 1, 1), dtype=np.int32),
{
Coordinates.X: [0],
Coordinates.Y: [0],
Coordinates.Z: [0],
},
None,
None,
],
# 2D label image
[
np.zeros((1, 1), dtype=np.int32),
{
Coordinates.X: [0],
Coordinates.Y: [0],
},
None,
None,
],
# wrong dtype
[
np.zeros((1, 1), dtype=np.float32),
{
Coordinates.X: [0],
Coordinates.Y: [0],
},
None,
TypeError,
],
# missing some coordinates
[
np.zeros((1, 1), dtype=np.float32),
{
Coordinates.X: [0],
},
None,
KeyError,
],
]
)
def test_from_array_and_coords(
array: np.ndarray,
physical_coordinates: Mapping[Coordinates, Sequence[Number]],
log: Optional[Log],
expected_error: Optional[Type[Exception]],
):
"""Test that we can construct a LabelImage and that some common error conditions are caught."""
if expected_error is not None:
with pytest.raises(expected_error):
LabelImage.from_array_and_coords(array, None, physical_coordinates, log)
else:
label_image = LabelImage.from_array_and_coords(array, None, physical_coordinates, log)
assert isinstance(label_image.log, Log)
assert label_image.xarray.attrs.get(AttrKeys.DOCTYPE, None) == DOCTYPE_STRING
assert label_image.xarray.attrs.get(AttrKeys.VERSION, None) == str(CURRENT_VERSION)


def test_pixel_coordinates():
array = np.zeros((2, 2, 2), dtype=np.int32)
pixel_coordinates = {
Axes.X: [2, 3],
Axes.ZPLANE: [0, 1],
}
physical_coordinates = {
Coordinates.X: [0, 0.5],
Coordinates.Y: [0, 0.2],
Coordinates.Z: [0, 0.1],
}
label_image = LabelImage.from_array_and_coords(
array, pixel_coordinates, physical_coordinates, None)

assert np.array_equal(label_image.xarray.coords[Axes.X.value], [2, 3])
# not provided, should be 0..N-1
assert np.array_equal(label_image.xarray.coords[Axes.Y.value], [0, 1])
assert np.array_equal(label_image.xarray.coords[Axes.ZPLANE.value], [0, 1])


def test_save_and_load(tmp_path):
"""Verify that we can save the label image and load it correctly."""
array = np.zeros((2, 2, 2), dtype=np.int32)
pixel_coordinates = {
Axes.X: [2, 3],
Axes.ZPLANE: [0, 1],
}
physical_coordinates = {
Coordinates.X: [0, 0.5],
Coordinates.Y: [0, 0.2],
Coordinates.Z: [0, 0.1],
}
log = Log()
# instantiate a filter (even though that makes no sense in this context)
filt = Filter.Reduce((Axes.ROUND,), func="max")
log.update_log(filt)

label_image = LabelImage.from_array_and_coords(
array, pixel_coordinates, physical_coordinates, log)
label_image.save(tmp_path / "label_image.netcdf")

loaded_label_image = LabelImage.from_disk(tmp_path / "label_image.netcdf")

assert label_image.xarray.equals(loaded_label_image.xarray)
assert label_image.xarray.attrs == loaded_label_image.xarray.attrs
30 changes: 30 additions & 0 deletions starfish/core/label_image/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List, Tuple

from starfish.core.types import Axes, Coordinates


def _get_axes_names(ndim: int) -> Tuple[List[Axes], List[Coordinates]]:
"""Get needed axes and coordinates given the number of dimensions.
Parameters
----------
ndim : int
Number of dimensions.
Returns
-------
axes : List[Axes]
Axes.
coords : List[Coordinates]
Coordinates.
"""
if ndim == 2:
axes = [Axes.Y, Axes.X]
coords = [Coordinates.Y, Coordinates.X]
elif ndim == 3:
axes = [Axes.ZPLANE, Axes.Y, Axes.X]
coords = [Coordinates.Z, Coordinates.Y, Coordinates.X]
else:
raise TypeError('expected 2- or 3-D image')

return axes, coords

0 comments on commit 98d54af

Please sign in to comment.