-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rather than passing around a naked 2D/3D array as a label image, create a type for holding that data. Test plan: see tests! Depends on #1617 Part of #1497
- Loading branch information
Tony Tung
authored
Oct 28, 2019
1 parent
421ad08
commit 98d54af
Showing
5 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .label_image import LabelImage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from pathlib import Path | ||
from typing import Any, Hashable, Mapping, MutableMapping, Optional, Sequence, Tuple, Union | ||
|
||
import numpy as np | ||
import xarray as xr | ||
from semantic_version import Version | ||
|
||
from starfish.core.types import Axes, Coordinates, LOG, Number, STARFISH_EXTRAS_KEY | ||
from starfish.core.util.logging import Log | ||
from .util import _get_axes_names | ||
|
||
|
||
class AttrKeys: | ||
DOCTYPE = f"{STARFISH_EXTRAS_KEY}.DOCTYPE" | ||
LOG = f"{STARFISH_EXTRAS_KEY}.{LOG}" | ||
VERSION = f"{STARFISH_EXTRAS_KEY}.VERSION" | ||
|
||
|
||
DOCTYPE_STRING = "starfish/LabelImage" | ||
CURRENT_VERSION = Version("0.0.0") | ||
MIN_SUPPORTED_VERSION = Version("0.0.0") | ||
MAX_SUPPORTED_VERSION = Version("0.0.0") | ||
|
||
|
||
class LabelImage: | ||
"""Wraps an xarray that contains a 2D or 3D labeled image. Each axis is labeled with physical | ||
coordinate data.""" | ||
|
||
def __init__(self, label_image: xr.DataArray): | ||
# verify that the data array has the required elements. | ||
if label_image.dtype.kind not in ("i", "u"): | ||
raise TypeError("label image should be an integer type") | ||
for axis in (Axes.X, Axes.Y): | ||
if axis.value not in label_image.coords: | ||
raise ValueError(f"label image should have an {axis.value} axis") | ||
expected_coordinates: Tuple[Coordinates, ...] | ||
if label_image.ndim == 5: | ||
expected_coordinates = (Coordinates.X, Coordinates.Y, Coordinates.Z) | ||
else: | ||
expected_coordinates = (Coordinates.X, Coordinates.Y) | ||
for coord in expected_coordinates: | ||
if coord.value not in label_image.coords: | ||
raise ValueError(f"label image should have a {coord.value} coordinates") | ||
|
||
self.label_image = label_image.copy(deep=False) | ||
if AttrKeys.DOCTYPE not in self.label_image.attrs: | ||
self.label_image.attrs[AttrKeys.DOCTYPE] = DOCTYPE_STRING | ||
if AttrKeys.LOG not in self.label_image.attrs: | ||
self.label_image.attrs[AttrKeys.LOG] = Log().encode() | ||
|
||
@classmethod | ||
def from_array_and_coords( | ||
cls, | ||
array: np.ndarray, | ||
pixel_coordinates: Optional[Union[ | ||
Mapping[Axes, Sequence[int]], | ||
Mapping[str, Sequence[int]]]], | ||
physical_coordinates: Union[ | ||
Mapping[Coordinates, Sequence[Number]], | ||
Mapping[str, Sequence[Number]]], | ||
log: Optional[Log], | ||
) -> "LabelImage": | ||
"""Constructs a LabelImage from an array containing the labels, a set of physical | ||
coordinates, and an optional log of how this label image came to be. | ||
Parameters | ||
---------- | ||
array : np.ndarray | ||
A 2D or 3D array containing the labels. The ordering of the axes must be Y, X for 2D | ||
images and ZPLANE, Y, X for 3D images. | ||
pixel_coordinates : Optional[Mapping[Union[Axes, str], Sequence[int]]] | ||
A map from the axis to the values for that axis. For any axis that exist in the array | ||
but not in pixel_coordinates, the pixel coordinates are assigned from 0..N-1, where N is | ||
the size along that axis. | ||
physical_coordinates : Mapping[Union[Coordinates, str], Sequence[Number]] | ||
A map from the physical coordinate type to the values for axis. For 2D label images, | ||
X and Y physical coordinates must be provided. For 3D label images, Z physical | ||
coordinates must also be provided. | ||
log : Optional[Log] | ||
A log of how this label image came to be. | ||
""" | ||
# normalize the pixel coordinates to Mapping[Axes, Sequence[int]] | ||
pixel_coordinates = { | ||
axis if isinstance(axis, Axes) else Axes(axis): axis_values | ||
for axis, axis_values in (pixel_coordinates or {}).items() | ||
} | ||
# normalize the physical coordinates to Mapping[Coordinates, Sequence[Number]] | ||
physical_coordinates = { | ||
coord if isinstance(coord, Coordinates) else Coordinates(coord): coord_values | ||
for coord, coord_values in physical_coordinates.items() | ||
} | ||
|
||
img_axes, img_coords = _get_axes_names(array.ndim) | ||
xr_axes = [axis.value for axis in img_axes] | ||
try: | ||
xr_coords: MutableMapping[Hashable, Any] = { | ||
coord.value: (axis.value, physical_coordinates[coord]) | ||
for axis, coord in zip(img_axes, img_coords) | ||
} | ||
except KeyError as ex: | ||
raise KeyError(f"missing physical coordinates {ex.args[0]}") from ex | ||
|
||
xr_coords[Axes.X.value] = pixel_coordinates.get(Axes.X, np.arange(0, array.shape[-1])) | ||
xr_coords[Axes.Y.value] = pixel_coordinates.get(Axes.Y, np.arange(0, array.shape[-2])) | ||
if array.ndim == 3: | ||
xr_coords[Axes.ZPLANE.value] = pixel_coordinates.get( | ||
Axes.ZPLANE, np.arange(0, array.shape[-3])) | ||
|
||
dataarray = xr.DataArray( | ||
array, | ||
dims=xr_axes, | ||
coords=xr_coords, | ||
) | ||
dataarray.attrs.update({ | ||
AttrKeys.LOG: (log or Log()).encode(), | ||
AttrKeys.DOCTYPE: DOCTYPE_STRING, | ||
AttrKeys.VERSION: str(CURRENT_VERSION), | ||
}) | ||
|
||
return LabelImage(dataarray) | ||
|
||
@property | ||
def xarray(self): | ||
"""Returns the xarray that contains the label image and the physical coordinates.""" | ||
return self.label_image | ||
|
||
@property | ||
def log(self) -> Log: | ||
"""Returns a copy of the provenance data. Modifications to this copy will not affect the | ||
log stored on this label image.""" | ||
return Log.decode(self.label_image.attrs[AttrKeys.LOG]) | ||
|
||
@classmethod | ||
def from_disk(cls, path: Union[str, Path]) -> "LabelImage": | ||
"""Load a label image from disk. | ||
Parameters | ||
---------- | ||
path : Union[str, Path] | ||
Path of the label image to instantiate from. | ||
Returns | ||
------- | ||
label_image : LabelImage | ||
Label image from the path. | ||
""" | ||
label_image = xr.open_dataarray(path) | ||
if ( | ||
AttrKeys.DOCTYPE not in label_image.attrs | ||
or label_image.attrs[AttrKeys.DOCTYPE] != DOCTYPE_STRING | ||
or AttrKeys.VERSION not in label_image.attrs | ||
): | ||
raise ValueError(f"{path} does not appear to be a starfish label image") | ||
if not ( | ||
MIN_SUPPORTED_VERSION | ||
<= Version(label_image.attrs[AttrKeys.VERSION]) | ||
<= MAX_SUPPORTED_VERSION): | ||
raise ValueError( | ||
f"{path} contains a label image, but the version " | ||
f"{label_image.attrs[AttrKeys.VERSION]} is not supported") | ||
|
||
return cls(label_image) | ||
|
||
def save(self, path: Union[str, Path]): | ||
"""Save the label image to disk. | ||
Parameters | ||
---------- | ||
path : Union[str, Path] | ||
Path of the netcdf file to write to. | ||
""" | ||
self.label_image.to_netcdf(path) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from typing import Mapping, Optional, Sequence, Type | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from starfish import Log | ||
from starfish.image import Filter | ||
from starfish.types import Axes, Coordinates, Number | ||
from ..label_image import AttrKeys, CURRENT_VERSION, DOCTYPE_STRING, LabelImage | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"array, physical_coordinates, log, expected_error", | ||
[ | ||
# 3D label image | ||
[ | ||
np.zeros((1, 1, 1), dtype=np.int32), | ||
{ | ||
Coordinates.X: [0], | ||
Coordinates.Y: [0], | ||
Coordinates.Z: [0], | ||
}, | ||
None, | ||
None, | ||
], | ||
# 2D label image | ||
[ | ||
np.zeros((1, 1), dtype=np.int32), | ||
{ | ||
Coordinates.X: [0], | ||
Coordinates.Y: [0], | ||
}, | ||
None, | ||
None, | ||
], | ||
# wrong dtype | ||
[ | ||
np.zeros((1, 1), dtype=np.float32), | ||
{ | ||
Coordinates.X: [0], | ||
Coordinates.Y: [0], | ||
}, | ||
None, | ||
TypeError, | ||
], | ||
# missing some coordinates | ||
[ | ||
np.zeros((1, 1), dtype=np.float32), | ||
{ | ||
Coordinates.X: [0], | ||
}, | ||
None, | ||
KeyError, | ||
], | ||
] | ||
) | ||
def test_from_array_and_coords( | ||
array: np.ndarray, | ||
physical_coordinates: Mapping[Coordinates, Sequence[Number]], | ||
log: Optional[Log], | ||
expected_error: Optional[Type[Exception]], | ||
): | ||
"""Test that we can construct a LabelImage and that some common error conditions are caught.""" | ||
if expected_error is not None: | ||
with pytest.raises(expected_error): | ||
LabelImage.from_array_and_coords(array, None, physical_coordinates, log) | ||
else: | ||
label_image = LabelImage.from_array_and_coords(array, None, physical_coordinates, log) | ||
assert isinstance(label_image.log, Log) | ||
assert label_image.xarray.attrs.get(AttrKeys.DOCTYPE, None) == DOCTYPE_STRING | ||
assert label_image.xarray.attrs.get(AttrKeys.VERSION, None) == str(CURRENT_VERSION) | ||
|
||
|
||
def test_pixel_coordinates(): | ||
array = np.zeros((2, 2, 2), dtype=np.int32) | ||
pixel_coordinates = { | ||
Axes.X: [2, 3], | ||
Axes.ZPLANE: [0, 1], | ||
} | ||
physical_coordinates = { | ||
Coordinates.X: [0, 0.5], | ||
Coordinates.Y: [0, 0.2], | ||
Coordinates.Z: [0, 0.1], | ||
} | ||
label_image = LabelImage.from_array_and_coords( | ||
array, pixel_coordinates, physical_coordinates, None) | ||
|
||
assert np.array_equal(label_image.xarray.coords[Axes.X.value], [2, 3]) | ||
# not provided, should be 0..N-1 | ||
assert np.array_equal(label_image.xarray.coords[Axes.Y.value], [0, 1]) | ||
assert np.array_equal(label_image.xarray.coords[Axes.ZPLANE.value], [0, 1]) | ||
|
||
|
||
def test_save_and_load(tmp_path): | ||
"""Verify that we can save the label image and load it correctly.""" | ||
array = np.zeros((2, 2, 2), dtype=np.int32) | ||
pixel_coordinates = { | ||
Axes.X: [2, 3], | ||
Axes.ZPLANE: [0, 1], | ||
} | ||
physical_coordinates = { | ||
Coordinates.X: [0, 0.5], | ||
Coordinates.Y: [0, 0.2], | ||
Coordinates.Z: [0, 0.1], | ||
} | ||
log = Log() | ||
# instantiate a filter (even though that makes no sense in this context) | ||
filt = Filter.Reduce((Axes.ROUND,), func="max") | ||
log.update_log(filt) | ||
|
||
label_image = LabelImage.from_array_and_coords( | ||
array, pixel_coordinates, physical_coordinates, log) | ||
label_image.save(tmp_path / "label_image.netcdf") | ||
|
||
loaded_label_image = LabelImage.from_disk(tmp_path / "label_image.netcdf") | ||
|
||
assert label_image.xarray.equals(loaded_label_image.xarray) | ||
assert label_image.xarray.attrs == loaded_label_image.xarray.attrs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import List, Tuple | ||
|
||
from starfish.core.types import Axes, Coordinates | ||
|
||
|
||
def _get_axes_names(ndim: int) -> Tuple[List[Axes], List[Coordinates]]: | ||
"""Get needed axes and coordinates given the number of dimensions. | ||
Parameters | ||
---------- | ||
ndim : int | ||
Number of dimensions. | ||
Returns | ||
------- | ||
axes : List[Axes] | ||
Axes. | ||
coords : List[Coordinates] | ||
Coordinates. | ||
""" | ||
if ndim == 2: | ||
axes = [Axes.Y, Axes.X] | ||
coords = [Coordinates.Y, Coordinates.X] | ||
elif ndim == 3: | ||
axes = [Axes.ZPLANE, Axes.Y, Axes.X] | ||
coords = [Coordinates.Z, Coordinates.Y, Coordinates.X] | ||
else: | ||
raise TypeError('expected 2- or 3-D image') | ||
|
||
return axes, coords |