Skip to content

Commit

Permalink
Allow for intensity tables with labeled axes
Browse files Browse the repository at this point in the history
Currently, we use the offset of r/c/z to label the axes on the intensity table.  We should instead be using the axes labels from the ImageStack.  This PR makes that change in three key areas:

1. `IntensityTable.empty_intensity_table` now accepts the labels for the axes.
2. `IntensityTable.from_spot_data` now accepts the labels for the axes and verifies that the number of labels matches the intensity data.
3. `IntensityTable.from_image_stack` reads the labels for the zplane axes and assigns them correctly.

Test plan: Add a test that instantiates a labeled ImageStack and derives an IntensityTable from it.  Also ran `make test && make -j run-notebooks`.

Depends on #1178
Fixes #1168
  • Loading branch information
Tony Tung committed Apr 18, 2019
1 parent d574826 commit bd6e533
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 42 deletions.
7 changes: 6 additions & 1 deletion starfish/codebook/test/test_metric_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def intensity_table_factory(data: np.ndarray=np.array([[[0, 3], [4, 0]]])) -> In
columns=[Axes.ZPLANE, Axes.Y, Axes.X, Features.SPOT_RADIUS]
)

intensity_table = IntensityTable.from_spot_data(data, SpotAttributes(spot_attributes_data))
intensity_table = IntensityTable.from_spot_data(
data,
SpotAttributes(spot_attributes_data),
ch_values=np.arange(data.shape[1]),
round_values=np.arange(data.shape[2]),
)
return intensity_table


Expand Down
6 changes: 5 additions & 1 deletion starfish/codebook/test/test_normalize_code_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def intensity_table_factory() -> IntensityTable:
).T
spot_attributes = SpotAttributes(spot_attribute_data)

intensity_table = IntensityTable.from_spot_data(intensities, spot_attributes)
intensity_table = IntensityTable.from_spot_data(
intensities, spot_attributes,
ch_values=np.arange(intensities.shape[1]),
round_values=np.arange(intensities.shape[2]),
)
return intensity_table


Expand Down
6 changes: 5 additions & 1 deletion starfish/codebook/test/test_per_round_max_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ def intensity_table_factory(data: np.ndarray=np.array([[[0, 3], [4, 0]]])) -> In
)

spot_attributes = SpotAttributes(spot_attributes_data)
intensity_table = IntensityTable.from_spot_data(data, spot_attributes)
intensity_table = IntensityTable.from_spot_data(
data, spot_attributes,
ch_values=np.arange(data.shape[1]),
round_values=np.arange(data.shape[2]),
)
return intensity_table


Expand Down
3 changes: 1 addition & 2 deletions starfish/imagestack/imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Mapping,
Expand Down Expand Up @@ -1000,7 +999,7 @@ def num_zplanes(self):
"""Return the number of z_planes in the ImageStack"""
return self.xarray.sizes[Axes.ZPLANE]

def axis_labels(self, axis: Axes) -> Iterable[int]:
def axis_labels(self, axis: Axes) -> Sequence[int]:
"""Given an axis, return the sorted unique values for that axis in this ImageStack. For
instance, ``imagestack.axis_labels(Axes.ROUND)`` returns all the round ids in this
imagestack."""
Expand Down
76 changes: 53 additions & 23 deletions starfish/intensity_table/intensity_table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from itertools import product
from json import loads
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -72,22 +72,27 @@ class IntensityTable(xr.DataArray):

@staticmethod
def _build_xarray_coords(
spot_attributes: SpotAttributes, channel_index: np.ndarray, round_index: np.ndarray
spot_attributes: SpotAttributes,
channel_values: Sequence[int],
round_values: Sequence[int]
) -> Dict[str, np.ndarray]:
"""build coordinates for intensity-table"""
coordinates = {
k: (Features.AXIS, spot_attributes.data[k].values)
for k in spot_attributes.data}
coordinates.update({
Features.AXIS: np.arange(len(spot_attributes.data)),
Axes.CH.value: channel_index,
Axes.ROUND.value: round_index
Axes.CH.value: np.array(channel_values),
Axes.ROUND.value: np.array(round_values),
})
return coordinates

