From 856314c31564cd9c7f8129a12a47c16f9094fb06 Mon Sep 17 00:00:00 2001 From: Kira Evans Date: Mon, 1 Apr 2019 16:10:57 -0700 Subject: [PATCH 1/6] replace label images with segmentation masks --- notebooks/ISS_Pipeline_-_Breast_-_1_FOV.ipynb | 6 +- notebooks/py/ISS_Pipeline_-_Breast_-_1_FOV.py | 4 +- starfish/image/_segmentation/_base.py | 9 +- starfish/image/_segmentation/watershed.py | 22 ++- starfish/segmentation_mask.py | 177 ++++++++++++++++++ starfish/spots/_target_assignment/label.py | 51 ++--- .../test/full_pipelines/api/test_iss_api.py | 6 +- starfish/test/test_segmentation_mask.py | 125 +++++++++++++ 8 files changed, 359 insertions(+), 41 deletions(-) create mode 100644 starfish/segmentation_mask.py create mode 100644 starfish/test/test_segmentation_mask.py diff --git a/notebooks/ISS_Pipeline_-_Breast_-_1_FOV.ipynb b/notebooks/ISS_Pipeline_-_Breast_-_1_FOV.ipynb index fc66c9c35..c4a85eaf0 100644 --- a/notebooks/ISS_Pipeline_-_Breast_-_1_FOV.ipynb +++ b/notebooks/ISS_Pipeline_-_Breast_-_1_FOV.ipynb @@ -373,7 +373,7 @@ " input_threshold=stain_thresh,\n", " min_distance=min_dist\n", ")\n", - "label_image = seg.run(registered_image, nuclei)\n", + "masks = seg.run(registered_image, nuclei)\n", "seg.show()" ] }, @@ -392,7 +392,7 @@ "source": [ "from starfish.spots import TargetAssignment\n", "al = TargetAssignment.Label()\n", - "labeled = al.run(label_image, decoded)" + "labeled = al.run(masks, decoded)" ] }, { @@ -513,4 +513,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/notebooks/py/ISS_Pipeline_-_Breast_-_1_FOV.py b/notebooks/py/ISS_Pipeline_-_Breast_-_1_FOV.py index 867d4477a..18f54ab12 100644 --- a/notebooks/py/ISS_Pipeline_-_Breast_-_1_FOV.py +++ b/notebooks/py/ISS_Pipeline_-_Breast_-_1_FOV.py @@ -240,7 +240,7 @@ input_threshold=stain_thresh, min_distance=min_dist ) -label_image = seg.run(registered_image, nuclei) +masks = seg.run(registered_image, nuclei) seg.show() # EPY: END code @@ -251,7 +251,7 @@ # EPY: START code from starfish.spots import TargetAssignment al = TargetAssignment.Label() -labeled = al.run(label_image, decoded) +labeled = al.run(masks, decoded) # EPY: END code # EPY: START code diff --git a/starfish/image/_segmentation/_base.py b/starfish/image/_segmentation/_base.py index 2735594e0..7ef4a2541 100644 --- a/starfish/image/_segmentation/_base.py +++ b/starfish/image/_segmentation/_base.py @@ -6,6 +6,7 @@ from starfish.imagestack.imagestack import ImageStack from starfish.pipeline import PipelineComponent from starfish.pipeline.algorithmbase import AlgorithmBase +from starfish.segmentation_mask import SegmentationMaskCollection from starfish.util import click @@ -50,8 +51,12 @@ class SegmentationAlgorithmBase(AlgorithmBase): @classmethod def get_pipeline_component_class(cls) -> Type[PipelineComponent]: return Segmentation - @abstractmethod - def run(self, primary_image_stack: ImageStack, nuclei_stack: ImageStack, *args): + def run( + self, + primary_image_stack: ImageStack, + nuclei_stack: ImageStack, + *args + ) -> SegmentationMaskCollection: """Performs registration on the stack provided.""" raise NotImplementedError() diff --git a/starfish/image/_segmentation/watershed.py b/starfish/image/_segmentation/watershed.py index e62a4890a..3f4d97047 100644 --- a/starfish/image/_segmentation/watershed.py +++ b/starfish/image/_segmentation/watershed.py @@ -9,7 +9,8 @@ from starfish.image._filter.util import bin_open, bin_thresh from starfish.imagestack.imagestack import ImageStack -from starfish.types import Axes, Number +from starfish.segmentation_mask import SegmentationMaskCollection +from starfish.types import Axes, Coordinates, Number from starfish.util import click from ._base import SegmentationAlgorithmBase @@ -49,7 +50,12 @@ def __init__( self.min_distance = min_distance self._segmentation_instance: Optional[_WatershedSegmenter] = None - def run(self, primary_images: ImageStack, nuclei: ImageStack, *args) -> np.ndarray: + def run( + self, + primary_images: ImageStack, + nuclei: ImageStack, + *args + ) -> SegmentationMaskCollection: """Segments nuclei in 2-d using a nuclei ImageStack Primary images are used to expand the nuclear mask, but only in cases where there are @@ -64,9 +70,8 @@ def run(self, primary_images: ImageStack, nuclei: ImageStack, *args) -> np.ndarr Returns ------- - np.ndarray : - label image where each cell is labeled by a different positive integer value. 0 - implies that a pixel is not part of a cell. + masks : SegmentationMaskCollection + binary masks segmenting each cell """ # create a 'stain' for segmentation @@ -88,7 +93,12 @@ def run(self, primary_images: ImageStack, nuclei: ImageStack, *args) -> np.ndarr disk_size_mask, self.min_distance ) - return label_image + # we max-projected and squeezed the Z-plane so label_image.ndim == 2 + physical_ticks = {coord: nuclei.xarray.coords[coord.value].data.tolist() + for coord in (Coordinates.Y, Coordinates.X)} + + return SegmentationMaskCollection.from_label_image(label_image, + physical_ticks) def show(self, figsize: Tuple[int, int]=(10, 10)) -> None: if isinstance(self._segmentation_instance, _WatershedSegmenter): diff --git a/starfish/segmentation_mask.py b/starfish/segmentation_mask.py new file mode 100644 index 000000000..3bc22ece1 --- /dev/null +++ b/starfish/segmentation_mask.py @@ -0,0 +1,177 @@ +import itertools +import os +import os.path as osp +import shutil +from typing import Dict, List, Tuple, Union + +import numpy as np +import xarray as xr +from skimage.measure import regionprops + +from starfish.types import Axes, Coordinates + + +AXES = [a.value for a in Axes if a not in (Axes.ROUND, Axes.CH)] +COORDS = [c.value for c in Coordinates] + + +def validate_segmentation_mask(arr: xr.DataArray): + """Validate if the given array is a segmentation mask. + + Parameters + ---------- + arr : xr.DataArray + Array to check. + """ + if not isinstance(arr, xr.DataArray): + raise TypeError(f"expected DataArray; got {type(arr)}") + + if arr.ndim not in (2, 3): + raise TypeError(f"expected 2 or 3 dimensions; got {arr.ndim}") + + if arr.dtype != np.bool: + raise TypeError(f"expected dtype of bool; got {arr.dtype}") + + if arr.ndim == 2: + axes = AXES[1:] + coords = COORDS[1:] + else: + axes = AXES + coords = COORDS + + for dim in axes: + if dim not in arr.dims: + raise TypeError(f"no dimension '{dim}'") + + for coord in itertools.chain(axes, coords): + if coord not in arr.coords: + raise TypeError(f"no coordinate '{coord}'") + + +class SegmentationMaskCollection: + """Collection of binary segmentation masks with a list-like access pattern. + + Parameters + ---------- + masks : list of xr.DataArray + Segmentation masks. + """ + _masks: List[xr.DataArray] + + def __init__(self, masks: List[xr.DataArray]): + for mask in masks: + validate_segmentation_mask(mask) + + self._masks = masks + + def __getitem__(self, index): + return self._masks[index] + + def __iter__(self): + return iter(self._masks) + + def __len__(self): + return len(self._masks) + + def add_mask(self, mask: xr.DataArray): + """Add an existing segmentation mask. + + Parameters + ---------- + arr : xr.DataArray + Segmentation mask. + """ + validate_segmentation_mask(mask) + self._masks.append(mask) + + @classmethod + def from_label_image( + cls, + label_image: np.ndarray, + physical_ticks: Dict[Coordinates, List[float]] + ) -> "SegmentationMaskCollection": + """Creates segmentation masks from a label image. + + Parameters + ---------- + label_image : int array + Integer array where each integer corresponds to a cell. + physical_ticks : Dict[Coordinates, List[float]] + Physical coordinates for each axis. + + Returns + ------- + masks : SegmentationMaskCollection + Masks generated from the label image. + """ + props = regionprops(label_image) + + if label_image.ndim == 2: + dims = AXES[1:] + elif label_image.ndim == 3: + dims = AXES + else: + raise TypeError('expected 2- or 3-D image') + + masks: List[xr.DataArray] = [] + + coords: Dict[str, Union[list, Tuple[str, list]]] + + for label, prop in enumerate(props): + coords = {d: list(range(prop.bbox[i], prop.bbox[i + len(dims)])) + for i, d in enumerate(dims)} + + for d, c in physical_ticks.items(): + axis = d.value[0] + i = dims.index(axis) + coords[d.value] = (axis, c[prop.bbox[i]:prop.bbox[i + len(dims)]]) + + mask = xr.DataArray(prop.image, + dims=dims, + coords=coords, + name=str(label + 1)) + masks.append(mask) + + return cls(masks) + + @classmethod + def from_disk(cls, path: str) -> "SegmentationMaskCollection": + """Load the collection from disk. + + Parameters + ---------- + path : str + Path of the directory to instantiate from. + + Returns + ------- + masks : SegmentationMaskCollection + Collection of segmentation masks. + """ + masks = [] + for p in os.listdir(path): + mask = xr.open_dataarray(osp.join(path, p)) + masks.append(mask) + + return cls(masks) + + def save(self, path: str, overwrite: bool = False): + """Save the segmentation masks to disk. + + Parameters + ---------- + path : str + Path of the directory to write to. + overwrite : bool, optional + Whether to overwrite the directory if it exists. + """ + try: + os.mkdir(path) + except FileExistsError: + if not overwrite: + raise + shutil.rmtree(path, ignore_errors=True) + os.mkdir(path) + + for i, mask in enumerate(self._masks): + mask.to_netcdf(osp.join(path, str(i)), 'w') diff --git a/starfish/spots/_target_assignment/label.py b/starfish/spots/_target_assignment/label.py index 9c7ca324a..5f45071bd 100644 --- a/starfish/spots/_target_assignment/label.py +++ b/starfish/spots/_target_assignment/label.py @@ -1,6 +1,5 @@ -import numpy as np - from starfish.intensity_table.intensity_table import IntensityTable +from starfish.segmentation_mask import SegmentationMaskCollection from starfish.types import Axes, Features from starfish.util import click from ._base import TargetAssignmentAlgorithm @@ -19,26 +18,29 @@ def _add_arguments(cls, parser) -> None: @staticmethod def _assign( - label_image: np.ndarray, + masks: SegmentationMaskCollection, intensities: IntensityTable, in_place: bool, ) -> IntensityTable: + cell_ids = [] + + for spot in intensities: + for mask in masks: + sel = {Axes.X.value: spot[Axes.X.value], + Axes.Y.value: spot[Axes.Y.value]} + if mask.ndim == 3: + sel[Axes.ZPLANE.value] = spot[Axes.ZPLANE.value] + + try: + if mask.sel(sel): + cell_id = mask.name + break + except KeyError: + pass + else: + cell_id = '' - if len(label_image.shape) == 3: - cell_ids = label_image[ - intensities[Axes.ZPLANE.value].values, - intensities[Axes.Y.value].values, - intensities[Axes.X.value].values - ] - elif len(label_image.shape) == 2: - cell_ids = label_image[ - intensities[Axes.Y.value].values, - intensities[Axes.X.value].values - ] - else: - raise ValueError( - f"`label_image` must be 2 or 3 dimensional, not {len(label_image.shape)}D." - ) + cell_ids.append(cell_id) if not in_place: intensities = intensities.copy() @@ -49,18 +51,17 @@ def _assign( def run( self, - label_image: np.ndarray, + masks: SegmentationMaskCollection, intensity_table: IntensityTable, - verbose: bool=False, - in_place: bool=False, + verbose: bool = False, + in_place: bool = False, ) -> IntensityTable: """Extract cell ids for features in IntensityTable from a segmentation label image Parameters ---------- - label_image : np.ndarray[np.uint32] - integer array produced from segmentation where each pixel in a cell is labeled by the - same integer, and each cell is labeled by a different integer + masks : SegmentaionMaskCollection + binary masks segmenting each cell intensity_table : IntensityTable spot information in_place : bool @@ -75,7 +76,7 @@ def run( cells will be assigned zero. """ - return self._assign(label_image, intensity_table, in_place=in_place) + return self._assign(masks, intensity_table, in_place=in_place) @staticmethod @click.command("Label") diff --git a/starfish/test/full_pipelines/api/test_iss_api.py b/starfish/test/full_pipelines/api/test_iss_api.py index c5b73e392..7703399cf 100644 --- a/starfish/test/full_pipelines/api/test_iss_api.py +++ b/starfish/test/full_pipelines/api/test_iss_api.py @@ -106,7 +106,7 @@ def test_iss_pipeline_cropped_data(): 'TFRC', 'TP53', 'VEGF'])) assert np.array_equal(gene_counts, [20, 1, 5, 2, 1, 11, 1, 3, 2, 1, 1, 2]) - label_image = iss.label_image + masks = iss.masks seg = iss.seg @@ -115,7 +115,7 @@ def test_iss_pipeline_cropped_data(): # assign targets lab = TargetAssignment.Label() - assigned = lab.run(label_image, decoded) + assigned = lab.run(masks, decoded) pipeline_log = assigned.get_log() @@ -139,4 +139,4 @@ def test_iss_pipeline_cropped_data(): assert pipeline_log[3]['method'] == 'BlobDetector' # 28 of the spots are assigned to cell 1 (although most spots do not decode!) - assert np.sum(assigned['cell_id'] == 1) == 28 + assert np.sum(assigned['cell_id'] == '1') == 28 diff --git a/starfish/test/test_segmentation_mask.py b/starfish/test/test_segmentation_mask.py new file mode 100644 index 000000000..f99197076 --- /dev/null +++ b/starfish/test/test_segmentation_mask.py @@ -0,0 +1,125 @@ +import shutil + +import numpy as np +import pytest +import xarray as xr + +from starfish.segmentation_mask import (SegmentationMaskCollection, + validate_segmentation_mask) +from starfish.types import Axes, Coordinates + + +def test_validate_segmentation_mask(): + good = xr.DataArray([[True, False, False], + [False, True, True]], + dims=('y', 'x'), + coords=dict(x=[0, 1, 2], + y=[0, 1], + xc=('x', [0.5, 1.5, 2.5]), + yc=('y', [0.5, 1.5]))) + validate_segmentation_mask(good) + + good = xr.DataArray([[[True], [False], [False]], + [[False], [True], [True]]], + dims=('z', 'y', 'x'), + coords=dict(z=[0, 1], + y=[1, 2, 3], + x=[42], + zc=('z', [0.5, 1.5]), + yc=('y', [1.5, 2.5, 3.5]), + xc=('x', [42.5]))) + validate_segmentation_mask(good) + + bad = xr.DataArray([[1, 2, 3], + [4, 5, 6]], + dims=('y', 'x'), + coords=dict(x=[0, 1, 2], + y=[0, 1], + xc=('x', [0.5, 1.5, 2.5]), + yc=('y', [0.5, 1.5]))) + with pytest.raises(TypeError): + validate_segmentation_mask(bad) + + bad = xr.DataArray([True], + dims=('x'), + coords=dict(x=[0], + xc=('x', [0.5]))) + with pytest.raises(TypeError): + validate_segmentation_mask(bad) + + bad = xr.DataArray([[True]], + dims=('z', 'y'), + coords=dict(z=[0], + y=[0], + zc=('z', [0.5]), + yc=('y', [0.5]))) + with pytest.raises(TypeError): + validate_segmentation_mask(bad) + + bad = xr.DataArray([[True]], + dims=('x', 'y')) + with pytest.raises(TypeError): + validate_segmentation_mask(bad) + + +def test_from_label_image(): + label_image = np.zeros((5, 5), dtype=np.int32) + label_image[0] = 1 + label_image[3:5, 3:5] = 2 + label_image[-1, -1] = 0 + + physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0], + Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]} + + masks = SegmentationMaskCollection.from_label_image(label_image, + physical_ticks) + + assert len(masks) == 2 + + region_1, region_2 = masks + + assert region_1.name == '1' + assert region_2.name == '2' + + assert np.array_equal(region_1, np.ones((1, 5), dtype=np.bool)) + temp = np.ones((2, 2), dtype=np.bool) + temp[-1, -1] = False + assert np.array_equal(region_2, temp) + + assert np.array_equal(region_1[Axes.Y.value], [0]) + assert np.array_equal(region_1[Axes.X.value], [0, 1, 2, 3, 4]) + + assert np.array_equal(region_2[Axes.Y.value], [3, 4]) + assert np.array_equal(region_2[Axes.X.value], [3, 4]) + + assert np.array_equal(region_1[Coordinates.Y.value], + physical_ticks[Coordinates.Y][0:1]) + assert np.array_equal(region_1[Coordinates.X.value], + physical_ticks[Coordinates.X][0:5]) + + assert np.array_equal(region_2[Coordinates.Y.value], + physical_ticks[Coordinates.Y][3:5]) + assert np.array_equal(region_2[Coordinates.X.value], + physical_ticks[Coordinates.X][3:5]) + + +def test_save_load(): + label_image = np.zeros((5, 5), dtype=np.int32) + label_image[0] = 1 + label_image[3:5, 3:5] = 2 + label_image[-1, -1] = 0 + + physical_ticks = {Coordinates.Y: [1.2, 2.4, 3.6, 4.8, 6.0], + Coordinates.X: [7.2, 8.4, 9.6, 10.8, 12]} + + masks = SegmentationMaskCollection.from_label_image(label_image, + physical_ticks) + + path = 'data' + try: + masks.save(path) + masks2 = SegmentationMaskCollection.from_disk(path) + for m, m2 in zip(masks, masks2): + assert np.array_equal(m, m2) + finally: + shutil.rmtree(path, ignore_errors=True) From 460f69113659d2733818749734adc0fbaea0fdc4 Mon Sep 17 00:00:00 2001 From: Kira Evans Date: Tue, 9 Apr 2019 11:01:03 -0700 Subject: [PATCH 2/6] improve docs & organization --- starfish/image/_segmentation/_base.py | 10 ++-- starfish/image/_segmentation/watershed.py | 2 +- starfish/segmentation_mask.py | 62 +++++++++++++++------- starfish/spots/_target_assignment/label.py | 1 + 4 files changed, 48 insertions(+), 27 deletions(-) diff --git a/starfish/image/_segmentation/_base.py b/starfish/image/_segmentation/_base.py index 7ef4a2541..c86826f54 100644 --- a/starfish/image/_segmentation/_base.py +++ b/starfish/image/_segmentation/_base.py @@ -1,8 +1,6 @@ from abc import abstractmethod from typing import Type -from skimage.io import imsave - from starfish.imagestack.imagestack import ImageStack from starfish.pipeline import PipelineComponent from starfish.pipeline.algorithmbase import AlgorithmBase @@ -25,10 +23,10 @@ def _cli_run(cls, ctx, instance): pri_stack = ctx.obj["primary_images"] nuc_stack = ctx.obj["nuclei"] - label_image = instance.run(pri_stack, nuc_stack) + masks = instance.run(pri_stack, nuc_stack) - print(f"Writing label image to {output}") - imsave(output, label_image) + print(f"Writing masks to {output}") + masks.save(output) @staticmethod @click.group(COMPONENT_NAME) @@ -58,5 +56,5 @@ def run( nuclei_stack: ImageStack, *args ) -> SegmentationMaskCollection: - """Performs registration on the stack provided.""" + """Performs segmentation on the stack provided.""" raise NotImplementedError() diff --git a/starfish/image/_segmentation/watershed.py b/starfish/image/_segmentation/watershed.py index 3f4d97047..57f129dfd 100644 --- a/starfish/image/_segmentation/watershed.py +++ b/starfish/image/_segmentation/watershed.py @@ -94,7 +94,7 @@ def run( ) # we max-projected and squeezed the Z-plane so label_image.ndim == 2 - physical_ticks = {coord: nuclei.xarray.coords[coord.value].data.tolist() + physical_ticks = {coord: nuclei.xarray.coords[coord.value].data for coord in (Coordinates.Y, Coordinates.X)} return SegmentationMaskCollection.from_label_image(label_image, diff --git a/starfish/segmentation_mask.py b/starfish/segmentation_mask.py index 3bc22ece1..f178255a7 100644 --- a/starfish/segmentation_mask.py +++ b/starfish/segmentation_mask.py @@ -2,7 +2,7 @@ import os import os.path as osp import shutil -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Sequence, Tuple, Union import numpy as np import xarray as xr @@ -15,6 +15,33 @@ COORDS = [c.value for c in Coordinates] +def _get_axes_names(ndim: int) -> Tuple[List[str], List[str]]: + """Get needed axes names given the number of dimensions. + + Parameters + ---------- + ndim : int + Number of dimensions. + + Returns + ------- + axes : List[str] + Axes names. + coords : List[str] + Coordinates names. + """ + if ndim == 2: + axes = [axis for axis in AXES if axis != Axes.ZPLANE.value] + coords = [coord for coord in COORDS if coord != Coordinates.Z.value] + elif ndim == 3: + axes = AXES + coords = COORDS + else: + raise TypeError('expected 2- or 3-D image') + + return axes, coords + + def validate_segmentation_mask(arr: xr.DataArray): """Validate if the given array is a segmentation mask. @@ -32,12 +59,7 @@ def validate_segmentation_mask(arr: xr.DataArray): if arr.dtype != np.bool: raise TypeError(f"expected dtype of bool; got {arr.dtype}") - if arr.ndim == 2: - axes = AXES[1:] - coords = COORDS[1:] - else: - axes = AXES - coords = COORDS + axes, coords = _get_axes_names(arr.ndim) for dim in axes: if dim not in arr.dims: @@ -73,7 +95,7 @@ def __iter__(self): def __len__(self): return len(self._masks) - def add_mask(self, mask: xr.DataArray): + def append(self, mask: xr.DataArray): """Add an existing segmentation mask. Parameters @@ -88,15 +110,15 @@ def add_mask(self, mask: xr.DataArray): def from_label_image( cls, label_image: np.ndarray, - physical_ticks: Dict[Coordinates, List[float]] + physical_ticks: Dict[Coordinates, Sequence[float]] ) -> "SegmentationMaskCollection": """Creates segmentation masks from a label image. Parameters ---------- label_image : int array - Integer array where each integer corresponds to a cell. - physical_ticks : Dict[Coordinates, List[float]] + Integer array where each integer corresponds to a region. + physical_ticks : Dict[Coordinates, Sequence[float]] Physical coordinates for each axis. Returns @@ -106,21 +128,21 @@ def from_label_image( """ props = regionprops(label_image) - if label_image.ndim == 2: - dims = AXES[1:] - elif label_image.ndim == 3: - dims = AXES - else: - raise TypeError('expected 2- or 3-D image') + dims, _ = _get_axes_names(label_image.ndim) masks: List[xr.DataArray] = [] - coords: Dict[str, Union[list, Tuple[str, list]]] + coords: Dict[str, Union[list, Tuple[str, Sequence]]] + # for each region (and its properties): for label, prop in enumerate(props): + # create pixel coordinate labels from the bounding box + # to preserve spatial indexing relative to the original image coords = {d: list(range(prop.bbox[i], prop.bbox[i + len(dims)])) for i, d in enumerate(dims)} + # create physical coordinate labels by taking the overlapping + # subset from the full span of labels for d, c in physical_ticks.items(): axis = d.value[0] i = dims.index(axis) @@ -162,8 +184,8 @@ def save(self, path: str, overwrite: bool = False): ---------- path : str Path of the directory to write to. - overwrite : bool, optional - Whether to overwrite the directory if it exists. + overwrite : bool + Whether to overwrite the directory if it exists. (default: False) """ try: os.mkdir(path) diff --git a/starfish/spots/_target_assignment/label.py b/starfish/spots/_target_assignment/label.py index 5f45071bd..f15e085bf 100644 --- a/starfish/spots/_target_assignment/label.py +++ b/starfish/spots/_target_assignment/label.py @@ -24,6 +24,7 @@ def _assign( ) -> IntensityTable: cell_ids = [] + # for each spot, test whether the spot falls inside the area of each mask for spot in intensities: for mask in masks: sel = {Axes.X.value: spot[Axes.X.value], From e21d700728c82eff52c07197c6b2b4f6604bffe4 Mon Sep 17 00:00:00 2001 From: Kira Evans Date: Wed, 10 Apr 2019 12:38:46 -0700 Subject: [PATCH 3/6] save as tar file instead of directory --- starfish/segmentation_mask.py | 38 ++++++++++++------------- starfish/test/test_segmentation_mask.py | 6 ++-- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/starfish/segmentation_mask.py b/starfish/segmentation_mask.py index f178255a7..9a92c1def 100644 --- a/starfish/segmentation_mask.py +++ b/starfish/segmentation_mask.py @@ -1,7 +1,6 @@ +import io import itertools -import os -import os.path as osp -import shutil +import tarfile from typing import Dict, List, Sequence, Tuple, Union import numpy as np @@ -171,29 +170,28 @@ def from_disk(cls, path: str) -> "SegmentationMaskCollection": Collection of segmentation masks. """ masks = [] - for p in os.listdir(path): - mask = xr.open_dataarray(osp.join(path, p)) - masks.append(mask) + + with tarfile.open(path) as t: + for info in t.getmembers(): + f = t.extractfile(info.name) + print(f) + mask = xr.open_dataarray(f) + masks.append(mask) return cls(masks) - def save(self, path: str, overwrite: bool = False): + def save(self, path: str): """Save the segmentation masks to disk. Parameters ---------- path : str - Path of the directory to write to. - overwrite : bool - Whether to overwrite the directory if it exists. (default: False) + Path of the tar file to write to. """ - try: - os.mkdir(path) - except FileExistsError: - if not overwrite: - raise - shutil.rmtree(path, ignore_errors=True) - os.mkdir(path) - - for i, mask in enumerate(self._masks): - mask.to_netcdf(osp.join(path, str(i)), 'w') + with tarfile.open(path, 'w:gz') as t: + for i, mask in enumerate(self._masks): + data = mask.to_netcdf() + with io.BytesIO(data) as buff: + info = tarfile.TarInfo(name=str(i) + '.nc') + info.size = len(data) + t.addfile(tarinfo=info, fileobj=buff) diff --git a/starfish/test/test_segmentation_mask.py b/starfish/test/test_segmentation_mask.py index f99197076..e7f518a3b 100644 --- a/starfish/test/test_segmentation_mask.py +++ b/starfish/test/test_segmentation_mask.py @@ -1,4 +1,4 @@ -import shutil +import os import numpy as np import pytest @@ -115,11 +115,11 @@ def test_save_load(): masks = SegmentationMaskCollection.from_label_image(label_image, physical_ticks) - path = 'data' + path = 'data.tgz' try: masks.save(path) masks2 = SegmentationMaskCollection.from_disk(path) for m, m2 in zip(masks, masks2): assert np.array_equal(m, m2) finally: - shutil.rmtree(path, ignore_errors=True) + os.remove(path) From 06a3e24c7734f99273e3bbbb04ca680813577685 Mon Sep 17 00:00:00 2001 From: Kira Evans Date: Wed, 10 Apr 2019 17:10:28 -0700 Subject: [PATCH 4/6] update cli --- starfish/spots/_target_assignment/_base.py | 4 ++-- starfish/test/full_pipelines/cli/test_iss.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/starfish/spots/_target_assignment/_base.py b/starfish/spots/_target_assignment/_base.py index e08fbc62e..f74884087 100644 --- a/starfish/spots/_target_assignment/_base.py +++ b/starfish/spots/_target_assignment/_base.py @@ -3,11 +3,11 @@ from typing import Type import numpy as np -from skimage.io import imread from starfish.intensity_table.intensity_table import IntensityTable from starfish.pipeline.algorithmbase import AlgorithmBase from starfish.pipeline.pipelinecomponent import PipelineComponent +from starfish.segmentation_mask import SegmentationMaskCollection from starfish.util import click @@ -43,7 +43,7 @@ def _cli(ctx, label_image, intensities, output): component=TargetAssignment, output=output, intensity_table=IntensityTable.load(intensities), - label_image=imread(label_image) + label_image=SegmentationMaskCollection.from_disk(label_image) ) diff --git a/starfish/test/full_pipelines/cli/test_iss.py b/starfish/test/full_pipelines/cli/test_iss.py index 07794f59e..8e09d5cb0 100644 --- a/starfish/test/full_pipelines/cli/test_iss.py +++ b/starfish/test/full_pipelines/cli/test_iss.py @@ -131,7 +131,7 @@ def stages(self): "--nuclei", lambda tempdir, *args, **kwargs: os.path.join( tempdir, "filtered", "nuclei.json"), "-o", lambda tempdir, *args, **kwargs: os.path.join( - tempdir, "results", "label_image.png"), + tempdir, "results", "masks.tgz"), "Watershed", "--nuclei-threshold", ".16", "--input-threshold", ".22", @@ -141,7 +141,7 @@ def stages(self): "starfish", "target_assignment", "--label-image", lambda tempdir, *args, **kwargs: os.path.join( - tempdir, "results", "label_image.png"), + tempdir, "results", "masks.tgz"), "--intensities", lambda tempdir, *args, **kwargs: os.path.join( tempdir, "results", "spots.nc"), "--output", lambda tempdir, *args, **kwargs: os.path.join( From bfd6ec00c57fd1173f385140f788954772478374 Mon Sep 17 00:00:00 2001 From: Kira Evans Date: Thu, 11 Apr 2019 13:24:24 -0700 Subject: [PATCH 5/6] more concise validation using sets --- starfish/segmentation_mask.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/starfish/segmentation_mask.py b/starfish/segmentation_mask.py index 9a92c1def..2f2b96e16 100644 --- a/starfish/segmentation_mask.py +++ b/starfish/segmentation_mask.py @@ -1,7 +1,7 @@ import io import itertools import tarfile -from typing import Dict, List, Sequence, Tuple, Union +from typing import Dict, List, Set, Sequence, Tuple, Union import numpy as np import xarray as xr @@ -59,14 +59,13 @@ def validate_segmentation_mask(arr: xr.DataArray): raise TypeError(f"expected dtype of bool; got {arr.dtype}") axes, coords = _get_axes_names(arr.ndim) + dims = set(axes) - for dim in axes: - if dim not in arr.dims: - raise TypeError(f"no dimension '{dim}'") + if dims != set(arr.dims): + raise TypeError(f"missing dimensions '{dims.difference(arr.dims)}'") - for coord in itertools.chain(axes, coords): - if coord not in arr.coords: - raise TypeError(f"no coordinate '{coord}'") + if dims.union(coords) != set(arr.coords): + raise TypeError(f"missing coordinates '{dims.union(coords).difference(arr.coords)}'") class SegmentationMaskCollection: @@ -74,7 +73,7 @@ class SegmentationMaskCollection: Parameters ---------- - masks : list of xr.DataArray + masks : List[xr.DataArray] Segmentation masks. """ _masks: List[xr.DataArray] From c8df372917ebfb14eb097ac72b71f7d2b7896b37 Mon Sep 17 00:00:00 2001 From: Kira Evans Date: Mon, 15 Apr 2019 07:48:22 -0700 Subject: [PATCH 6/6] various syntax/organization changes --- starfish/segmentation_mask.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/starfish/segmentation_mask.py b/starfish/segmentation_mask.py index 2f2b96e16..38db00ff3 100644 --- a/starfish/segmentation_mask.py +++ b/starfish/segmentation_mask.py @@ -1,7 +1,6 @@ import io -import itertools import tarfile -from typing import Dict, List, Set, Sequence, Tuple, Union +from typing import Dict, List, Sequence, Tuple, Union import numpy as np import xarray as xr @@ -76,8 +75,6 @@ class SegmentationMaskCollection: masks : List[xr.DataArray] Segmentation masks. """ - _masks: List[xr.DataArray] - def __init__(self, masks: List[xr.DataArray]): for mask in masks: validate_segmentation_mask(mask) @@ -146,10 +143,13 @@ def from_label_image( i = dims.index(axis) coords[d.value] = (axis, c[prop.bbox[i]:prop.bbox[i + len(dims)]]) + name = str(label + 1) + name = name.zfill(len(str(len(props)))) # pad with zeros + mask = xr.DataArray(prop.image, dims=dims, coords=coords, - name=str(label + 1)) + name=name) masks.append(mask) return cls(masks) @@ -161,7 +161,7 @@ def from_disk(cls, path: str) -> "SegmentationMaskCollection": Parameters ---------- path : str - Path of the directory to instantiate from. + Path of the tar file to instantiate from. Returns ------- @@ -173,7 +173,6 @@ def from_disk(cls, path: str) -> "SegmentationMaskCollection": with tarfile.open(path) as t: for info in t.getmembers(): f = t.extractfile(info.name) - print(f) mask = xr.open_dataarray(f) masks.append(mask)