Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace label images with segmentation masks #1135

Merged
merged 6 commits into from
Apr 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions notebooks/ISS_Pipeline_-_Breast_-_1_FOV.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@
" input_threshold=stain_thresh,\n",
" min_distance=min_dist\n",
")\n",
"label_image = seg.run(registered_image, nuclei)\n",
"masks = seg.run(registered_image, nuclei)\n",
"seg.show()"
]
},
Expand All @@ -392,7 +392,7 @@
"source": [
"from starfish.spots import TargetAssignment\n",
"al = TargetAssignment.Label()\n",
"labeled = al.run(label_image, decoded)"
"labeled = al.run(masks, decoded)"
]
},
{
Expand Down Expand Up @@ -513,4 +513,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
4 changes: 2 additions & 2 deletions notebooks/py/ISS_Pipeline_-_Breast_-_1_FOV.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@
input_threshold=stain_thresh,
min_distance=min_dist
)
label_image = seg.run(registered_image, nuclei)
masks = seg.run(registered_image, nuclei)
seg.show()
# EPY: END code

Expand All @@ -251,7 +251,7 @@
# EPY: START code
from starfish.spots import TargetAssignment
al = TargetAssignment.Label()
labeled = al.run(label_image, decoded)
labeled = al.run(masks, decoded)
# EPY: END code

# EPY: START code
Expand Down
19 changes: 11 additions & 8 deletions starfish/image/_segmentation/_base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from abc import abstractmethod
from typing import Type

from skimage.io import imsave

from starfish.imagestack.imagestack import ImageStack
from starfish.pipeline import PipelineComponent
from starfish.pipeline.algorithmbase import AlgorithmBase
from starfish.segmentation_mask import SegmentationMaskCollection
from starfish.util import click


Expand All @@ -24,10 +23,10 @@ def _cli_run(cls, ctx, instance):
pri_stack = ctx.obj["primary_images"]
nuc_stack = ctx.obj["nuclei"]

label_image = instance.run(pri_stack, nuc_stack)
masks = instance.run(pri_stack, nuc_stack)

print(f"Writing label image to {output}")
imsave(output, label_image)
print(f"Writing masks to {output}")
masks.save(output)

@staticmethod
@click.group(COMPONENT_NAME)
Expand All @@ -50,8 +49,12 @@ class SegmentationAlgorithmBase(AlgorithmBase):
@classmethod
def get_pipeline_component_class(cls) -> Type[PipelineComponent]:
return Segmentation

@abstractmethod
def run(self, primary_image_stack: ImageStack, nuclei_stack: ImageStack, *args):
"""Performs registration on the stack provided."""
def run(
self,
primary_image_stack: ImageStack,
nuclei_stack: ImageStack,
*args
) -> SegmentationMaskCollection:
"""Performs segmentation on the stack provided."""
raise NotImplementedError()
22 changes: 16 additions & 6 deletions starfish/image/_segmentation/watershed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from starfish.image._filter.util import bin_open, bin_thresh
from starfish.imagestack.imagestack import ImageStack
from starfish.types import Axes, Number
from starfish.segmentation_mask import SegmentationMaskCollection
from starfish.types import Axes, Coordinates, Number
from starfish.util import click
from ._base import SegmentationAlgorithmBase

Expand Down Expand Up @@ -49,7 +50,12 @@ def __init__(
self.min_distance = min_distance
self._segmentation_instance: Optional[_WatershedSegmenter] = None

def run(self, primary_images: ImageStack, nuclei: ImageStack, *args) -> np.ndarray:
def run(
self,
primary_images: ImageStack,
nuclei: ImageStack,
*args
kne42 marked this conversation as resolved.
Show resolved Hide resolved
) -> SegmentationMaskCollection:
"""Segments nuclei in 2-d using a nuclei ImageStack

Primary images are used to expand the nuclear mask, but only in cases where there are
Expand All @@ -64,9 +70,8 @@ def run(self, primary_images: ImageStack, nuclei: ImageStack, *args) -> np.ndarr

Returns
-------
np.ndarray :
label image where each cell is labeled by a different positive integer value. 0
implies that a pixel is not part of a cell.
masks : SegmentationMaskCollection
binary masks segmenting each cell
"""

# create a 'stain' for segmentation
Expand All @@ -88,7 +93,12 @@ def run(self, primary_images: ImageStack, nuclei: ImageStack, *args) -> np.ndarr
disk_size_mask, self.min_distance
)

return label_image
# we max-projected and squeezed the Z-plane so label_image.ndim == 2
physical_ticks = {coord: nuclei.xarray.coords[coord.value].data
for coord in (Coordinates.Y, Coordinates.X)}

return SegmentationMaskCollection.from_label_image(label_image,
physical_ticks)

def show(self, figsize: Tuple[int, int]=(10, 10)) -> None:
if isinstance(self._segmentation_instance, _WatershedSegmenter):
Expand Down
195 changes: 195 additions & 0 deletions starfish/segmentation_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import io
import tarfile
from typing import Dict, List, Sequence, Tuple, Union

import numpy as np
import xarray as xr
from skimage.measure import regionprops

from starfish.types import Axes, Coordinates


AXES = [a.value for a in Axes if a not in (Axes.ROUND, Axes.CH)]
COORDS = [c.value for c in Coordinates]


