diff --git a/starfish/core/spots/AssignTargets/label.py b/starfish/core/spots/AssignTargets/label.py index e8d0fd187..4e940b2da 100644 --- a/starfish/core/spots/AssignTargets/label.py +++ b/starfish/core/spots/AssignTargets/label.py @@ -1,8 +1,10 @@ +import warnings + import numpy as np from starfish.core.intensity_table.decoded_intensity_table import DecodedIntensityTable from starfish.core.morphology.binary_mask import BinaryMaskCollection -from starfish.core.types import Features +from starfish.core.types import Axes, Features from ._base import AssignTargetsAlgorithm @@ -32,19 +34,32 @@ def _assign( decoded_intensities[Features.CELL_ID] = cell_ids + # it's 3D data. for _, mask in masks: + has_z_data = Axes.ZPLANE.value in mask.coords + if has_z_data: + z_min, z_max = float(mask.z.min()), float(mask.z.max()) + else: + warnings.warn( + "AssignTargets will require 3D masks in the future.", DeprecationWarning) + z_min, z_max = np.NINF, np.inf y_min, y_max = float(mask.y.min()), float(mask.y.max()) x_min, x_max = float(mask.x.min()), float(mask.x.max()) in_bbox = decoded_intensities.where( - (decoded_intensities.y >= y_min) + (decoded_intensities.z >= z_min) + & (decoded_intensities.z <= z_max) + & (decoded_intensities.y >= y_min) & (decoded_intensities.y <= y_max) & (decoded_intensities.x >= x_min) & (decoded_intensities.x <= x_max), drop=True ) - in_mask = mask.sel(y=in_bbox.y, x=in_bbox.x) + selectors = {'y': in_bbox.y, 'x': in_bbox.x} + if has_z_data: + selectors['z'] = in_bbox.z + in_mask = mask.sel(**selectors) spot_ids = in_bbox[Features.SPOT_ID][in_mask.values] decoded_intensities[Features.CELL_ID].loc[spot_ids] = mask.name