@classmethod
def zeros(
cls, spot_attributes: SpotAttributes, n_ch: int, n_round: int,
cls,
spot_attributes: SpotAttributes,
ch_values: Sequence[int],
round_values: Sequence[int],
) -> "IntensityTable":
"""
Create an empty intensity table with pre-set shape whose values are zero.
Expand All @@ -97,10 +102,12 @@ def zeros(
spot_attributes : SpotAttributes
Table containing spot metadata. Must contain the values specified in Axes.X,
Y, Z, and RADIUS.
n_ch : int
Number of channels measured in the imaging experiment.
n_round : int
Number of imaging rounds measured in the imaging experiment.
ch_values : Sequence[int]
The possible values for the channel number, in the order that they are in the ImageStack
5D tensor.
round_values : Sequence[int]
The possible values for the round number, in the order that they are in the ImageStack
5D tensor.
Returns
-------
Expand All @@ -111,11 +118,10 @@ def zeros(
if not isinstance(spot_attributes, SpotAttributes):
raise TypeError('parameter spot_attributes must be a starfish SpotAttributes object.')

channel_index = np.arange(n_ch)
round_index = np.arange(n_round)
data = np.zeros((spot_attributes.data.shape[0], n_ch, n_round))
data = np.zeros((spot_attributes.data.shape[0], len(ch_values), len(round_values)))
dims = (Features.AXIS, Axes.CH.value, Axes.ROUND.value)
coords = cls._build_xarray_coords(spot_attributes, channel_index, round_index)
coords = cls._build_xarray_coords(
spot_attributes, np.array(ch_values), round_values)

intensity_table = cls(
data=data, coords=coords, dims=dims,
Expand All @@ -125,7 +131,11 @@ def zeros(

@classmethod
def from_spot_data(
cls, intensities: Union[xr.DataArray, np.ndarray], spot_attributes: SpotAttributes,
cls,
intensities: Union[xr.DataArray, np.ndarray],
spot_attributes: SpotAttributes,
ch_values: Sequence[int],
round_values: Sequence[int],
*args, **kwargs) -> "IntensityTable":
"""
Creates an IntensityTable from a :code:`(features, channel, round)`
Expand All @@ -139,6 +149,11 @@ def from_spot_data(
spot_attributes : SpotAttributes
Table containing spot metadata. Must contain the values specified in Axes.X,
Y, Z, and RADIUS.
ch_values : Sequence[int]
The possible values for the channel number, in the order that they are in the ImageStack
5D tensor.
round_values : Sequence[int]
The possible values for the round number, in the order that they are in the ImageStack
args :
Additional arguments to pass to the xarray constructor.
kwargs :
Expand All @@ -155,13 +170,22 @@ def from_spot_data(
f'intensities must be a (features * ch * round) 3-d tensor. Provided intensities '
f'shape ({intensities.shape}) is invalid.')

if len(ch_values) != intensities.shape[1]:
raise ValueError(
f"The number of ch values ({len(ch_values)}) should be equal to intensities' "
f"shape[1] ({intensities.shape[1]})."
)

if len(round_values) != intensities.shape[2]:
raise ValueError(
f"The number of round values ({len(ch_values)}) should be equal to intensities' "
f"shape[2] ({intensities.shape[2]})."
)

if not isinstance(spot_attributes, SpotAttributes):
raise TypeError('parameter spot_attributes must be a starfish SpotAttributes object.')

coords = cls._build_xarray_coords(
spot_attributes,
np.arange(intensities.shape[1]),
np.arange(intensities.shape[2]))
coords = cls._build_xarray_coords(spot_attributes, ch_values, round_values)

dims = (Features.AXIS, Axes.CH.value, Axes.ROUND.value)

Expand Down Expand Up @@ -324,7 +348,8 @@ def synthetic_intensities(
data = preserve_float_range(data)
assert 0 < data.max() <= 1

intensities = cls.from_spot_data(data, spot_attributes)
intensities = cls.from_spot_data(
data, spot_attributes, np.arange(data.shape[1]), np.arange(data.shape[2]))
intensities[Features.TARGET] = (Features.AXIS, targets)

return intensities
Expand Down Expand Up @@ -360,10 +385,10 @@ def from_image_stack(
assert crop_y * 2 < image_stack.shape['y']
assert crop_x * 2 < image_stack.shape['x']

zmin = crop_z
zmin = image_stack.axis_labels(Axes.ZPLANE)[crop_z]
ymin = crop_y
xmin = crop_x
zmax = image_stack.shape['z'] - crop_z
zmax = image_stack.axis_labels(Axes.ZPLANE)[-crop_z - 1]
ymax = image_stack.shape['y'] - crop_y
xmax = image_stack.shape['x'] - crop_x
cropped_stack = image_stack.sel({Axes.ZPLANE: (zmin, zmax),
Expand All @@ -383,7 +408,7 @@ def from_image_stack(
-1, image_stack.num_chs, image_stack.num_rounds)

# IntensityTable pixel coordinates
z = np.arange(zmin, zmax)
z = image_stack.axis_labels(Axes.ZPLANE)
y = np.arange(ymin, ymax)
x = np.arange(xmin, xmax)

Expand All @@ -397,7 +422,12 @@ def from_image_stack(

pixel_coordinates = SpotAttributes(feature_attribute_data)

return IntensityTable.from_spot_data(intensity_data, pixel_coordinates)
return IntensityTable.from_spot_data(
intensity_data,
pixel_coordinates,
image_stack.axis_labels(Axes.CH),
image_stack.axis_labels(Axes.ROUND),
)

@staticmethod
def _process_overlaps(
Expand Down
4 changes: 2 additions & 2 deletions starfish/intensity_table/test/test_empty_intensity_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def test_intensity_table_can_be_created_from_spot_attributes():

intensities = IntensityTable.zeros(
spot_attributes,
n_ch=1,
n_round=3
ch_values=np.arange(1),
round_values=np.arange(3)
)

assert intensities.sizes[Axes.CH] == 1
Expand Down
11 changes: 11 additions & 0 deletions starfish/intensity_table/test/test_from_imagestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from starfish import ImageStack, IntensityTable
from starfish.imagestack.test import test_labeled_indices
from starfish.test.factories import (
codebook_intensities_image_for_single_synthetic_spot,
synthetic_spot_pass_through_stack,
Expand Down Expand Up @@ -54,3 +55,13 @@ def test_intensity_table_can_be_constructed_from_an_imagestack():
# the number of channels and rounds should match the ImageStack
assert intensities.sizes[Axes.CH.value] == c
assert intensities.sizes[Axes.ROUND.value] == r


def test_from_imagestack_labeled_indices():
# use the ImageStack with labeled indices from the test.
imagestack = test_labeled_indices.setup_imagestack()
intensity_table = IntensityTable.from_image_stack(imagestack)
assert np.array_equal(
intensity_table[Axes.CH.value], np.array(test_labeled_indices.CH_LABELS))
assert np.array_equal(
intensity_table[Axes.ROUND.value], np.array(test_labeled_indices.ROUND_LABELS))
20 changes: 16 additions & 4 deletions starfish/intensity_table/test/test_from_spot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,30 @@ def test_intensity_table_can_be_constructed_from_a_numpy_array_and_spot_attribut
"""
spot_attributes = spot_attribute_factory(3)
data = np.zeros(30).reshape(3, 5, 2)
intensities = IntensityTable.from_spot_data(data, spot_attributes)
intensities = IntensityTable.from_spot_data(
data, spot_attributes, np.arange(data.shape[1]), np.arange(data.shape[2]))

assert intensities.shape == data.shape
assert np.array_equal(intensities.values, data)


def test_from_spot_attributes_must_have_aligned_dimensions_spot_attributes_and_data():
@pytest.mark.parametrize(
"num_features, num_ch_values, num_round_values",
[
(2, 5, 2,),
(3, 4, 2,),
(3, 5, 1,),
]
)
def test_from_spot_attributes_must_have_aligned_dimensions_spot_attributes_and_data(
num_features, num_ch_values, num_round_values,
):
"""
Number of features must match number of SpotAttributes. Pass two attributes and 3 features and
verify a ValueError is raised.
"""
spot_attributes = spot_attribute_factory(2)
spot_attributes = spot_attribute_factory(num_features)
data = np.zeros(30).reshape(3, 5, 2)
with pytest.raises(ValueError):
IntensityTable.from_spot_data(data, spot_attributes)
IntensityTable.from_spot_data(
data, spot_attributes, np.arange(num_ch_values), np.arange(num_round_values))
16 changes: 8 additions & 8 deletions starfish/spots/_detector/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,22 @@ def measure_spot_intensities(
"""

# determine the shape of the intensity table
n_ch = data_image.shape[Axes.CH]
n_round = data_image.shape[Axes.ROUND]
ch_values = data_image.axis_labels(Axes.CH)
round_values = data_image.axis_labels(Axes.ROUND)

# construct the empty intensity table
intensity_table = IntensityTable.zeros(
spot_attributes=spot_attributes,
n_ch=n_ch,
n_round=n_round,
ch_values=ch_values,
round_values=round_values,
)

# if no spots were detected, return the empty IntensityTable
if intensity_table.sizes[Features.AXIS] == 0:
return intensity_table

# fill the intensity table
indices = product(range(n_ch), range(n_round))
indices = product(ch_values, round_values)
for c, r in indices:
image, _ = data_image.get_slice({Axes.CH: c, Axes.ROUND: r})
blob_intensities: pd.Series = measure_spot_intensity(
Expand Down Expand Up @@ -142,15 +142,15 @@ def concatenate_spot_attributes_to_intensities(
concatenated input SpotAttributes, converted to an IntensityTable object
"""
n_ch: int = max(inds[Axes.CH] for _, inds in spot_attributes) + 1
n_round: int = max(inds[Axes.ROUND] for _, inds in spot_attributes) + 1
ch_values: Sequence[int] = sorted(set(inds[Axes.CH] for _, inds in spot_attributes))
round_values: Sequence[int] = sorted(set(inds[Axes.ROUND] for _, inds in spot_attributes))

all_spots = pd.concat([sa.data for sa, inds in spot_attributes], sort=True)
# this drop call ensures only x, y, z, radius, and quality, are passed to the IntensityTable
features_coordinates = all_spots.drop(['spot_id', 'intensity'], axis=1)

intensity_table = IntensityTable.zeros(
SpotAttributes(features_coordinates), n_ch, n_round,
SpotAttributes(features_coordinates), ch_values, round_values,
)

i = 0
Expand Down

0 comments on commit bd6e533

Please sign in to comment.