def _get_axes_names(ndim: int) -> Tuple[List[str], List[str]]:
"""Get needed axes names given the number of dimensions.

Parameters
----------
ndim : int
Number of dimensions.

Returns
-------
axes : List[str]
Axes names.
coords : List[str]
Coordinates names.
"""
if ndim == 2:
axes = [axis for axis in AXES if axis != Axes.ZPLANE.value]
coords = [coord for coord in COORDS if coord != Coordinates.Z.value]
elif ndim == 3:
axes = AXES
coords = COORDS
else:
raise TypeError('expected 2- or 3-D image')

return axes, coords


def validate_segmentation_mask(arr: xr.DataArray):
"""Validate if the given array is a segmentation mask.

Parameters
----------
arr : xr.DataArray
Array to check.
"""
if not isinstance(arr, xr.DataArray):
raise TypeError(f"expected DataArray; got {type(arr)}")

if arr.ndim not in (2, 3):
raise TypeError(f"expected 2 or 3 dimensions; got {arr.ndim}")

if arr.dtype != np.bool:
raise TypeError(f"expected dtype of bool; got {arr.dtype}")

axes, coords = _get_axes_names(arr.ndim)
dims = set(axes)

if dims != set(arr.dims):
raise TypeError(f"missing dimensions '{dims.difference(arr.dims)}'")

if dims.union(coords) != set(arr.coords):
raise TypeError(f"missing coordinates '{dims.union(coords).difference(arr.coords)}'")


class SegmentationMaskCollection:
"""Collection of binary segmentation masks with a list-like access pattern.

Parameters
----------
masks : List[xr.DataArray]
Segmentation masks.
"""
def __init__(self, masks: List[xr.DataArray]):
for mask in masks:
validate_segmentation_mask(mask)

self._masks = masks

def __getitem__(self, index):
return self._masks[index]

def __iter__(self):
return iter(self._masks)

def __len__(self):
return len(self._masks)

def append(self, mask: xr.DataArray):
"""Add an existing segmentation mask.

Parameters
----------
arr : xr.DataArray
Segmentation mask.
"""
validate_segmentation_mask(mask)
self._masks.append(mask)

@classmethod
def from_label_image(
cls,
label_image: np.ndarray,
physical_ticks: Dict[Coordinates, Sequence[float]]
) -> "SegmentationMaskCollection":
"""Creates segmentation masks from a label image.

Parameters
----------
label_image : int array
Integer array where each integer corresponds to a region.
physical_ticks : Dict[Coordinates, Sequence[float]]
Physical coordinates for each axis.

Returns
-------
masks : SegmentationMaskCollection
Masks generated from the label image.
"""
props = regionprops(label_image)

dims, _ = _get_axes_names(label_image.ndim)

masks: List[xr.DataArray] = []

coords: Dict[str, Union[list, Tuple[str, Sequence]]]

# for each region (and its properties):
for label, prop in enumerate(props):
kne42 marked this conversation as resolved.
Show resolved Hide resolved
# create pixel coordinate labels from the bounding box
# to preserve spatial indexing relative to the original image
coords = {d: list(range(prop.bbox[i], prop.bbox[i + len(dims)]))
for i, d in enumerate(dims)}

# create physical coordinate labels by taking the overlapping
# subset from the full span of labels
for d, c in physical_ticks.items():
axis = d.value[0]
i = dims.index(axis)
coords[d.value] = (axis, c[prop.bbox[i]:prop.bbox[i + len(dims)]])

name = str(label + 1)
name = name.zfill(len(str(len(props)))) # pad with zeros

mask = xr.DataArray(prop.image,
dims=dims,
coords=coords,
name=name)
masks.append(mask)

return cls(masks)

@classmethod
def from_disk(cls, path: str) -> "SegmentationMaskCollection":
"""Load the collection from disk.

Parameters
----------
path : str
Path of the tar file to instantiate from.

Returns
-------
masks : SegmentationMaskCollection
Collection of segmentation masks.
"""
masks = []

with tarfile.open(path) as t:
for info in t.getmembers():
f = t.extractfile(info.name)
mask = xr.open_dataarray(f)
masks.append(mask)

return cls(masks)

def save(self, path: str):
"""Save the segmentation masks to disk.

Parameters
----------
path : str
Path of the tar file to write to.
"""
with tarfile.open(path, 'w:gz') as t:
for i, mask in enumerate(self._masks):
data = mask.to_netcdf()
with io.BytesIO(data) as buff:
info = tarfile.TarInfo(name=str(i) + '.nc')
info.size = len(data)
t.addfile(tarinfo=info, fileobj=buff)
4 changes: 2 additions & 2 deletions starfish/spots/_target_assignment/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Type

import numpy as np
from skimage.io import imread

from starfish.intensity_table.intensity_table import IntensityTable
from starfish.pipeline.algorithmbase import AlgorithmBase
from starfish.pipeline.pipelinecomponent import PipelineComponent
from starfish.segmentation_mask import SegmentationMaskCollection
from starfish.util import click


Expand Down Expand Up @@ -43,7 +43,7 @@ def _cli(ctx, label_image, intensities, output):
component=TargetAssignment,
output=output,
intensity_table=IntensityTable.load(intensities),
label_image=imread(label_image)
label_image=SegmentationMaskCollection.from_disk(label_image)
)


Expand Down
Loading