Skip to content

Commit

Permalink
add ome_zarr prepare/open methods
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Dec 9, 2024
1 parent 7b3c728 commit 0d81ef3
Show file tree
Hide file tree
Showing 4 changed files with 444 additions and 98 deletions.
2 changes: 1 addition & 1 deletion funlib/persistence/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .arrays import Array, open_ds, prepare_ds # noqa
from .arrays import Array, open_ds, prepare_ds, open_ome_ds, prepare_ome_ds # noqa

__version__ = "0.5.3"
__version_info__ = tuple(int(i) for i in __version__.split("."))
1 change: 1 addition & 0 deletions funlib/persistence/arrays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .array import Array # noqa
from .datasets import prepare_ds, open_ds # noqa
from .ome_datasets import prepare_ome_ds, open_ome_ds # noqa
312 changes: 215 additions & 97 deletions funlib/persistence/arrays/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,103 +8,6 @@
from funlib.geometry import Coordinate


class MetaDataFormat(BaseModel):
offset_attr: str = "offset"
voxel_size_attr: str = "voxel_size"
axis_names_attr: str = "axis_names"
units_attr: str = "units"
types_attr: str = "types"

class Config:
extra = "forbid"

def fetch(self, data: dict[str | int, Any], keys: Sequence[str]):
current_key: str | int
current_key, *keys = keys
try:
current_key = int(current_key)
except ValueError:
pass
if isinstance(current_key, int):
return self.fetch(data[current_key], keys)
if len(keys) == 0:
return data.get(current_key, None)
elif isinstance(data, list):
assert current_key == "{dim}", current_key
values = []
for sub_data in data:
try:
values.append(self.fetch(sub_data, keys))
except KeyError:
values.append(None)
return values
else:
return self.fetch(data[current_key], keys)

def strip_channels(self, types: list[str], to_strip: list[Sequence]) -> None:
to_delete = [i for i, t in enumerate(types) if t not in ["space", "time"]][::-1]
for ll in to_strip:
if ll is not None and len(ll) == len(types):
for i in to_delete:
del ll[i]

def parse(
self,
shape,
data: dict[str | int, Any],
offset=None,
voxel_size=None,
axis_names=None,
units=None,
types=None,
strict=False,
):
offset = (
offset
if offset is not None
else self.fetch(data, self.offset_attr.split("/"))
)
voxel_size = (
voxel_size
if voxel_size is not None
else self.fetch(data, self.voxel_size_attr.split("/"))
)
axis_names = (
axis_names
if axis_names is not None
else self.fetch(data, self.axis_names_attr.split("/"))
)
units = (
units if units is not None else self.fetch(data, self.units_attr.split("/"))
)
types = (
types if types is not None else self.fetch(data, self.types_attr.split("/"))
)

# we expect offset, voxel_size, and units to only apply to time and space dimensions
# so here we strip off values that are not space or time
if types is not None:
self.strip_channels(types, [offset, voxel_size, units])

offset = Coordinate(offset) if offset is not None else None
voxel_size = Coordinate(voxel_size) if voxel_size is not None else None
axis_names = list(axis_names) if axis_names is not None else None
units = list(units) if units is not None else None
types = list(types) if types is not None else None

metadata = MetaData(
shape=shape,
offset=offset,
voxel_size=voxel_size,
axis_names=axis_names,
units=units,
types=types,
strict=strict,
)

return metadata


