Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 2D support for Cellpose & napari workflows task (closes #398) #403

Merged
merged 50 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
675dd61
Add 2D OME-Zarr support for Cellpose & napari workflow task
jluethi Jun 7, 2023
c326d26
Merge branch 'main' into 398_2D_support
jluethi Jun 7, 2023
a31916e
Cleanup
jluethi Jun 7, 2023
e6c3982
Cleanup
jluethi Jun 7, 2023
35b28c3
Update region loading to return 2D arrays for napari workflows when n…
jluethi Jun 8, 2023
cb3d38e
Cleanup
jluethi Jun 8, 2023
7f8a54c
Change the default of the load_region on return_as_3D
jluethi Jun 8, 2023
e03cc48
Add unit test for load_region
jluethi Jun 9, 2023
fef5e36
cleanup test_unit_ROIs
jluethi Jun 9, 2023
497e7d0
cleanup test_unit_ROIs
jluethi Jun 9, 2023
ee16a45
Merge branch 'main' into 398_2D_support
tcompa Jun 9, 2023
7c0089e
Add workaround for #420
jluethi Jun 9, 2023
a894bd4
Fix precommit issue
jluethi Jun 9, 2023
bba8043
Revert coordinate assumption of coordinateTransformations to work wit…
jluethi Jun 9, 2023
92c1789
Merge branch 'main' into 398_2D_support
tcompa Jun 12, 2023
02e66e6
Replace `Exception`s with `ValueError`s in `lib_zattrs_utils.py`
tcompa Jun 12, 2023
3ad869f
Improve error message in `lib_zattrs_utils.py`
tcompa Jun 12, 2023
63e04d2
Improve rescale_datasets (ref #420)
tcompa Jun 12, 2023
b828cb0
Clean up test_load_region with multiple parametrize loops
tcompa Jun 12, 2023
137d823
Better assignment of pixel_sizes in extract_zyx_pixel_sizes
tcompa Jun 12, 2023
a0ab7e7
Always use 4D scale transformations in create-ome-zarr tasks (ref #420)
tcompa Jun 12, 2023
5a5b86a
Fix bug introduced in 63e04d2, and improve error messages
tcompa Jun 12, 2023
a242e81
Add test_unit_zattrs_utils.py
tcompa Jun 12, 2023
c5e4939
Add trivial test
tcompa Jun 12, 2023
b310ce5
Add failure modes to new load_region (ref #398)
tcompa Jun 12, 2023
8657a6f
Fix coordinateTransformations of Zenodo OME-Zarrs (ref #420)
tcompa Jun 12, 2023
527b192
Better indentation of zattrs in tests
tcompa Jun 12, 2023
d260206
Add more logs
tcompa Jun 12, 2023
bad56c4
BROKEN add test_CYX_input (ref #398)
tcompa Jun 12, 2023
6dc68ae
Add validate_axes_and_coordinateTransformations to tests
tcompa Jun 12, 2023
1b7ed17
BROKEN Introduce validate_axes_and_coordinateTransformations in test_…
tcompa Jun 12, 2023
a63013e
BROKEN Add validate_axes_and_coordinateTransformations in test_workfl…
tcompa Jun 12, 2023
d9a0fe6
Add remove_channel_axis argument to rescale_datasets (ref #398)
tcompa Jun 12, 2023
313680c
Set `remove_channel_axis=True` when doing labeling in cellpose/napari…
tcompa Jun 12, 2023
d256bb2
Add warning for missing Z axis
tcompa Jun 12, 2023
f357505
Add f to f-string
tcompa Jun 12, 2023
479c472
Remove obsolete `raise` in tests
tcompa Jun 12, 2023
2556509
Update test_extract_zyx_pixel_sizes
tcompa Jun 12, 2023
e4cce32
Indent JSON output in tests
tcompa Jun 12, 2023
5a5ea4b
Update test due to update in patched_segment_ROI_overlapping_organoids
tcompa Jun 12, 2023
476f301
Remove debugging log
tcompa Jun 12, 2023
1c25faa
Rename tests/utils.py into tests/_validation.py
tcompa Jun 13, 2023
22c4980
Move some auxiliary test functions into `_zenodo_ome_zarrs.py` module
tcompa Jun 13, 2023
df2089c
Expose num_axes in `tests/_validation.py`
tcompa Jun 13, 2023
90c7528
Add test_napari_workflow_CYX
tcompa Jun 13, 2023
905af27
Add test_napari_workflow_CYX_wrong_dimensions
tcompa Jun 13, 2023
461c5de
Update test_expected_dimensions
tcompa Jun 13, 2023
801256a
Merge branch 'main' into 398_2D_support
tcompa Jun 13, 2023
2a30696
bump version 0.10.0a1 -> 0.10.0a2
tcompa Jun 13, 2023
09f0c8d
Add check about axes list not starting with "c", before removing chan…
tcompa Jun 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"2022, Friedrich Miescher Institute for Biomedical Research and "
"University of Zurich"
)
version = "0.10.0a1"
version = "0.10.0a2"
language = "en"

extensions = [
Expand Down
2 changes: 1 addition & 1 deletion fractal_tasks_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
)


__VERSION__ = "0.10.0a1"
__VERSION__ = "0.10.0a2"
__OME_NGFF_VERSION__ = "0.4"
50 changes: 49 additions & 1 deletion fractal_tasks_core/lib_regions_of_interest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
import logging
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

import anndata as ad
import dask.array as da
import numpy as np
import pandas as pd
import zarr


logger = logging.getLogger(__name__)


def prepare_FOV_ROI_table(
df: pd.DataFrame, metadata: list[str] = ["time"]
) -> ad.AnnData:
Expand Down Expand Up @@ -389,7 +395,7 @@ def is_ROI_table_valid(*, table_path: str, use_masks: bool) -> Optional[bool]:
# Soft constraint: the table can be used for masked loading (if not, return
# False)
attrs = zarr.group(table_path).attrs
logging.info(f"ROI table at {table_path} has attrs: {attrs}")
logger.info(f"ROI table at {table_path} has attrs: {attrs.asdict()}")
valid = set(("type", "region", "instance_key")).issubset(attrs.keys())
if valid:
valid = valid and attrs["type"] == "ngff:region_table"
Expand All @@ -398,3 +404,45 @@ def is_ROI_table_valid(*, table_path: str, use_masks: bool) -> Optional[bool]:
return True
else:
return False


def load_region(
data_zyx: da.array,
region: Tuple[slice, slice, slice],
compute=True,
return_as_3D=False,
) -> Union[da.array, np.array]:
"""
Load a region from a dask array

Can handle both 2D and 3D dask arrays as input and return them as is or
always as a 3D array

:param data_zyx: dask array, 2D or 3D
:param region: region to load, tuple of three slices (ZYX)
:param compute: whether to compute the result. If True, returns a numpy
array. If False, returns a dask array.
:return_as_3D: whether to return a 3D array, even if the input is 2D
:return: 3D array
"""

if len(region) != 3:
raise ValueError(
f"In `load_region`, `region` must have three elements "
f"(given: {len(region)})."
)

if len(data_zyx.shape) == 3:
img = data_zyx[region]
elif len(data_zyx.shape) == 2:
img = data_zyx[(region[1], region[2])]
if return_as_3D:
img = np.expand_dims(img, axis=0)
else:
raise ValueError(
f"Shape {data_zyx.shape} not supported for `load_region`"
)
if compute:
return img.compute()
else:
return img
52 changes: 36 additions & 16 deletions fractal_tasks_core/lib_zattrs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
Functions to handle .zattrs files and their contents
"""
import json
import logging
from typing import Any
from typing import Dict
from typing import List


logger = logging.getLogger(__name__)


def extract_zyx_pixel_sizes(zattrs_path: str, level: int = 0) -> List[float]:
"""
Load multiscales/datasets from .zattrs file and read the pixel sizes for a
Expand All @@ -40,12 +44,22 @@ def extract_zyx_pixel_sizes(zattrs_path: str, level: int = 0) -> List[float]:

# Check that there is a single multiscale
if len(multiscales) > 1:
raise Exception(f"ERROR: There are {len(multiscales)} multiscales")
raise ValueError(
f"ERROR: There are {len(multiscales)} multiscales"
)

# Check that Z axis is present, raise a warning otherwise
axes = [ax["name"] for ax in multiscales[0]["axes"]]
if "z" not in axes:
logger.warning(
f"Z axis is not present in {axes=}. This case may work "
"by accident, but it is not fully supported."
)

# Check that there are no datasets-global transformations
if "coordinateTransformations" in multiscales[0].keys():
raise NotImplementedError(
"global coordinateTransformations at the multiscales "
"Global coordinateTransformations at the multiscales "
"level are not currently supported"
)

Expand All @@ -56,22 +70,24 @@ def extract_zyx_pixel_sizes(zattrs_path: str, level: int = 0) -> List[float]:
transformations = datasets[level]["coordinateTransformations"]
for t in transformations:
if t["type"] == "scale":
pixel_sizes = t["scale"]
# FIXME: Using [-3:] indices is a hack to deal with the fact
# that the coordinationTransformation can contain additional
# entries (e.g. scaling for the channels)
# https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/420
pixel_sizes = t["scale"][-3:]
if min(pixel_sizes) < 1e-9:
raise Exception(
f"ERROR: pixel_sizes in {zattrs_path} are", pixel_sizes
raise ValueError(
f"pixel_sizes in {zattrs_path} are {pixel_sizes}"
)
return pixel_sizes

raise Exception(
"ERROR:"
f" no scale transformation found for level {level}"
f" in {zattrs_path}"
raise ValueError(
f"No scale transformation found for level {level} in {zattrs_path}"
)

except KeyError as e:
raise KeyError(
"extract_zyx_pixel_sizes_from_zattrs failed, for {zattrs_path}\n",
f"extract_zyx_pixel_sizes_from_zattrs failed, for {zattrs_path}\n",
e,
)

Expand All @@ -81,6 +97,7 @@ def rescale_datasets(
datasets: List[Dict],
coarsening_xy: int,
reference_level: int,
remove_channel_axis: bool = False,
) -> List[Dict]:
"""
Given a set of datasets (as per OME-NGFF specs), update their "scale"
Expand All @@ -90,6 +107,8 @@ def rescale_datasets(
:param datasets: list of datasets (as per OME-NGFF specs)
:param coarsening_xy: linear coarsening factor between subsequent levels
:param reference_level: TBD
:param remove_channel_axis: If ``True``, remove the first item of all
``scale`` transformations.
"""

# Construct rescaled datasets
Expand All @@ -107,12 +126,13 @@ def rescale_datasets(
new_transformations = []
for t in old_transformations:
if t["type"] == "scale":
new_t: Dict[str, Any] = {"type": "scale"}
new_t["scale"] = [
t["scale"][0],
t["scale"][1] * coarsening_xy**reference_level,
t["scale"][2] * coarsening_xy**reference_level,
]
new_t: Dict[str, Any] = t.copy()
# Rescale last two dimensions (that is, Y and X)
prefactor = coarsening_xy**reference_level
new_t["scale"][-2] = new_t["scale"][-2] * prefactor
new_t["scale"][-1] = new_t["scale"][-1] * prefactor
if remove_channel_axis:
new_t["scale"].pop(0)
new_transformations.append(new_t)
else:
new_transformations.append(t)
Expand Down
51 changes: 41 additions & 10 deletions fractal_tasks_core/tasks/cellpose_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
convert_ROI_table_to_indices,
)
from fractal_tasks_core.lib_regions_of_interest import is_ROI_table_valid
from fractal_tasks_core.lib_regions_of_interest import load_region
from fractal_tasks_core.lib_ROI_overlaps import find_overlaps_in_ROI_indices
from fractal_tasks_core.lib_ROI_overlaps import get_overlapping_pairs_3D
from fractal_tasks_core.lib_zattrs_utils import extract_zyx_pixel_sizes
Expand Down Expand Up @@ -385,9 +386,11 @@ def cellpose_segmentation(
full_res_pxl_sizes_zyx = extract_zyx_pixel_sizes(
f"{zarrurl}/.zattrs", level=0
)
logger.info(f"{full_res_pxl_sizes_zyx=}")
actual_res_pxl_sizes_zyx = extract_zyx_pixel_sizes(
f"{zarrurl}/.zattrs", level=level
)
logger.info(f"{actual_res_pxl_sizes_zyx=}")

# Heuristic to determine reset_origin # FIXME, see issue #339
if input_ROI_table in ["FOV_ROI_table", "well_ROI_table"]:
Expand Down Expand Up @@ -415,7 +418,7 @@ def cellpose_segmentation(
)

# Select 2D/3D behavior and set some parameters
do_3D = data_zyx.shape[0] > 1
do_3D = data_zyx.shape[0] > 1 and len(data_zyx.shape) == 3
if do_3D:
if anisotropy is None:
# Read pixel sizes from zattrs file
Expand Down Expand Up @@ -451,10 +454,17 @@ def cellpose_segmentation(
)

# Rescale datasets (only relevant for level>0)
if not multiscales[0]["axes"][0]["name"] == "c":
raise ValueError(
"Cannot set `remove_channel_axis=True` for multiscale "
f'metadata with axes={multiscales[0]["axes"]}. '
'First axis should have name "c".'
)
new_datasets = rescale_datasets(
datasets=multiscales[0]["datasets"],
coarsening_xy=coarsening_xy,
reference_level=level,
remove_channel_axis=True,
)

# Write zattrs for labels and for specific label
Expand Down Expand Up @@ -493,9 +503,18 @@ def cellpose_segmentation(
logger.info(f"Output label path: {zarrurl}/labels/{output_label_name}/0")
store = zarr.storage.FSStore(f"{zarrurl}/labels/{output_label_name}/0")
label_dtype = np.uint32

# Ensure that all output shapes & chunks are 3D (for 2D data: (1, y, x))
# https://github.com/fractal-analytics-platform/fractal-tasks-core/issues/398
shape = data_zyx.shape
if len(shape) == 2:
shape = (1, *shape)
chunks = data_zyx.chunksize
if len(chunks) == 2:
chunks = (1, *chunks)
mask_zarr = zarr.create(
shape=data_zyx.shape,
chunks=data_zyx.chunksize,
shape=shape,
chunks=chunks,
dtype=label_dtype,
store=store,
overwrite=False,
Expand Down Expand Up @@ -557,12 +576,26 @@ def cellpose_segmentation(
# Prepare single-channel or dual-channel input for cellpose
if wavelength_id_c2 or channel_label_c2:
# Dual channel mode, first channel is the membrane channel
img_np = np.zeros((2, *data_zyx[region].shape))
img_np[0, :, :, :] = data_zyx[region].compute()
img_np[1, :, :, :] = data_zyx_c2[region].compute()
img_1 = load_region(
data_zyx,
region,
compute=True,
return_as_3D=True,
)
img_np = np.zeros((2, *img_1.shape))
img_np[0, :, :, :] = img_1
img_np[1, :, :, :] = load_region(
data_zyx_c2,
region,
compute=True,
return_as_3D=True,
)
channels = [1, 2]
else:
img_np = np.expand_dims(data_zyx[region].compute(), axis=0)
img_np = np.expand_dims(
load_region(data_zyx, region, compute=True, return_as_3D=True),
axis=0,
)
channels = [0, 0]

# Prepare keyword arguments for segment_ROI function
Expand Down Expand Up @@ -620,7 +653,6 @@ def cellpose_segmentation(
)

if output_ROI_table:

bbox_df = array_to_bounding_box_table(
new_label_img, actual_res_pxl_sizes_zyx
)
Expand Down Expand Up @@ -656,7 +688,7 @@ def cellpose_segmentation(
overwrite=False,
num_levels=num_levels,
coarsening_xy=coarsening_xy,
chunksize=data_zyx.chunksize,
chunksize=chunks,
aggregation_function=np.max,
)

Expand Down Expand Up @@ -711,7 +743,6 @@ def cellpose_segmentation(


if __name__ == "__main__":

from fractal_tasks_core.tasks._utils import run_fractal_task

run_fractal_task(
Expand Down
1 change: 1 addition & 0 deletions fractal_tasks_core/tasks/create_ome_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def create_ome_zarr(
{
"type": "scale",
"scale": [
1,
pixel_size_z,
pixel_size_y
* coarsening_xy**ind_level,
Expand Down
1 change: 1 addition & 0 deletions fractal_tasks_core/tasks/create_ome_zarr_multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def create_ome_zarr_multiplex(
{
"type": "scale",
"scale": [
1,
pixel_size_z,
pixel_size_y
* coarsening_xy**ind_level,
Expand Down
Loading