diff --git a/starfish/codebook/test/test_metric_decode.py b/starfish/codebook/test/test_metric_decode.py index d4c522f4b..cb6bca17d 100644 --- a/starfish/codebook/test/test_metric_decode.py +++ b/starfish/codebook/test/test_metric_decode.py @@ -22,7 +22,12 @@ def intensity_table_factory(data: np.ndarray=np.array([[[0, 3], [4, 0]]])) -> In columns=[Axes.ZPLANE, Axes.Y, Axes.X, Features.SPOT_RADIUS] ) - intensity_table = IntensityTable.from_spot_data(data, SpotAttributes(spot_attributes_data)) + intensity_table = IntensityTable.from_spot_data( + data, + SpotAttributes(spot_attributes_data), + ch_values=np.arange(data.shape[1]), + round_values=np.arange(data.shape[2]), + ) return intensity_table diff --git a/starfish/codebook/test/test_normalize_code_traces.py b/starfish/codebook/test/test_normalize_code_traces.py index d2820c111..44db58c64 100644 --- a/starfish/codebook/test/test_normalize_code_traces.py +++ b/starfish/codebook/test/test_normalize_code_traces.py @@ -20,7 +20,11 @@ def intensity_table_factory() -> IntensityTable: ).T spot_attributes = SpotAttributes(spot_attribute_data) - intensity_table = IntensityTable.from_spot_data(intensities, spot_attributes) + intensity_table = IntensityTable.from_spot_data( + intensities, spot_attributes, + ch_values=np.arange(intensities.shape[1]), + round_values=np.arange(intensities.shape[2]), + ) return intensity_table diff --git a/starfish/codebook/test/test_per_round_max_decode.py b/starfish/codebook/test/test_per_round_max_decode.py index 3a6c9b12c..4424950ca 100644 --- a/starfish/codebook/test/test_per_round_max_decode.py +++ b/starfish/codebook/test/test_per_round_max_decode.py @@ -21,7 +21,11 @@ def intensity_table_factory(data: np.ndarray=np.array([[[0, 3], [4, 0]]])) -> In ) spot_attributes = SpotAttributes(spot_attributes_data) - intensity_table = IntensityTable.from_spot_data(data, spot_attributes) + intensity_table = IntensityTable.from_spot_data( + data, spot_attributes, + ch_values=np.arange(data.shape[1]), + round_values=np.arange(data.shape[2]), + ) return intensity_table diff --git a/starfish/imagestack/imagestack.py b/starfish/imagestack/imagestack.py index 6015d60f3..d47b13647 100644 --- a/starfish/imagestack/imagestack.py +++ b/starfish/imagestack/imagestack.py @@ -8,7 +8,6 @@ from typing import ( Any, Callable, - Iterable, Iterator, List, Mapping, @@ -1000,7 +999,7 @@ def num_zplanes(self): """Return the number of z_planes in the ImageStack""" return self.xarray.sizes[Axes.ZPLANE] - def axis_labels(self, axis: Axes) -> Iterable[int]: + def axis_labels(self, axis: Axes) -> Sequence[int]: """Given an axis, return the sorted unique values for that axis in this ImageStack. For instance, ``imagestack.axis_labels(Axes.ROUND)`` returns all the round ids in this imagestack.""" diff --git a/starfish/intensity_table/intensity_table.py b/starfish/intensity_table/intensity_table.py index 9d816ca2c..1492468a4 100644 --- a/starfish/intensity_table/intensity_table.py +++ b/starfish/intensity_table/intensity_table.py @@ -1,6 +1,6 @@ from itertools import product from json import loads -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Sequence, Union import numpy as np import pandas as pd @@ -72,7 +72,9 @@ class IntensityTable(xr.DataArray): @staticmethod def _build_xarray_coords( - spot_attributes: SpotAttributes, channel_index: np.ndarray, round_index: np.ndarray + spot_attributes: SpotAttributes, + channel_values: Sequence[int], + round_values: Sequence[int] ) -> Dict[str, np.ndarray]: """build coordinates for intensity-table""" coordinates = { @@ -80,14 +82,17 @@ def _build_xarray_coords( for k in spot_attributes.data} coordinates.update({ Features.AXIS: np.arange(len(spot_attributes.data)), - Axes.CH.value: channel_index, - Axes.ROUND.value: round_index + Axes.CH.value: np.array(channel_values), + Axes.ROUND.value: np.array(round_values), }) return coordinates @classmethod def zeros( - cls, spot_attributes: SpotAttributes, n_ch: int, n_round: int, + cls, + spot_attributes: SpotAttributes, + ch_values: Sequence[int], + round_values: Sequence[int], ) -> "IntensityTable": """ Create an empty intensity table with pre-set shape whose values are zero. @@ -97,10 +102,12 @@ def zeros( spot_attributes : SpotAttributes Table containing spot metadata. Must contain the values specified in Axes.X, Y, Z, and RADIUS. - n_ch : int - Number of channels measured in the imaging experiment. - n_round : int - Number of imaging rounds measured in the imaging experiment. + ch_values : Sequence[int] + The possible values for the channel number, in the order that they are in the ImageStack + 5D tensor. + round_values : Sequence[int] + The possible values for the round number, in the order that they are in the ImageStack + 5D tensor. Returns ------- @@ -111,11 +118,10 @@ def zeros( if not isinstance(spot_attributes, SpotAttributes): raise TypeError('parameter spot_attributes must be a starfish SpotAttributes object.') - channel_index = np.arange(n_ch) - round_index = np.arange(n_round) - data = np.zeros((spot_attributes.data.shape[0], n_ch, n_round)) + data = np.zeros((spot_attributes.data.shape[0], len(ch_values), len(round_values))) dims = (Features.AXIS, Axes.CH.value, Axes.ROUND.value) - coords = cls._build_xarray_coords(spot_attributes, channel_index, round_index) + coords = cls._build_xarray_coords( + spot_attributes, np.array(ch_values), round_values) intensity_table = cls( data=data, coords=coords, dims=dims, @@ -125,7 +131,11 @@ def zeros( @classmethod def from_spot_data( - cls, intensities: Union[xr.DataArray, np.ndarray], spot_attributes: SpotAttributes, + cls, + intensities: Union[xr.DataArray, np.ndarray], + spot_attributes: SpotAttributes, + ch_values: Sequence[int], + round_values: Sequence[int], *args, **kwargs) -> "IntensityTable": """ Creates an IntensityTable from a :code:`(features, channel, round)` @@ -139,6 +149,11 @@ def from_spot_data( spot_attributes : SpotAttributes Table containing spot metadata. Must contain the values specified in Axes.X, Y, Z, and RADIUS. + ch_values : Sequence[int] + The possible values for the channel number, in the order that they are in the ImageStack + 5D tensor. + round_values : Sequence[int] + The possible values for the round number, in the order that they are in the ImageStack args : Additional arguments to pass to the xarray constructor. kwargs : @@ -155,13 +170,22 @@ def from_spot_data( f'intensities must be a (features * ch * round) 3-d tensor. Provided intensities ' f'shape ({intensities.shape}) is invalid.') + if len(ch_values) != intensities.shape[1]: + raise ValueError( + f"The number of ch values ({len(ch_values)}) should be equal to intensities' " + f"shape[1] ({intensities.shape[1]})." + ) + + if len(round_values) != intensities.shape[2]: + raise ValueError( + f"The number of round values ({len(ch_values)}) should be equal to intensities' " + f"shape[2] ({intensities.shape[2]})." + ) + if not isinstance(spot_attributes, SpotAttributes): raise TypeError('parameter spot_attributes must be a starfish SpotAttributes object.') - coords = cls._build_xarray_coords( - spot_attributes, - np.arange(intensities.shape[1]), - np.arange(intensities.shape[2])) + coords = cls._build_xarray_coords(spot_attributes, ch_values, round_values) dims = (Features.AXIS, Axes.CH.value, Axes.ROUND.value) @@ -324,7 +348,8 @@ def synthetic_intensities( data = preserve_float_range(data) assert 0 < data.max() <= 1 - intensities = cls.from_spot_data(data, spot_attributes) + intensities = cls.from_spot_data( + data, spot_attributes, np.arange(data.shape[1]), np.arange(data.shape[2])) intensities[Features.TARGET] = (Features.AXIS, targets) return intensities @@ -360,10 +385,10 @@ def from_image_stack( assert crop_y * 2 < image_stack.shape['y'] assert crop_x * 2 < image_stack.shape['x'] - zmin = crop_z + zmin = image_stack.axis_labels(Axes.ZPLANE)[crop_z] ymin = crop_y xmin = crop_x - zmax = image_stack.shape['z'] - crop_z + zmax = image_stack.axis_labels(Axes.ZPLANE)[-crop_z - 1] ymax = image_stack.shape['y'] - crop_y xmax = image_stack.shape['x'] - crop_x cropped_stack = image_stack.sel({Axes.ZPLANE: (zmin, zmax), @@ -383,7 +408,7 @@ def from_image_stack( -1, image_stack.num_chs, image_stack.num_rounds) # IntensityTable pixel coordinates - z = np.arange(zmin, zmax) + z = image_stack.axis_labels(Axes.ZPLANE) y = np.arange(ymin, ymax) x = np.arange(xmin, xmax) @@ -397,7 +422,12 @@ def from_image_stack( pixel_coordinates = SpotAttributes(feature_attribute_data) - return IntensityTable.from_spot_data(intensity_data, pixel_coordinates) + return IntensityTable.from_spot_data( + intensity_data, + pixel_coordinates, + image_stack.axis_labels(Axes.CH), + image_stack.axis_labels(Axes.ROUND), + ) @staticmethod def _process_overlaps( diff --git a/starfish/intensity_table/test/test_empty_intensity_table.py b/starfish/intensity_table/test/test_empty_intensity_table.py index ee9db9ffe..7f7e9035a 100644 --- a/starfish/intensity_table/test/test_empty_intensity_table.py +++ b/starfish/intensity_table/test/test_empty_intensity_table.py @@ -29,8 +29,8 @@ def test_intensity_table_can_be_created_from_spot_attributes(): intensities = IntensityTable.zeros( spot_attributes, - n_ch=1, - n_round=3 + ch_values=np.arange(1), + round_values=np.arange(3) ) assert intensities.sizes[Axes.CH] == 1 diff --git a/starfish/intensity_table/test/test_from_imagestack.py b/starfish/intensity_table/test/test_from_imagestack.py index dea27c7e0..5e9ebb6b9 100644 --- a/starfish/intensity_table/test/test_from_imagestack.py +++ b/starfish/intensity_table/test/test_from_imagestack.py @@ -5,6 +5,7 @@ import numpy as np from starfish import ImageStack, IntensityTable +from starfish.imagestack.test import test_labeled_indices from starfish.test.factories import ( codebook_intensities_image_for_single_synthetic_spot, synthetic_spot_pass_through_stack, @@ -54,3 +55,13 @@ def test_intensity_table_can_be_constructed_from_an_imagestack(): # the number of channels and rounds should match the ImageStack assert intensities.sizes[Axes.CH.value] == c assert intensities.sizes[Axes.ROUND.value] == r + + +def test_from_imagestack_labeled_indices(): + # use the ImageStack with labeled indices from the test. + imagestack = test_labeled_indices.setup_imagestack() + intensity_table = IntensityTable.from_image_stack(imagestack) + assert np.array_equal( + intensity_table[Axes.CH.value], np.array(test_labeled_indices.CH_LABELS)) + assert np.array_equal( + intensity_table[Axes.ROUND.value], np.array(test_labeled_indices.ROUND_LABELS)) diff --git a/starfish/intensity_table/test/test_from_spot_data.py b/starfish/intensity_table/test/test_from_spot_data.py index 19966bf82..9a456aa86 100644 --- a/starfish/intensity_table/test/test_from_spot_data.py +++ b/starfish/intensity_table/test/test_from_spot_data.py @@ -30,18 +30,30 @@ def test_intensity_table_can_be_constructed_from_a_numpy_array_and_spot_attribut """ spot_attributes = spot_attribute_factory(3) data = np.zeros(30).reshape(3, 5, 2) - intensities = IntensityTable.from_spot_data(data, spot_attributes) + intensities = IntensityTable.from_spot_data( + data, spot_attributes, np.arange(data.shape[1]), np.arange(data.shape[2])) assert intensities.shape == data.shape assert np.array_equal(intensities.values, data) -def test_from_spot_attributes_must_have_aligned_dimensions_spot_attributes_and_data(): +@pytest.mark.parametrize( + "num_features, num_ch_values, num_round_values", + [ + (2, 5, 2,), + (3, 4, 2,), + (3, 5, 1,), + ] +) +def test_from_spot_attributes_must_have_aligned_dimensions_spot_attributes_and_data( + num_features, num_ch_values, num_round_values, +): """ Number of features must match number of SpotAttributes. Pass two attributes and 3 features and verify a ValueError is raised. """ - spot_attributes = spot_attribute_factory(2) + spot_attributes = spot_attribute_factory(num_features) data = np.zeros(30).reshape(3, 5, 2) with pytest.raises(ValueError): - IntensityTable.from_spot_data(data, spot_attributes) + IntensityTable.from_spot_data( + data, spot_attributes, np.arange(num_ch_values), np.arange(num_round_values)) diff --git a/starfish/spots/_detector/detect.py b/starfish/spots/_detector/detect.py index f8dc9390f..4b9dd05d1 100644 --- a/starfish/spots/_detector/detect.py +++ b/starfish/spots/_detector/detect.py @@ -94,14 +94,14 @@ def measure_spot_intensities( """ # determine the shape of the intensity table - n_ch = data_image.shape[Axes.CH] - n_round = data_image.shape[Axes.ROUND] + ch_values = data_image.axis_labels(Axes.CH) + round_values = data_image.axis_labels(Axes.ROUND) # construct the empty intensity table intensity_table = IntensityTable.zeros( spot_attributes=spot_attributes, - n_ch=n_ch, - n_round=n_round, + ch_values=ch_values, + round_values=round_values, ) # if no spots were detected, return the empty IntensityTable @@ -109,7 +109,7 @@ def measure_spot_intensities( return intensity_table # fill the intensity table - indices = product(range(n_ch), range(n_round)) + indices = product(ch_values, round_values) for c, r in indices: image, _ = data_image.get_slice({Axes.CH: c, Axes.ROUND: r}) blob_intensities: pd.Series = measure_spot_intensity( @@ -142,15 +142,15 @@ def concatenate_spot_attributes_to_intensities( concatenated input SpotAttributes, converted to an IntensityTable object """ - n_ch: int = max(inds[Axes.CH] for _, inds in spot_attributes) + 1 - n_round: int = max(inds[Axes.ROUND] for _, inds in spot_attributes) + 1 + ch_values: Sequence[int] = sorted(set(inds[Axes.CH] for _, inds in spot_attributes)) + round_values: Sequence[int] = sorted(set(inds[Axes.ROUND] for _, inds in spot_attributes)) all_spots = pd.concat([sa.data for sa, inds in spot_attributes], sort=True) # this drop call ensures only x, y, z, radius, and quality, are passed to the IntensityTable features_coordinates = all_spots.drop(['spot_id', 'intensity'], axis=1) intensity_table = IntensityTable.zeros( - SpotAttributes(features_coordinates), n_ch, n_round, + SpotAttributes(features_coordinates), ch_values, round_values, ) i = 0