class MetaData:
def __init__(
self,
Expand All @@ -125,6 +28,35 @@ def __init__(

self.validate(strict)

def interleave_physical(
self, physical: Sequence[int | str], non_physical: int | str | None
) -> Sequence[int | str | None]:
interleaved = []
physical_ind = 0
for i, type in enumerate(self.types):
if type in ["space", "time"]:
interleaved.append(physical[physical_ind])
physical_ind += 1
else:
interleaved.append(non_physical)
return interleaved

@property
def ome_scale(self) -> Sequence[int]:
return self.interleave_physical(self.voxel_size, 1)

@property
def ome_translate(self) -> Sequence[int]:
assert self.offset % self.voxel_size == self.voxel_size * 0, (
"funlib.persistence only supports ome-zarr with integer multiples of voxel_size as an offset."
f"offset: {self.offset}, voxel_size:{self.voxel_size}, offset % voxel_size: {self.offset % self.voxel_size}"
)
return self.interleave_physical(self.offset / self.voxel_size, 0)

@property
def ome_units(self) -> list[str | None]:
return self.interleave_physical(self.units, None)

@property
def offset(self) -> Coordinate:
return (
Expand Down Expand Up @@ -224,6 +156,192 @@ def validate(self, strict: bool):
assert self.dims == self.physical_dims + self.channel_dims


class OME_MetaDataFormat(BaseModel):
class Config:
extra = "forbid"

def fetch(self, data: dict[str | int, Any], keys: Sequence[str]):
current_key: str | int
current_key, *keys = keys
try:
current_key = int(current_key)
except ValueError:
pass
if isinstance(current_key, int):
return self.fetch(data[current_key], keys)
if len(keys) == 0:
return data.get(current_key, None)
elif isinstance(data, list):
assert current_key == "{dim}", current_key
values = []
for sub_data in data:
try:
values.append(self.fetch(sub_data, keys))
except KeyError:
values.append(None)
return values
else:
return self.fetch(data[current_key], keys)

def strip_channels(self, types: list[str], to_strip: list[Sequence]) -> None:
to_delete = [i for i, t in enumerate(types) if t not in ["space", "time"]][::-1]
for ll in to_strip:
if ll is not None and len(ll) == len(types):
for i in to_delete:
del ll[i]

def parse(
self,
shape,
data: dict[str | int, Any],
offset=None,
voxel_size=None,
axis_names=None,
units=None,
types=None,
strict=False,
) -> MetaData:
offset = (
offset
if offset is not None
else self.fetch(data, self.offset_attr.split("/"))
)
voxel_size = (
voxel_size
if voxel_size is not None
else self.fetch(data, self.voxel_size_attr.split("/"))
)
axis_names = (
axis_names
if axis_names is not None
else self.fetch(data, self.axis_names_attr.split("/"))
)
units = (
units if units is not None else self.fetch(data, self.units_attr.split("/"))
)
types = (
types if types is not None else self.fetch(data, self.types_attr.split("/"))
)

if types is not None:
self.strip_channels(types, [offset, voxel_size, units])

offset = Coordinate(offset) if offset is not None else None
voxel_size = Coordinate(voxel_size) if voxel_size is not None else None
axis_names = list(axis_names) if axis_names is not None else None
units = list(units) if units is not None else None
types = list(types) if types is not None else None

metadata = MetaData(
shape=shape,
offset=offset,
voxel_size=voxel_size,
axis_names=axis_names,
units=units,
types=types,
strict=strict,
)

return metadata


class MetaDataFormat(BaseModel):
offset_attr: str = "offset"
voxel_size_attr: str = "voxel_size"
axis_names_attr: str = "axis_names"
units_attr: str = "units"
types_attr: str = "types"

class Config:
extra = "forbid"

def fetch(self, data: dict[str | int, Any], keys: Sequence[str]):
current_key: str | int
current_key, *keys = keys
try:
current_key = int(current_key)
except ValueError:
pass
if isinstance(current_key, int):
return self.fetch(data[current_key], keys)
if len(keys) == 0:
return data.get(current_key, None)
elif isinstance(data, list):
assert current_key == "{dim}", current_key
values = []
for sub_data in data:
try:
values.append(self.fetch(sub_data, keys))
except KeyError:
values.append(None)
return values
else:
return self.fetch(data[current_key], keys)

def strip_channels(self, types: list[str], to_strip: list[Sequence]) -> None:
to_delete = [i for i, t in enumerate(types) if t not in ["space", "time"]][::-1]
for ll in to_strip:
if ll is not None and len(ll) == len(types):
for i in to_delete:
del ll[i]

def parse(
self,
shape,
data: dict[str | int, Any],
offset=None,
voxel_size=None,
axis_names=None,
units=None,
types=None,
strict=False,
) -> MetaData:
offset = (
offset
if offset is not None
else self.fetch(data, self.offset_attr.split("/"))
)
voxel_size = (
voxel_size
if voxel_size is not None
else self.fetch(data, self.voxel_size_attr.split("/"))
)
axis_names = (
axis_names
if axis_names is not None
else self.fetch(data, self.axis_names_attr.split("/"))
)
units = (
units if units is not None else self.fetch(data, self.units_attr.split("/"))
)
types = (
types if types is not None else self.fetch(data, self.types_attr.split("/"))
)

# we expect offset, voxel_size, and units to only apply to time and space dimensions
# so here we strip off values that are not space or time
if types is not None:
self.strip_channels(types, [offset, voxel_size, units])

offset = Coordinate(offset) if offset is not None else None
voxel_size = Coordinate(voxel_size) if voxel_size is not None else None
axis_names = list(axis_names) if axis_names is not None else None
units = list(units) if units is not None else None
types = list(types) if types is not None else None

metadata = MetaData(
shape=shape,
offset=offset,
voxel_size=voxel_size,
axis_names=axis_names,
units=units,
types=types,
strict=strict,
)

return metadata


DEFAULT_METADATA_FORMAT = MetaDataFormat()
LOCAL_PATHS = [Path("pyproject.toml"), Path("funlib_persistence.toml")]
USER_PATHS = [
Expand Down
Loading

0 comments on commit 0d81ef3

Please sign in to comment.