Skip to content

Commit

Permalink
Merge pull request #124 from will-moore/validate_axes_types_v0.4
Browse files Browse the repository at this point in the history
Validate axes types v0.4
  • Loading branch information
sbesson authored Jan 18, 2022
2 parents 2ed4426 + d9a44f7 commit 73aee67
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 74 deletions.
120 changes: 120 additions & 0 deletions ome_zarr/axes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Axes class for validating and transforming axes
"""
from typing import Any, Dict, List, Union

from .format import CurrentFormat, Format

KNOWN_AXES = {"x": "space", "y": "space", "z": "space", "c": "channel", "t": "time"}


class Axes:
def __init__(
self,
axes: Union[List[str], List[Dict[str, str]]],
fmt: Format = CurrentFormat(),
) -> None:
"""
Constructor, transforms axes and validates
Raises ValueError if not valid
"""
if axes is not None:
self.axes = self._axes_to_dicts(axes)
elif fmt.version in ("0.1", "0.2"):
# strictly 5D
self.axes = self._axes_to_dicts(["t", "c", "z", "y", "x"])
self.fmt = fmt
self.validate()

def validate(self) -> None:
"""Raises ValueError if not valid"""
if self.fmt.version in ("0.1", "0.2"):
return

# check names (only enforced for version 0.3)
if self.fmt.version == "0.3":
self._validate_03()
return

self._validate_axes_types()

def to_list(
self, fmt: Format = CurrentFormat()
) -> Union[List[str], List[Dict[str, str]]]:
if fmt.version == "0.3":
return self._get_names()
return self.axes

@staticmethod
def _axes_to_dicts(
axes: Union[List[str], List[Dict[str, str]]]
) -> List[Dict[str, str]]:
"""Returns a list of axis dicts with name and type"""
axes_dicts = []
for axis in axes:
if isinstance(axis, str):
axis_dict = {"name": axis}
if axis in KNOWN_AXES:
axis_dict["type"] = KNOWN_AXES[axis]
axes_dicts.append(axis_dict)
else:
axes_dicts.append(axis)
return axes_dicts

def _validate_axes_types(self) -> None:
"""
Validate the axes types according to the spec, version 0.4+
"""
axes_types = [axis.get("type") for axis in self.axes]
known_types = list(KNOWN_AXES.values())
unknown_types = [atype for atype in axes_types if atype not in known_types]
if len(unknown_types) > 1:
raise ValueError(
"Too many unknown axes types. 1 allowed, found: %s" % unknown_types
)

def _last_index(item: str, item_list: List[Any]) -> int:
return max(loc for loc, val in enumerate(item_list) if val == item)

if "time" in axes_types and _last_index("time", axes_types) > 0:
raise ValueError("'time' axis must be first dimension only")

if axes_types.count("channel") > 1:
raise ValueError("Only 1 axis can be type 'channel'")

if "channel" in axes_types and _last_index(
"channel", axes_types
) > axes_types.index("space"):
raise ValueError("'space' axes must come after 'channel'")

def _get_names(self) -> List[str]:
"""Returns a list of axis names"""
axes_names = []
for axis in self.axes:
if "name" not in axis:
raise ValueError("Axis Dict %s has no 'name'" % axis)
axes_names.append(axis["name"])
return axes_names

def _validate_03(self) -> None:

val_axes = tuple(self._get_names())
if len(val_axes) == 2:
if val_axes != ("y", "x"):
raise ValueError(f"2D data must have axes ('y', 'x') {val_axes}")
elif len(val_axes) == 3:
if val_axes not in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")]:
raise ValueError(
"3D data must have axes ('z', 'y', 'x') or ('c', 'y', 'x')"
" or ('t', 'y', 'x'), not %s" % (val_axes,)
)
elif len(val_axes) == 4:
if val_axes not in [
("t", "z", "y", "x"),
("c", "z", "y", "x"),
("t", "c", "y", "x"),
]:
raise ValueError("4D data must have axes tzyx or czyx or tcyx")
else:
if val_axes != ("t", "c", "z", "y", "x"):
raise ValueError("5D data must have axes ('t', 'c', 'z', 'y', 'x')")
22 changes: 21 additions & 1 deletion ome_zarr/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@
LOGGER = logging.getLogger("ome_zarr.format")


def format_from_version(version: str) -> "Format":

for fmt in format_implementations():
if fmt.version == version:
return fmt
raise ValueError(f"Version {version} not recognized")


def format_implementations() -> Iterator["Format"]:
"""
Return an instance of each format implementation, newest to oldest.
"""
yield FormatV04()
yield FormatV03()
yield FormatV02()
yield FormatV01()
Expand Down Expand Up @@ -136,4 +145,15 @@ def version(self) -> str:
return "0.3"


CurrentFormat = FormatV03
class FormatV04(FormatV03):
"""
Changelog: axes is list of dicts,
introduce transformations in multiscales (Nov 2021)
"""

@property
def version(self) -> str:
return "0.4"


CurrentFormat = FormatV04
27 changes: 18 additions & 9 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import numpy as np
from dask import delayed

from .axes import Axes
from .format import format_from_version
from .io import ZarrLocation
from .types import JSONDict

Expand Down Expand Up @@ -275,20 +277,22 @@ def matches(zarr: ZarrLocation) -> bool:
def __init__(self, node: Node) -> None:
super().__init__(node)

axes_values = {"t", "c", "z", "y", "x"}
try:
multiscales = self.lookup("multiscales", [])
version = multiscales[0].get(
"version", "0.1"
) # should this be matched with Format.version?
datasets = multiscales[0]["datasets"]
# axes field was introduced in 0.3, before all data was 5d
axes = tuple(multiscales[0].get("axes", ["t", "c", "z", "y", "x"]))
if len(set(axes) - axes_values) > 0:
raise RuntimeError(f"Invalid axes names: {set(axes) - axes_values}")
node.metadata["axes"] = axes
datasets = [d["path"] for d in datasets]
self.datasets: List[str] = datasets
axes = multiscales[0].get("axes")
fmt = format_from_version(version)
# Raises ValueError if not valid
axes_obj = Axes(axes, fmt)
node.metadata["axes"] = axes_obj.to_list()
paths = [d["path"] for d in datasets]
self.datasets: List[str] = paths
transformations = [d.get("transformations") for d in datasets]
if any(trans is not None for trans in transformations):
node.metadata["transformations"] = transformations
LOGGER.info("datasets %s", datasets)
except Exception as e:
LOGGER.error(f"failed to parse multiscale metadata: {e}")
Expand All @@ -301,7 +305,12 @@ def __init__(self, node: Node) -> None:
for c in data.chunks
]
LOGGER.info("resolution: %s", resolution)
LOGGER.info(" - shape %s = %s", axes, data.shape)
axes_names = None
if axes is not None:
axes_names = tuple(
axis if isinstance(axis, str) else axis["name"] for axis in axes
)
LOGGER.info(" - shape %s = %s", axes_names, data.shape)
LOGGER.info(" - chunks = %s", chunk_sizes)
LOGGER.info(" - dtype = %s", data.dtype)
node.data.append(data)
Expand Down
Loading

0 comments on commit 73aee67

Please sign in to comment.