Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate axes for writer #123

Merged
merged 8 commits into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 65 additions & 19 deletions ome_zarr/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,59 @@
LOGGER = logging.getLogger("ome_zarr.writer")


def _validate_axes_names(
ndim: int, axes: Union[str, List[str]] = None, fmt: Format = CurrentFormat()
) -> Union[None, List[str]]:
"""Returns validated list of axes names or raise exception if invalid"""

if fmt.version in ("0.1", "0.2"):
if axes is not None:
LOGGER.info("axes ignored for version 0.1 or 0.2")
return None
sbesson marked this conversation as resolved.
Show resolved Hide resolved

# handle version 0.3...
if axes is None:
if ndim == 2:
axes = ["y", "x"]
sbesson marked this conversation as resolved.
Show resolved Hide resolved
LOGGER.info("Auto using axes %s for 2D data" % axes)
elif ndim == 5:
axes = ["t", "c", "z", "y", "x"]
sbesson marked this conversation as resolved.
Show resolved Hide resolved
LOGGER.info("Auto using axes %s for 5D data" % axes)
else:
raise ValueError(
"axes must be provided. Can't be guessed for 3D or 4D data"
)

if isinstance(axes, str):
axes = list(axes)

if axes is not None:
if len(axes) != ndim:
raise ValueError("axes length must match number of dimensions")
# from https://github.com/constantinpape/ome-ngff-implementations/
val_axes = tuple(axes)
if ndim == 2:
if val_axes != ("y", "x"):
raise ValueError(f"2D data must have axes ('y', 'x') {val_axes}")
elif ndim == 3:
if val_axes not in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")]:
raise ValueError(
"3D data must have axes ('z', 'y', 'x') or ('c', 'y', 'x')"
" or ('t', 'y', 'x'), not %s" % (val_axes,)
)
elif ndim == 4:
if val_axes not in [
("t", "z", "y", "x"),
("c", "z", "y", "x"),
("t", "c", "y", "x"),
]:
raise ValueError("4D data must have axes tzyx or czyx or tcyx")
else:
assert val_axes == ("t", "c", "z", "y", "x"), str(val_axes)

return axes


def write_multiscale(
pyramid: List,
group: zarr.Group,
Expand All @@ -28,6 +81,8 @@ def write_multiscale(
----------
pyramid: List of np.ndarray
the image data to save. Largest level first
All image arrays MUST be up to 5-dimensional with dimensions
ordered (t, c, z, y, x)
group: zarr.Group
the group within the zarr store to store the data in
chunks: int or tuple of ints,
Expand All @@ -41,25 +96,7 @@ def write_multiscale(
"""

dims = len(pyramid[0].shape)
if fmt.version not in ("0.1", "0.2"):
if axes is None:
if dims == 2:
axes = ["y", "x"]
elif dims == 5:
axes = ["t", "c", "z", "y", "x"]
else:
raise ValueError(
"axes must be provided. Can't be guessed for 3D or 4D data"
)
if len(axes) != dims:
raise ValueError("axes length must match number of dimensions")

if isinstance(axes, str):
axes = list(axes)

for dim in axes:
if dim not in ("t", "c", "z", "y", "x"):
raise ValueError("axes must each be one of 'x', 'y', 'z', 'c' or 't'")
axes = _validate_axes_names(dims, axes, fmt)

paths = []
for path, dataset in enumerate(pyramid):
Expand Down Expand Up @@ -90,6 +127,8 @@ def write_image(
image: np.ndarray
the image data to save. A downsampling of the data will be computed
if the scaler argument is non-None.
Image array MUST be up to 5-dimensional with dimensions
ordered (t, c, z, y, x)
group: zarr.Group
the group within the zarr store to store the data in
chunks: int or tuple of ints,
Expand All @@ -115,11 +154,18 @@ def write_image(
# v0.1 and v0.2 are strictly 5D
shape_5d: Tuple[Any, ...] = (*(1,) * (5 - image.ndim), *image.shape)
image = image.reshape(shape_5d)
# and we don't need axes
axes = None

if chunks is not None:
chunks = _retuple(chunks, image.shape)

if scaler is not None:
if image.shape[-1] == 1 or image.shape[-2] == 1:
raise ValueError(
"Can't downsample if size of x or y dimension is 1. "
"Shape: %s" % (image.shape,)
)
image = scaler.nearest(image)
else:
LOGGER.debug("disabling pyramid")
Expand Down
48 changes: 47 additions & 1 deletion tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ome_zarr.io import parse_url
from ome_zarr.reader import Multiscales, Reader
from ome_zarr.scale import Scaler
from ome_zarr.writer import write_image
from ome_zarr.writer import _validate_axes_names, write_image


class TestWriter:
Expand Down Expand Up @@ -77,3 +77,49 @@ def test_writer(self, shape, scaler, format_version):
else:
assert node.data[0].ndim == 5
assert np.allclose(data, node.data[0][...].compute())

def test_dim_names(self):
will-moore marked this conversation as resolved.
Show resolved Hide resolved

v03 = FormatV03()

# v0.3 MUST specify axes for 3D or 4D data
with pytest.raises(ValueError):
_validate_axes_names(3, axes=None, fmt=v03)

# ndims must match axes length
with pytest.raises(ValueError):
_validate_axes_names(3, axes="yx", fmt=v03)

# axes must be ordered tczyx
with pytest.raises(ValueError):
_validate_axes_names(3, axes="yxt", fmt=v03)
with pytest.raises(ValueError):
_validate_axes_names(2, axes=["x", "y"], fmt=v03)

# valid axes - no change, converted to list
assert _validate_axes_names(2, axes=["y", "x"], fmt=v03) == ["y", "x"]
assert _validate_axes_names(5, axes="tczyx", fmt=v03) == [
"t",
"c",
"z",
"y",
"x",
]

# if 2D or 5D, axes can be assigned automatically
assert _validate_axes_names(2, axes=None, fmt=v03) == ["y", "x"]
assert _validate_axes_names(5, axes=None, fmt=v03) == ["t", "c", "z", "y", "x"]

# for v0.1 or v0.2, axes should be None
assert _validate_axes_names(2, axes=["y", "x"], fmt=FormatV01()) is None
assert _validate_axes_names(2, axes=["y", "x"], fmt=FormatV02()) is None

# check that write_image is checking axes
data = self.create_data((125, 125))
with pytest.raises(ValueError):
write_image(
image=data,
group=self.group,
fmt=v03,
axes="xyz",
)