Skip to content

Commit

Permalink
allow a list of label_name's in get_array_pair_by_coordinate
Browse files Browse the repository at this point in the history
Co-authored-by: Silvia Barbiero <[email protected]>
Co-authored-by: Charlotte Soneson <[email protected]>
Co-authored-by: Michael Stadler <[email protected]>
  • Loading branch information
3 people committed Aug 8, 2024
1 parent 7912864 commit 3e52e7a
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 77 deletions.
168 changes: 93 additions & 75 deletions src/ez_zarr/ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def get_array_by_coordinate(self,
return arr

def get_array_pair_by_coordinate(self,
label_name: str,
label_name: Union[str, list[str]],
upper_left_yx: Optional[tuple[int]]=None,
lower_right_yx: Optional[tuple[int]]=None,
size_yx: Optional[tuple[int]]=None,
Expand All @@ -684,7 +684,8 @@ def get_array_pair_by_coordinate(self,
necessary, the label array is resized to match the intensity array.
Parameters:
label_name (str): The name of the label image to be extracted.
label_name (str or list of str): The name(s) of the label image(s)
to be extracted.
upper_left_yx (tuple, optional): Tuple of (y, x) intensity image
coordinates for the upper-left (lower) coordinates defining the
region of interest.
Expand All @@ -706,17 +707,23 @@ def get_array_pair_by_coordinate(self,
By default, this is `None`, which will use `pyramid_level`.
Returns:
A tuple of two `numpy.ndarray` objects with the extracted intensity
and (possibly resized) label arrays.
A tuple of length two, with the first element corresponding to
a `numpy.ndarray` with the extracted intensity array and the
second element a dictionary with keys corresponding to the
`label_name` and the values corresponding to label arrays.
Label arrays are resized if necessary to match the intensity
array.
Examples:
Obtain the whole image and matching 'organoids' label arrays:
>>> img, lab = img.get_array_pair_by_coordinate(label_name = 'organoids')
"""
# digest arguments
assert isinstance(label_name, str) and label_name in self.label_names, (
f"Unknown label_name ({label_name}), should be one of "
if isinstance(label_name, str):
label_name = [label_name]
assert isinstance(label_name, list) and all([ln in self.label_names for ln in label_name]), (
f"Unknown label_name(s) ({', '.join([ln for ln in label_name if ln not in self.label_names])}), should be one of "
', '.join(self.label_names)
)
pyramid_level = self._digest_pyramid_level_argument(
Expand All @@ -726,85 +733,95 @@ def get_array_pair_by_coordinate(self,
if pyramid_level_coord is None:
pyramid_level_coord = pyramid_level

# find matching label pyramid level
# get intensity image scale
img_scale_spatial = self.get_scale(
pyramid_level=pyramid_level,
label_name=None,
spatial_axes_only=True
)
lab_scale_spatial_dict = {
pl: self.get_scale(pyramid_level=pl, label_name=label_name, spatial_axes_only=True) for pl in self.get_pyramid_levels(label_name=label_name)
}
# ... filter out label scales with higher resolution than the intensity image
lab_scale_spatial_dict = {pl: lab_scale_spatial for pl, lab_scale_spatial in lab_scale_spatial_dict.items() if all([lab_scale_spatial[i] >= img_scale_spatial[i] for i in range(len(lab_scale_spatial))])}
if len(lab_scale_spatial_dict) == 0:
raise ValueError(f"For the requested pyramid level ({pyramid_level}) of the intensity image, down-scaling of an available label ('{label_name}') would be required. Down-scaling of labels is not supported - try selecting a higher-resolution intensity image.")

nearest_scale_idx = np.argmin([np.mean(np.array(lab_scale_spatial_dict[pl]) / np.array(img_scale_spatial)) for pl in lab_scale_spatial_dict.keys()])
nearest_scale_pl = list(lab_scale_spatial_dict.keys())[nearest_scale_idx]
lab_scale_spatial = lab_scale_spatial_dict[nearest_scale_pl]

# calculate image corner points
imgpixel_upper_left_yx, imgpixel_lower_right_yx = self._digest_bounding_box(
upper_left_yx=upper_left_yx,
lower_right_yx=lower_right_yx,
size_yx=size_yx,
coordinate_unit=coordinate_unit,
label_name=None,
pyramid_level=pyramid_level,
pyramid_level_coord=pyramid_level_coord
)

# make sure that the dimensions are divisible by
# the yx scaling factor between intensity and label arrays
scalefact_yx = np.divide(lab_scale_spatial, img_scale_spatial)
imgpixel_upper_left_yx = tuple((np.floor_divide(imgpixel_upper_left_yx, scalefact_yx[-2:]) * scalefact_yx[-2:]))
imgpixel_lower_right_yx = tuple((np.floor_divide(imgpixel_lower_right_yx, scalefact_yx[-2:]) * scalefact_yx[-2:]))

# get intensity array
img_arr = np.array(self.get_array_by_coordinate(
upper_left_yx=imgpixel_upper_left_yx,
lower_right_yx=imgpixel_lower_right_yx,
size_yx=None,
coordinate_unit='pixel',
label_name=None,
pyramid_level=pyramid_level
))

# convert intensity coordiantes to label coordinates
labpixel_upper_left_yx = convert_coordinates(
imgpixel_upper_left_yx,
img_scale_spatial[-2:],
lab_scale_spatial[-2:]
)
labpixel_lower_right_yx = convert_coordinates(
imgpixel_lower_right_yx,
img_scale_spatial[-2:],
lab_scale_spatial[-2:]
)
# loop over label names
lab_arr_dict = {}

for lname in label_name:
# find matching label pyramid level
lab_scale_spatial_dict = {
pl: self.get_scale(pyramid_level=pl, label_name=lname, spatial_axes_only=True) for pl in self.get_pyramid_levels(label_name=lname)
}
# ... filter out label scales with higher resolution than the intensity image
lab_scale_spatial_dict = {pl: lab_scale_spatial for pl, lab_scale_spatial in lab_scale_spatial_dict.items() if all([lab_scale_spatial[i] >= img_scale_spatial[i] for i in range(len(lab_scale_spatial))])}
if len(lab_scale_spatial_dict) == 0:
raise ValueError(f"For the requested pyramid level ({pyramid_level}) of the intensity image, down-scaling of an available label ('{lname}') would be required. Down-scaling of labels is not supported - try selecting a higher-resolution intensity image.")

nearest_scale_idx = np.argmin([np.mean(np.array(lab_scale_spatial_dict[pl]) / np.array(img_scale_spatial)) for pl in lab_scale_spatial_dict.keys()])
nearest_scale_pl = list(lab_scale_spatial_dict.keys())[nearest_scale_idx]
lab_scale_spatial = lab_scale_spatial_dict[nearest_scale_pl]

# calculate image corner points
imgpixel_upper_left_yx, imgpixel_lower_right_yx = self._digest_bounding_box(
upper_left_yx=upper_left_yx,
lower_right_yx=lower_right_yx,
size_yx=size_yx,
coordinate_unit=coordinate_unit,
label_name=None,
pyramid_level=pyramid_level,
pyramid_level_coord=pyramid_level_coord
)

# get label array
lab_arr = np.array(self.get_array_by_coordinate(
upper_left_yx=labpixel_upper_left_yx,
lower_right_yx=labpixel_lower_right_yx,
size_yx=None,
coordinate_unit='pixel',
label_name=label_name,
pyramid_level=nearest_scale_pl
))

# resize label if needed (correct non-matching scales or rounding errors)
if lab_arr.shape[-2:] != img_arr.shape[-2:]:
warnings.warn(f"For the requested pyramid level ({pyramid_level}) of the intensity image, no matching label ('{label_name}') is available. Up-scaling the label using factor(s) {scalefact_yx}")
lab_arr = resize_image(
im=lab_arr,
output_shape=img_arr.shape[(img_arr.ndim-lab_arr.ndim):],
im_type='label',
number_nonspatial_axes=sum([int(s not in ['z','y','x']) for s in self.channel_info_labels[label_name]])
# make sure that the dimensions are divisible by
# the yx scaling factor between intensity and label arrays
scalefact_yx = np.divide(lab_scale_spatial, img_scale_spatial)
imgpixel_upper_left_yx = tuple((np.floor_divide(imgpixel_upper_left_yx, scalefact_yx[-2:]) * scalefact_yx[-2:]))
imgpixel_lower_right_yx = tuple((np.floor_divide(imgpixel_lower_right_yx, scalefact_yx[-2:]) * scalefact_yx[-2:]))

# get intensity array
img_arr = self.get_array_by_coordinate(
upper_left_yx=imgpixel_upper_left_yx,
lower_right_yx=imgpixel_lower_right_yx,
size_yx=None,
coordinate_unit='pixel',
label_name=None,
pyramid_level=pyramid_level,
as_NumPy=False
)

# convert intensity coordiantes to label coordinates
labpixel_upper_left_yx = convert_coordinates(
imgpixel_upper_left_yx,
img_scale_spatial[-2:],
lab_scale_spatial[-2:]
)
labpixel_lower_right_yx = convert_coordinates(
imgpixel_lower_right_yx,
img_scale_spatial[-2:],
lab_scale_spatial[-2:]
)

# get label array
lab_arr = np.array(self.get_array_by_coordinate(
upper_left_yx=labpixel_upper_left_yx,
lower_right_yx=labpixel_lower_right_yx,
size_yx=None,
coordinate_unit='pixel',
label_name=lname,
pyramid_level=nearest_scale_pl
))

# resize label if needed (correct non-matching scales or rounding errors)
if lab_arr.shape[-2:] != img_arr.shape[-2:]:
warnings.warn(f"For the requested pyramid level ({pyramid_level}) of the intensity image, no matching label ('{lname}') is available. Up-scaling the label using factor(s) {scalefact_yx}")
lab_arr = resize_image(
im=lab_arr,
output_shape=img_arr.shape[(img_arr.ndim-lab_arr.ndim):],
im_type='label',
number_nonspatial_axes=sum([int(s not in ['z','y','x']) for s in self.channel_info_labels[lname]])
)

# store label array in dictionary
lab_arr_dict[lname] = lab_arr

# return arrays
return tuple([img_arr, lab_arr])
return tuple([np.array(img_arr), lab_arr_dict])


def get_table(self,
Expand Down Expand Up @@ -1002,6 +1019,7 @@ def plot(self,
pyramid_level=pyramid_level,
pyramid_level_coord=pyramid_level_coord
)
lab = lab[label_name] # extract label array from dict

# calculate scalebar length in pixel in x direction
if scalebar_micrometer != 0:
Expand Down
24 changes: 22 additions & 2 deletions tests/test_ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def test_get_array_by_coordinate(img2d: ome_zarr.Image, img3d: ome_zarr.Image):
assert (img1b == img1a).all()
assert (img1c == img1a).all()

def test_get_array_pair_by_coordinate(img2d: ome_zarr.Image, tmpdir: str):
def test_get_array_pair_by_coordinate(tmpdir: str):
"""Test `ome_zarr.Image` object get_array_pair_by_coordinate() method."""

# using pyramid_level corresponding to a lower resolution intensity image
Expand All @@ -637,13 +637,33 @@ def test_get_array_pair_by_coordinate(img2d: ome_zarr.Image, tmpdir: str):
zattr['multiscales'][0]['datasets'] = zattr['multiscales'][0]['datasets'][:2]
with open(zattr_file, "w") as jsonfile:
json.dump(zattr, jsonfile, indent=4)
# ... plot
# ... extract array pair
imgtmp = ome_zarr.Image(str(tmpdir) + '/example_img')
with pytest.raises(Exception) as e_info:
imgtmp.get_array_pair_by_coordinate(pyramid_level='2', label_name='organoids')
# ... clean up
shutil.rmtree(str(tmpdir) + '/example_img')

# expected return values with multiple label_name values
# ... copy zarr fileset
assert tmpdir.check()
shutil.copytree('tests/example_data/plate_ones_mip.zarr/B/03/0',
str(tmpdir) + '/example_img')
assert tmpdir.join('/example_img/1').check()
shutil.copytree(str(tmpdir) + '/example_img/labels/organoids',
str(tmpdir) + '/example_img/labels/cells')
# ... initialize Image object
img2d = ome_zarr.Image(str(tmpdir) + '/example_img')
# ... extract array pair
img, lab = img2d.get_array_pair_by_coordinate(label_name=['organoids', 'cells'])
assert isinstance(img, np.ndarray)
assert isinstance(lab, dict)
assert list(lab.keys()) == ['organoids', 'cells']
assert isinstance(lab['organoids'], np.ndarray)
assert isinstance(lab['cells'], np.ndarray)
# ... clean up
shutil.rmtree(str(tmpdir) + '/example_img')

def test_get_table(img2d: ome_zarr.Image):
"""Test `Image.get_table()`."""

Expand Down

0 comments on commit 3e52e7a

Please sign in to comment.