diff --git a/funlib/persistence/__init__.py b/funlib/persistence/__init__.py index dbf9630..a413963 100644 --- a/funlib/persistence/__init__.py +++ b/funlib/persistence/__init__.py @@ -1,4 +1,4 @@ -from .arrays import Array, open_ds, prepare_ds # noqa +from .arrays import Array, open_ds, prepare_ds, open_ome_ds, prepare_ome_ds # noqa __version__ = "0.5.3" __version_info__ = tuple(int(i) for i in __version__.split(".")) diff --git a/funlib/persistence/arrays/__init__.py b/funlib/persistence/arrays/__init__.py index 5f70d38..779ea54 100644 --- a/funlib/persistence/arrays/__init__.py +++ b/funlib/persistence/arrays/__init__.py @@ -1,2 +1,3 @@ from .array import Array # noqa from .datasets import prepare_ds, open_ds # noqa +from .ome_datasets import prepare_ome_ds, open_ome_ds # noqa diff --git a/funlib/persistence/arrays/metadata.py b/funlib/persistence/arrays/metadata.py index d400415..b2eccc0 100644 --- a/funlib/persistence/arrays/metadata.py +++ b/funlib/persistence/arrays/metadata.py @@ -8,103 +8,6 @@ from funlib.geometry import Coordinate -class MetaDataFormat(BaseModel): - offset_attr: str = "offset" - voxel_size_attr: str = "voxel_size" - axis_names_attr: str = "axis_names" - units_attr: str = "units" - types_attr: str = "types" - - class Config: - extra = "forbid" - - def fetch(self, data: dict[str | int, Any], keys: Sequence[str]): - current_key: str | int - current_key, *keys = keys - try: - current_key = int(current_key) - except ValueError: - pass - if isinstance(current_key, int): - return self.fetch(data[current_key], keys) - if len(keys) == 0: - return data.get(current_key, None) - elif isinstance(data, list): - assert current_key == "{dim}", current_key - values = [] - for sub_data in data: - try: - values.append(self.fetch(sub_data, keys)) - except KeyError: - values.append(None) - return values - else: - return self.fetch(data[current_key], keys) - - def strip_channels(self, types: list[str], to_strip: list[Sequence]) -> None: - to_delete = [i for i, t in enumerate(types) if t not in ["space", "time"]][::-1] - for ll in to_strip: - if ll is not None and len(ll) == len(types): - for i in to_delete: - del ll[i] - - def parse( - self, - shape, - data: dict[str | int, Any], - offset=None, - voxel_size=None, - axis_names=None, - units=None, - types=None, - strict=False, - ): - offset = ( - offset - if offset is not None - else self.fetch(data, self.offset_attr.split("/")) - ) - voxel_size = ( - voxel_size - if voxel_size is not None - else self.fetch(data, self.voxel_size_attr.split("/")) - ) - axis_names = ( - axis_names - if axis_names is not None - else self.fetch(data, self.axis_names_attr.split("/")) - ) - units = ( - units if units is not None else self.fetch(data, self.units_attr.split("/")) - ) - types = ( - types if types is not None else self.fetch(data, self.types_attr.split("/")) - ) - - # we expect offset, voxel_size, and units to only apply to time and space dimensions - # so here we strip off values that are not space or time - if types is not None: - self.strip_channels(types, [offset, voxel_size, units]) - - offset = Coordinate(offset) if offset is not None else None - voxel_size = Coordinate(voxel_size) if voxel_size is not None else None - axis_names = list(axis_names) if axis_names is not None else None - units = list(units) if units is not None else None - types = list(types) if types is not None else None - - metadata = MetaData( - shape=shape, - offset=offset, - voxel_size=voxel_size, - axis_names=axis_names, - units=units, - types=types, - strict=strict, - ) - - return metadata - - class MetaData: def __init__( self, @@ -125,6 +28,35 @@ def __init__( self.validate(strict) + def interleave_physical( + self, physical: Sequence[int | str], non_physical: int | str | None + ) -> Sequence[int | str | None]: + interleaved = [] + physical_ind = 0 + for i, type in enumerate(self.types): + if type in ["space", "time"]: + interleaved.append(physical[physical_ind]) + physical_ind += 1 + else: + interleaved.append(non_physical) + return interleaved + + @property + def ome_scale(self) -> Sequence[int]: + return self.interleave_physical(self.voxel_size, 1) + + @property + def ome_translate(self) -> Sequence[int]: + assert self.offset % self.voxel_size == self.voxel_size * 0, ( + "funlib.persistence only supports ome-zarr with integer multiples of voxel_size as an offset." + f"offset: {self.offset}, voxel_size:{self.voxel_size}, offset % voxel_size: {self.offset % self.voxel_size}" + ) + return self.interleave_physical(self.offset / self.voxel_size, 0) + + @property + def ome_units(self) -> list[str | None]: + return self.interleave_physical(self.units, None) + @property def offset(self) -> Coordinate: return ( @@ -224,6 +156,192 @@ def validate(self, strict: bool): assert self.dims == self.physical_dims + self.channel_dims +class OME_MetaDataFormat(BaseModel): + class Config: + extra = "forbid" + + def fetch(self, data: dict[str | int, Any], keys: Sequence[str]): + current_key: str | int + current_key, *keys = keys + try: + current_key = int(current_key) + except ValueError: + pass + if isinstance(current_key, int): + return self.fetch(data[current_key], keys) + if len(keys) == 0: + return data.get(current_key, None) + elif isinstance(data, list): + assert current_key == "{dim}", current_key + values = [] + for sub_data in data: + try: + values.append(self.fetch(sub_data, keys)) + except KeyError: + values.append(None) + return values + else: + return self.fetch(data[current_key], keys) + + def strip_channels(self, types: list[str], to_strip: list[Sequence]) -> None: + to_delete = [i for i, t in enumerate(types) if t not in ["space", "time"]][::-1] + for ll in to_strip: + if ll is not None and len(ll) == len(types): + for i in to_delete: + del ll[i] + + def parse( + self, + shape, + data: dict[str | int, Any], + offset=None, + voxel_size=None, + axis_names=None, + units=None, + types=None, + strict=False, + ) -> MetaData: + offset = ( + offset + if offset is not None + else self.fetch(data, self.offset_attr.split("/")) + ) + voxel_size = ( + voxel_size + if voxel_size is not None + else self.fetch(data, self.voxel_size_attr.split("/")) + ) + axis_names = ( + axis_names + if axis_names is not None + else self.fetch(data, self.axis_names_attr.split("/")) + ) + units = ( + units if units is not None else self.fetch(data, self.units_attr.split("/")) + ) + types = ( + types if types is not None else self.fetch(data, self.types_attr.split("/")) + ) + + if types is not None: + self.strip_channels(types, [offset, voxel_size, units]) + + offset = Coordinate(offset) if offset is not None else None + voxel_size = Coordinate(voxel_size) if voxel_size is not None else None + axis_names = list(axis_names) if axis_names is not None else None + units = list(units) if units is not None else None + types = list(types) if types is not None else None + + metadata = MetaData( + shape=shape, + offset=offset, + voxel_size=voxel_size, + axis_names=axis_names, + units=units, + types=types, + strict=strict, + ) + + return metadata + + +class MetaDataFormat(BaseModel): + offset_attr: str = "offset" + voxel_size_attr: str = "voxel_size" + axis_names_attr: str = "axis_names" + units_attr: str = "units" + types_attr: str = "types" + + class Config: + extra = "forbid" + + def fetch(self, data: dict[str | int, Any], keys: Sequence[str]): + current_key: str | int + current_key, *keys = keys + try: + current_key = int(current_key) + except ValueError: + pass + if isinstance(current_key, int): + return self.fetch(data[current_key], keys) + if len(keys) == 0: + return data.get(current_key, None) + elif isinstance(data, list): + assert current_key == "{dim}", current_key + values = [] + for sub_data in data: + try: + values.append(self.fetch(sub_data, keys)) + except KeyError: + values.append(None) + return values + else: + return self.fetch(data[current_key], keys) + + def strip_channels(self, types: list[str], to_strip: list[Sequence]) -> None: + to_delete = [i for i, t in enumerate(types) if t not in ["space", "time"]][::-1] + for ll in to_strip: + if ll is not None and len(ll) == len(types): + for i in to_delete: + del ll[i] + + def parse( + self, + shape, + data: dict[str | int, Any], + offset=None, + voxel_size=None, + axis_names=None, + units=None, + types=None, + strict=False, + ) -> MetaData: + offset = ( + offset + if offset is not None + else self.fetch(data, self.offset_attr.split("/")) + ) + voxel_size = ( + voxel_size + if voxel_size is not None + else self.fetch(data, self.voxel_size_attr.split("/")) + ) + axis_names = ( + axis_names + if axis_names is not None + else self.fetch(data, self.axis_names_attr.split("/")) + ) + units = ( + units if units is not None else self.fetch(data, self.units_attr.split("/")) + ) + types = ( + types if types is not None else self.fetch(data, self.types_attr.split("/")) + ) + + # we expect offset, voxel_size, and units to only apply to time and space dimensions + # so here we strip off values that are not space or time + if types is not None: + self.strip_channels(types, [offset, voxel_size, units]) + + offset = Coordinate(offset) if offset is not None else None + voxel_size = Coordinate(voxel_size) if voxel_size is not None else None + axis_names = list(axis_names) if axis_names is not None else None + units = list(units) if units is not None else None + types = list(types) if types is not None else None + + metadata = MetaData( + shape=shape, + offset=offset, + voxel_size=voxel_size, + axis_names=axis_names, + units=units, + types=types, + strict=strict, + ) + + return metadata + + DEFAULT_METADATA_FORMAT = MetaDataFormat() LOCAL_PATHS = [Path("pyproject.toml"), Path("funlib_persistence.toml")] USER_PATHS = [ diff --git a/funlib/persistence/arrays/ome_datasets.py b/funlib/persistence/arrays/ome_datasets.py new file mode 100644 index 0000000..72ed664 --- /dev/null +++ b/funlib/persistence/arrays/ome_datasets.py @@ -0,0 +1,227 @@ +import logging +from typing import Sequence +from funlib.geometry import Coordinate + +import zarr +from iohub.ngff import open_ome_zarr, AxisMeta, TransformationMeta + +import dask.array as da +import numpy as np +from numpy.typing import DTypeLike + +from .array import Array +from .metadata import MetaData, OME_MetaDataFormat +from pathlib import Path +from itertools import chain + +logger = logging.getLogger(__name__) + + +def open_ome_ds( + store: Path, + name: str, + mode: str = "r", + **kwargs, +) -> Array: + """ + Open an ome-zarr dataset with common metadata that is useful for indexing with physcial coordinates. + + Args: + + store: + + See https://czbiohub-sf.github.io/iohub/main/api/ngff.html#iohub.open_ome_zarr + + name: + + The name of the dataset in your ome-zarr dataset. + + mode: + + See https://zarr.readthedocs.io/en/stable/api/convenience.html#zarr.convenience.open + + kwargs: + + See additional arguments available here: + https://czbiohub-sf.github.io/iohub/main/api/ngff.html#iohub.open_ome_zarr + + + Returns: + + A :class:`Array` supporting spatial indexing on your dataset. + """ + + assert (store / name).exists(), "Store does not exist!" + + ome_zarr = open_ome_zarr(store, mode=mode, **kwargs) + axes = ome_zarr.axes + axis_names = [axis.name for axis in axes] + units = [axis.unit for axis in axes if axis.unit is not None] + types = [axis.type for axis in axes if axis.type is not None] + + base_transform = ome_zarr.metadata.multiscales[0].coordinate_transformations + img_transforms = [ + ome_zarr.metadata.multiscales[0].datasets[i].coordinate_transformations + for i, dataset_meta in enumerate(ome_zarr.metadata.multiscales[0].datasets) + if dataset_meta.path == name + ][0] + + scales = [ + t.scale for t in chain(base_transform, img_transforms) if t.type == "scale" + ] + translations = [ + t.translation + for t in chain(base_transform, img_transforms) + if t.type == "translation" + ] + assert all( + all(np.isclose(scale, Coordinate(scale))) for scale in scales + ), f"funlib.persistence only supports integer scales: {scales}" + assert all( + all(np.isclose(translation, Coordinate(translation))) + for translation in translations + ), f"funlib.persistence only supports integer translations: {translations}" + scales = [Coordinate(s) for s in scales] + + # apply translations in order to get final scale/transform for this array + base_scale = scales[0] * 0 + 1 + base_translation = scales[0] * 0 + for t in chain(base_transform, img_transforms): + if t.type == "translation": + base_translation += base_scale * Coordinate(t.translation) + elif t.type == "scale": + base_scale *= Coordinate(t.scale) + + dataset = ome_zarr[name] + + metadata = OME_MetaDataFormat().parse( + dataset.shape, + {}, + offset=list(base_translation), + voxel_size=list(base_scale), + axis_names=axis_names, + units=units, + types=types, + ) + + return Array( + dataset, + metadata.offset, + metadata.voxel_size, + metadata.axis_names, + metadata.units, + metadata.types, + ) + + +def prepare_ome_ds( + store: Path, + name: str, + shape: Sequence[int], + dtype: DTypeLike, + chunk_shape: Sequence[int] | None = None, + offset: Sequence[int] | None = None, + voxel_size: Sequence[int] | None = None, + axis_names: Sequence[str] | None = None, + units: Sequence[str] | None = None, + types: Sequence[str] | None = None, + channel_names: list["str"] | None = None, + **kwargs, +) -> Array: + """Prepare an OME-Zarr dataset with common metadata that we use for indexing images with + spatial coordinates. + + Args: + + Store: + + See https://czbiohub-sf.github.io/iohub/main/api/ngff.html#iohub.open_ome_zarr + + shape: + + The shape of the dataset to create. For all dimensions, + including non-physical. + + chunk_shape: + + The shape of the chunks to use. If None, the default chunk shape + is used. + + offset: + + The offset of the dataset in physical coordinates. If None, the + default offset (0, ...) is used. + + voxel_size: + + The size of a voxel in physical coordinates. If None, the default + voxel size (1, ...) is used. + + axis_names: + + The names of the axes in the dataset. If None, the default axis + names ("d0", "d1", ...) are used. + + units: + + The units of the axes in the dataset. If None, the default units + ("", "", ...) are used. + + types: + + The types of the axes in the dataset. If None, the default types + ("space", "space", ...) are used. + + channel_names: + + The names of the channels in the dataset. If None, there must not be any + channels. If channels are present and no channel names are provided an exception + will be thrown. + + mode: + + The mode to open the dataset in. + See https://zarr.readthedocs.io/en/stable/api/creation.html#zarr.creation.open_array + + kwargs: + + See additional arguments available here: + https://czbiohub-sf.github.io/iohub/main/api/ngff.html#iohub.open_ome_zarr + + Returns: + + A :class:`Array` pointing to the newly created dataset. + """ + + assert not store.exists(), "Store already exists!" + + metadata = MetaData( + shape, Coordinate(offset), Coordinate(voxel_size), axis_names, units, types + ) + + axis_metadata = [ + AxisMeta(name=n, type=t, unit=u) + for n, t, u in zip(metadata.axis_names, metadata.types, metadata.ome_units) + ] + + # create the dataset + with open_ome_zarr( + store, mode="w", layout="fov", axes=axis_metadata, channel_names=channel_names + ) as ds: + transforms = [ + TransformationMeta(type="scale", scale=metadata.ome_scale), + TransformationMeta(type="translation", translation=metadata.ome_translate), + ] + + ds.create_zeros( + name=name, + shape=shape, + dtype=dtype, + chunks=chunk_shape, + transform=transforms, + ) + + # open array + array = open_ome_ds(store, name, mode="r+") + + return array